[PD1] Cleaned up protocol,database and receiving multiple messageInfos

This commit is contained in:
Afonso Franco 2024-04-20 17:16:52 +01:00
parent 4c141bbc6e
commit f3cf9cfc40
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
6 changed files with 347 additions and 231 deletions

View file

@ -30,38 +30,130 @@ func Run() {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
}
uid := flag.Arg(1)
subject := flag.Arg(2)
messageBody := readMessageBody()
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes
marshaledSubject := Marshal(subject)
marshaledBody := Marshal(messageBody)
plainSubjectBytes := Marshal(plainSubject)
plainBodyBytes := Marshal(plainBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
uidCert := getUserCert(cl, uid)
if uidCert == nil {
receiverCert := getUserCert(cl, uid)
if receiverCert == nil {
return
}
encryptedSubject := clientKeyStore.EncryptMessageContent(uidCert, marshaledSubject)
encryptedBody := clientKeyStore.EncryptMessageContent(uidCert, marshaledBody)
submitMessage := protocol.NewSubmitMessagePacket(uid, encryptedSubject, encryptedBody)
if !cl.Connection.Send(submitMessage) {
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(sendMsgPacket) {
return
}
cl.Connection.Conn.Close()
case "askqueue":
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
askQueue(cl,clientKeyStore, page, pageSize)
requestUnreadMsgsQueuePacket := protocol.NewRequestUnreadMsgsQueuePacket()
case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString)
if err != nil {
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewGetMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help":
showHelp()
default:
commandError()
}
}
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if !cl.Connection.Send(getUserCertPacket) {
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil
}
return userCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return answerGetUnreadMsgsInfo, certificatesMap
}
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
}
serverMessagePackets, certificates := getManyMessagesInfo(cl)
unreadMsgsInfo, certificates := getManyMessagesInfo(cl)
var clientMessages []ClientMessageInfo
for _, message := range serverMessagePackets {
for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID]
if ok {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
@ -75,89 +167,13 @@ func Run() {
return clientMessages[i].Num > clientMessages[j].Num
})
showMessagesInfo(clientMessages)
case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num,err :=strconv.Atoi(numString)
if err!=nil{
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewRequestMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket,active := cl.Connection.Receive()
if !active{
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize)
case 0:
return
case 1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize)
}
serverMessagePacket := protocol.UnmarshalServerMessagePacket(receivedMsgPacket.Body)
senderCert := getUserCert(cl, serverMessagePacket.FromUID)
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Subject)
decryptedBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Body)
subject := Unmarshal(decryptedSubjectBytes)
body := Unmarshal(decryptedBodyBytes)
message := newClientMessage(serverMessagePacket.FromUID, serverMessagePacket.ToUID, subject, body, serverMessagePacket.Timestamp)
showMessage(message)
case "help":
showHelp()
default:
commandError()
}
}
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
certRequestPacket := protocol.NewRequestUserCertPacket(uid)
if !cl.Connection.Send(certRequestPacket) {
return nil
}
var certPacket *protocol.Packet
certPacket, active := cl.Connection.Receive()
if !active {
return nil
}
uidCertInBytes := protocol.UnmarshalSendUserCertPacket(certPacket.Body)
uidCert, err := x509.ParseCertificate(uidCertInBytes.Certificate)
if err != nil {
return nil
}
return uidCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.ServerMessageInfoPacket, map[string]*x509.Certificate) {
//Create the slice to hold the incoming messages before decrypting
//Create the map to hold the sender certificates
//Create sync mutexes
serverMessageInfoPackets := []protocol.ServerMessageInfoPacket{}
//Run while message isn't the last one
msg := protocol.ServerMessageInfoPacket{}
for !msg.Last {
sendMsgPacket, active := cl.Connection.Receive()
if !active {
return nil, nil
}
msg = protocol.UnmarshalServerMessageInfoPacket(sendMsgPacket.Body)
//Lock and append
serverMessageInfoPackets = append(serverMessageInfoPackets, msg)
}
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range serverMessageInfoPackets {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return serverMessageInfoPackets, certificatesMap
}

View file

@ -4,10 +4,11 @@ import (
"bufio"
"fmt"
"os"
"strings"
)
func readMessageBody() string {
fmt.Println("Enter message content (limited to 1000 bytes):")
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
@ -36,15 +37,63 @@ func showHelp() {
fmt.Println("help: Imprime instruções de uso do programa.")
}
func showMessagesInfo(messages []ClientMessageInfo) {
func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int {
if messages == nil {
fmt.Println("No unread messages in the queue")
return 0
}
for _, message := range messages {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
fmt.Printf("Page %v/%v\n",page,numPages)
return messagesInfoPageNavigation(page, numPages)
}
func showMessage(message ClientMessage) {
fmt.Printf("From:%v\n", message.FromUID)
fmt.Printf("To:%v\n", message.ToUID)
fmt.Printf("Subject:%v\n", message.Subject)
fmt.Printf("Body:%v\n", message.Body)
func messagesInfoPageNavigation(page int, numPages int) int {
var action string
switch page {
case 1:
if page == numPages {
action = readStdin("Actions: quit")
} else {
action = readStdin("Actions: quit/next")
}
case numPages:
action = readStdin("Actions: prev/quit")
default:
action = readStdin("prev/quit/next")
}
switch strings.ToLower(action) {
case "prev":
if page == 1 {
fmt.Println("Unavailable action: Already in first page")
messagesInfoPageNavigation(page, numPages)
} else {
return -1
}
case "quit":
return 0
case "next":
if page == numPages {
fmt.Println("Unavailable action: Already in last page")
messagesInfoPageNavigation(page, numPages)
} else {
return 1
}
default:
fmt.Println("Unknown action")
messagesInfoPageNavigation(page, numPages)
}
return 0
}
func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body)
}

View file

@ -9,46 +9,67 @@ import (
type PacketType int
const (
ReqUserCertPkt PacketType = iota
ReqMsgsQueue
ReqMsgPkt
SubmitMsgPkt
SendUserCertPkt
ServerMsgInfoPkt
ServerMsgPkt
// Client requests user certificate
FlagGetUserCert PacketType = iota
// Client requests unread message info
FlagGetUnreadMsgsInfo
// Client requests a message from the queue
FlagGetMsg
// Client sends a message
FlagSendMsg
// Server sends user certificate
FlagAnswerGetUserCert
// Server sends list of unread messages
FlagAnswerGetUnreadMsgsInfo
// Server sends requested message
FlagAnswerGetMsg
)
type (
RequestUserCertPacket struct {
GetUserCert struct {
UID string `json:"uid"`
}
RequestMsgsQueuePacket struct {
GetUnreadMsgsInfo struct {
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
RequestMsgPacket struct {
GetMsg struct {
Num int `json:"num"`
}
SubmitMessagePacket struct {
SendMsg struct {
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
Body []byte `json:"body"`
}
SendUserCertPacket struct {
AnswerGetUserCert struct {
UID string `json:"uid"`
Certificate []byte `json:"certificate"`
}
ServerMessageInfoPacket struct {
AnswerGetUnreadMsgsInfo struct {
Page int `json:"page"`
NumPages int `json:"num_pages"`
MessagesInfo []MsgInfo `json:"messages_info"`
}
MsgInfo struct {
Num int `json:"num"`
FromUID string `json:"from_uid"`
Subject []byte `json:"subject"`
Timestamp time.Time `json:"timestamp"`
Last bool `json:"last"`
}
ServerMessagePacket struct {
AnswerGetMsg struct {
FromUID string `json:"from_uid"`
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
@ -64,156 +85,188 @@ type Packet struct {
Body PacketBody `json:"body"`
}
func NewRequestUserCertPacket(UID string) Packet {
func NewPacket(fl PacketType, body PacketBody) Packet {
return Packet{
Flag: ReqUserCertPkt,
Body: RequestUserCertPacket{
Flag: fl,
Body: body,
}
}
func NewGetUserCert(UID string) GetUserCert {
return GetUserCert{
UID: UID,
},
}
}
func NewRequestUnreadMsgsQueuePacket() Packet {
return Packet{
Flag: ReqMsgsQueue,
Body: RequestMsgsQueuePacket{},
}
func NewGetUnreadMsgsInfo(page int, pageSize int) GetUnreadMsgsInfo {
return GetUnreadMsgsInfo{
Page: page,
PageSize: pageSize}
}
func NewRequestMsgPacket(num int) Packet {
return Packet{
Flag: ReqMsgPkt,
Body: RequestMsgPacket{
func NewGetMsg(num int) GetMsg {
return GetMsg{
Num: num,
},
}
}
func NewSubmitMessagePacket(toUID string, subject []byte, body []byte) Packet {
return Packet{
Flag: SubmitMsgPkt,
Body: SubmitMessagePacket{
func NewSendMsg(toUID string, subject []byte, body []byte) SendMsg {
return SendMsg{
ToUID: toUID,
Subject: subject,
Body: body,
},
}
}
func NewSendUserCertPacket(uid string, certificate []byte) Packet {
return Packet{
Flag: SendUserCertPkt,
Body: SendUserCertPacket{
func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
return AnswerGetUserCert{
UID: uid,
Certificate: certificate,
},
}
}
func NewServerMessageInfoPacket(num int, fromUID string, subject []byte, timestamp time.Time, last bool) Packet {
return Packet{
Flag: ServerMsgInfoPkt,
Body: ServerMessageInfoPacket{
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo}
}
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{
Num: num,
FromUID: fromUID,
Subject: subject,
Timestamp: timestamp,
Last: last,
},
}
}
func NewServerMessagePacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return Packet{
Flag: ServerMsgPkt,
Body: ServerMessagePacket{
func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) AnswerGetMsg {
return AnswerGetMsg{
FromUID: fromUID,
ToUID: toUID,
Subject: subject,
Body: body,
Timestamp: timestamp,
},
}
}
func UnmarshalRequestUserCertPacket(data PacketBody) RequestUserCertPacket {
func NewGetUserCertPacket(UID string) Packet {
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
}
func NewGetUnreadMsgsInfoPacket(page int, pageSize int) Packet {
return NewPacket(FlagGetUnreadMsgsInfo, NewGetUnreadMsgsInfo(page, pageSize))
}
func NewGetMsgPacket(num int) Packet {
return NewPacket(FlagGetMsg, NewGetMsg(num))
}
func NewSendMsgPacket(toUID string, subject []byte, body []byte) Packet {
return NewPacket(FlagSendMsg, NewSendMsg(toUID, subject, body))
}
func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
return NewPacket(FlagAnswerGetUserCert, NewAnswerGetUserCert(uid, certificate))
}
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo))
}
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestUserCertPacket
var packet GetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestUserCertPacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
}
return packet
}
func UnmarshalRequestMsgsQueuePacket(data PacketBody) RequestMsgsQueuePacket {
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestMsgsQueuePacket
var packet GetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgsQueuePacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
}
return packet
}
func UnmarshalRequestMsgPacket(data PacketBody) RequestMsgPacket {
func UnmarshalGetMsg(data PacketBody) GetMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestMsgPacket
var packet GetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgPacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
}
return packet
}
func UnmarshalSubmitMessagePacket(data PacketBody) SubmitMessagePacket {
func UnmarshalSendMsg(data PacketBody) SendMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet SubmitMessagePacket
var packet SendMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SubmitMessagePacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
}
return packet
}
func UnmarshalSendUserCertPacket(data PacketBody) SendUserCertPacket {
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet SendUserCertPacket
var packet AnswerGetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendUserCertPacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
}
return packet
}
func UnmarshalServerMessageInfoPacket(data PacketBody) ServerMessageInfoPacket {
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ServerMessageInfoPacket
var packet MsgInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessageInfoPacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
}
return packet
}
func UnmarshalServerMessagePacket(data PacketBody) ServerMessagePacket {
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ServerMessagePacket
var packet AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessagePacket: %v", err))
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
}
return packet
}
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet AnswerGetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
}
return packet
}

View file

@ -41,6 +41,7 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
queue_position INT DEFAULT 0,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
@ -53,18 +54,36 @@ func (ds DataStore) CreateTables() error {
return err
}
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
fmt.Println("Error creating trigger", err)
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket
var serverMessage protocol.AnswerGetMsg
query := `
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
WHERE toUID = ? AND queue_position = ?
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
@ -73,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
}
return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
}
@ -84,9 +103,7 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp
FROM messages
WHERE toUID = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
WHERE toUID = ? AND queue_position = ?
)
`
@ -97,8 +114,14 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
}
}
func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
var messageInfoPackets []protocol.Packet
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
}
// Query to retrieve all messages from the user's queue
query := `
@ -109,38 +132,23 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
queue_position,
subject,
status
FROM (
SELECT
fromUID,
toUID,
timestamp,
ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position,
subject,
status
FROM
messages
FROM messages
WHERE
toUID = ?
) AS ranked_messages
WHERE
status = 0
toUID = ? AND status = 0
ORDER BY
timestamp;
queue_position DESC
LIMIT ? OFFSET ?;
`
// Execute the query
rows, err := ds.db.Query(query, toUID)
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
}
defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct
//First row
if !rows.Next() {
return []protocol.Packet{}
}
for {
messageInfoPackets := []protocol.MsgInfo{}
for rows.Next() {
var fromUID string
var subject []byte
var timestamp time.Time
@ -148,25 +156,19 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
}
var message protocol.Packet
hasNext := rows.Next()
if !hasNext {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true)
messageInfoPackets = append(messageInfoPackets, message)
break
} else {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false)
messageInfoPackets = append(messageInfoPackets, message)
}
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
}
return messageInfoPackets
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) {
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
@ -197,7 +199,7 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewSendUserCertPacket(uid, userCertBytes)
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
}
func (ds DataStore) userExists(uid string) bool {

View file

@ -4,7 +4,6 @@ import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"fmt"
)
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
@ -24,33 +23,30 @@ F:
for {
pac, active := connection.Receive()
if !active {
break F
break
}
switch pac.Flag {
case protocol.ReqUserCertPkt:
reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body)
case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active {
break F
}
case protocol.ReqMsgsQueue:
_ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body)
messages := dataStore.GetUnreadMessagesInfoQueue(UID)
fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages))
for _, message := range messages {
if !connection.Send(message) {
break
case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
if !connection.Send(messages) {
break F
}
}
case protocol.ReqMsgPkt:
reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body)
case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
break F
}
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.SubmitMsgPkt:
submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body)
case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
}

Binary file not shown.