[PD1] Cleaned up protocol,database and receiving multiple messageInfos

This commit is contained in:
Afonso Franco 2024-04-20 17:16:52 +01:00
parent 4c141bbc6e
commit f3cf9cfc40
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
6 changed files with 347 additions and 231 deletions

View file

@ -30,38 +30,130 @@ func Run() {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>") panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
} }
uid := flag.Arg(1) uid := flag.Arg(1)
subject := flag.Arg(2) plainSubject := flag.Arg(2)
messageBody := readMessageBody() plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes //Turn content to bytes
marshaledSubject := Marshal(subject) plainSubjectBytes := Marshal(plainSubject)
marshaledBody := Marshal(messageBody) plainBodyBytes := Marshal(plainBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore) cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close() defer cl.Connection.Conn.Close()
uidCert := getUserCert(cl, uid) receiverCert := getUserCert(cl, uid)
if uidCert == nil { if receiverCert == nil {
return return
} }
encryptedSubject := clientKeyStore.EncryptMessageContent(uidCert, marshaledSubject) subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
encryptedBody := clientKeyStore.EncryptMessageContent(uidCert, marshaledBody) body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
submitMessage := protocol.NewSubmitMessagePacket(uid, encryptedSubject, encryptedBody) sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(submitMessage) { if !cl.Connection.Send(sendMsgPacket) {
return return
} }
cl.Connection.Conn.Close() cl.Connection.Conn.Close()
case "askqueue": case "askqueue":
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
cl := networking.NewClient[protocol.Packet](&clientKeyStore) cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close() defer cl.Connection.Conn.Close()
askQueue(cl,clientKeyStore, page, pageSize)
requestUnreadMsgsQueuePacket := protocol.NewRequestUnreadMsgsQueuePacket() case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString)
if err != nil {
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewGetMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help":
showHelp()
default:
commandError()
}
}
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if !cl.Connection.Send(getUserCertPacket) {
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil
}
return userCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return answerGetUnreadMsgsInfo, certificatesMap
}
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return return
} }
serverMessagePackets, certificates := getManyMessagesInfo(cl) unreadMsgsInfo, certificates := getManyMessagesInfo(cl)
var clientMessages []ClientMessageInfo var clientMessages []ClientMessageInfo
for _, message := range serverMessagePackets { for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID] senderCert, ok := certificates[message.FromUID]
if ok { if ok {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
@ -75,89 +167,13 @@ func Run() {
return clientMessages[i].Num > clientMessages[j].Num return clientMessages[i].Num > clientMessages[j].Num
}) })
showMessagesInfo(clientMessages) action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case "getmsg": case -1:
if flag.NArg() < 2 { askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize)
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>") case 0:
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num,err :=strconv.Atoi(numString)
if err!=nil{
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewRequestMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket,active := cl.Connection.Receive()
if !active{
return return
case 1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize)
} }
serverMessagePacket := protocol.UnmarshalServerMessagePacket(receivedMsgPacket.Body)
senderCert := getUserCert(cl, serverMessagePacket.FromUID)
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Subject)
decryptedBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Body)
subject := Unmarshal(decryptedSubjectBytes)
body := Unmarshal(decryptedBodyBytes)
message := newClientMessage(serverMessagePacket.FromUID, serverMessagePacket.ToUID, subject, body, serverMessagePacket.Timestamp)
showMessage(message)
case "help":
showHelp()
default:
commandError()
}
}
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
certRequestPacket := protocol.NewRequestUserCertPacket(uid)
if !cl.Connection.Send(certRequestPacket) {
return nil
}
var certPacket *protocol.Packet
certPacket, active := cl.Connection.Receive()
if !active {
return nil
}
uidCertInBytes := protocol.UnmarshalSendUserCertPacket(certPacket.Body)
uidCert, err := x509.ParseCertificate(uidCertInBytes.Certificate)
if err != nil {
return nil
}
return uidCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.ServerMessageInfoPacket, map[string]*x509.Certificate) {
//Create the slice to hold the incoming messages before decrypting
//Create the map to hold the sender certificates
//Create sync mutexes
serverMessageInfoPackets := []protocol.ServerMessageInfoPacket{}
//Run while message isn't the last one
msg := protocol.ServerMessageInfoPacket{}
for !msg.Last {
sendMsgPacket, active := cl.Connection.Receive()
if !active {
return nil, nil
}
msg = protocol.UnmarshalServerMessageInfoPacket(sendMsgPacket.Body)
//Lock and append
serverMessageInfoPackets = append(serverMessageInfoPackets, msg)
}
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range serverMessageInfoPackets {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return serverMessageInfoPackets, certificatesMap
} }

View file

@ -4,10 +4,11 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"os" "os"
"strings"
) )
func readMessageBody() string { func readStdin(message string) string {
fmt.Println("Enter message content (limited to 1000 bytes):") fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
scanner.Scan() scanner.Scan()
// FIX: make sure this doesnt die // FIX: make sure this doesnt die
@ -36,15 +37,63 @@ func showHelp() {
fmt.Println("help: Imprime instruções de uso do programa.") fmt.Println("help: Imprime instruções de uso do programa.")
} }
func showMessagesInfo(messages []ClientMessageInfo) { func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int {
if messages == nil {
fmt.Println("No unread messages in the queue")
return 0
}
for _, message := range messages { for _, message := range messages {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
} }
fmt.Printf("Page %v/%v\n",page,numPages)
return messagesInfoPageNavigation(page, numPages)
} }
func showMessage(message ClientMessage) { func messagesInfoPageNavigation(page int, numPages int) int {
fmt.Printf("From:%v\n", message.FromUID) var action string
fmt.Printf("To:%v\n", message.ToUID)
fmt.Printf("Subject:%v\n", message.Subject) switch page {
fmt.Printf("Body:%v\n", message.Body) case 1:
if page == numPages {
action = readStdin("Actions: quit")
} else {
action = readStdin("Actions: quit/next")
} }
case numPages:
action = readStdin("Actions: prev/quit")
default:
action = readStdin("prev/quit/next")
}
switch strings.ToLower(action) {
case "prev":
if page == 1 {
fmt.Println("Unavailable action: Already in first page")
messagesInfoPageNavigation(page, numPages)
} else {
return -1
}
case "quit":
return 0
case "next":
if page == numPages {
fmt.Println("Unavailable action: Already in last page")
messagesInfoPageNavigation(page, numPages)
} else {
return 1
}
default:
fmt.Println("Unknown action")
messagesInfoPageNavigation(page, numPages)
}
return 0
}
func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body)
}

View file

@ -9,46 +9,67 @@ import (
type PacketType int type PacketType int
const ( const (
ReqUserCertPkt PacketType = iota // Client requests user certificate
ReqMsgsQueue FlagGetUserCert PacketType = iota
ReqMsgPkt
SubmitMsgPkt // Client requests unread message info
SendUserCertPkt FlagGetUnreadMsgsInfo
ServerMsgInfoPkt
ServerMsgPkt // Client requests a message from the queue
FlagGetMsg
// Client sends a message
FlagSendMsg
// Server sends user certificate
FlagAnswerGetUserCert
// Server sends list of unread messages
FlagAnswerGetUnreadMsgsInfo
// Server sends requested message
FlagAnswerGetMsg
) )
type ( type (
RequestUserCertPacket struct { GetUserCert struct {
UID string `json:"uid"` UID string `json:"uid"`
} }
RequestMsgsQueuePacket struct { GetUnreadMsgsInfo struct {
Page int `json:"page"`
PageSize int `json:"pageSize"`
} }
RequestMsgPacket struct { GetMsg struct {
Num int `json:"num"` Num int `json:"num"`
} }
SubmitMessagePacket struct { SendMsg struct {
ToUID string `json:"to_uid"` ToUID string `json:"to_uid"`
Subject []byte `json:"subject"` Subject []byte `json:"subject"`
Body []byte `json:"body"` Body []byte `json:"body"`
} }
SendUserCertPacket struct { AnswerGetUserCert struct {
UID string `json:"uid"` UID string `json:"uid"`
Certificate []byte `json:"certificate"` Certificate []byte `json:"certificate"`
} }
ServerMessageInfoPacket struct { AnswerGetUnreadMsgsInfo struct {
Page int `json:"page"`
NumPages int `json:"num_pages"`
MessagesInfo []MsgInfo `json:"messages_info"`
}
MsgInfo struct {
Num int `json:"num"` Num int `json:"num"`
FromUID string `json:"from_uid"` FromUID string `json:"from_uid"`
Subject []byte `json:"subject"` Subject []byte `json:"subject"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Last bool `json:"last"`
} }
ServerMessagePacket struct {
AnswerGetMsg struct {
FromUID string `json:"from_uid"` FromUID string `json:"from_uid"`
ToUID string `json:"to_uid"` ToUID string `json:"to_uid"`
Subject []byte `json:"subject"` Subject []byte `json:"subject"`
@ -64,156 +85,188 @@ type Packet struct {
Body PacketBody `json:"body"` Body PacketBody `json:"body"`
} }
func NewRequestUserCertPacket(UID string) Packet { func NewPacket(fl PacketType, body PacketBody) Packet {
return Packet{ return Packet{
Flag: ReqUserCertPkt, Flag: fl,
Body: RequestUserCertPacket{ Body: body,
}
}
func NewGetUserCert(UID string) GetUserCert {
return GetUserCert{
UID: UID, UID: UID,
},
} }
} }
func NewRequestUnreadMsgsQueuePacket() Packet { func NewGetUnreadMsgsInfo(page int, pageSize int) GetUnreadMsgsInfo {
return Packet{ return GetUnreadMsgsInfo{
Flag: ReqMsgsQueue, Page: page,
Body: RequestMsgsQueuePacket{}, PageSize: pageSize}
}
} }
func NewRequestMsgPacket(num int) Packet { func NewGetMsg(num int) GetMsg {
return Packet{ return GetMsg{
Flag: ReqMsgPkt,
Body: RequestMsgPacket{
Num: num, Num: num,
},
} }
} }
func NewSubmitMessagePacket(toUID string, subject []byte, body []byte) Packet { func NewSendMsg(toUID string, subject []byte, body []byte) SendMsg {
return Packet{ return SendMsg{
Flag: SubmitMsgPkt,
Body: SubmitMessagePacket{
ToUID: toUID, ToUID: toUID,
Subject: subject, Subject: subject,
Body: body, Body: body,
},
} }
} }
func NewSendUserCertPacket(uid string, certificate []byte) Packet { func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
return Packet{ return AnswerGetUserCert{
Flag: SendUserCertPkt,
Body: SendUserCertPacket{
UID: uid, UID: uid,
Certificate: certificate, Certificate: certificate,
},
} }
} }
func NewServerMessageInfoPacket(num int, fromUID string, subject []byte, timestamp time.Time, last bool) Packet {
return Packet{ func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
Flag: ServerMsgInfoPkt, return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo}
Body: ServerMessageInfoPacket{ }
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{
Num: num, Num: num,
FromUID: fromUID, FromUID: fromUID,
Subject: subject, Subject: subject,
Timestamp: timestamp, Timestamp: timestamp,
Last: last,
},
} }
} }
func NewServerMessagePacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet { func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) AnswerGetMsg {
return Packet{ return AnswerGetMsg{
Flag: ServerMsgPkt,
Body: ServerMessagePacket{
FromUID: fromUID, FromUID: fromUID,
ToUID: toUID, ToUID: toUID,
Subject: subject, Subject: subject,
Body: body, Body: body,
Timestamp: timestamp, Timestamp: timestamp,
},
} }
} }
func UnmarshalRequestUserCertPacket(data PacketBody) RequestUserCertPacket { func NewGetUserCertPacket(UID string) Packet {
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
}
func NewGetUnreadMsgsInfoPacket(page int, pageSize int) Packet {
return NewPacket(FlagGetUnreadMsgsInfo, NewGetUnreadMsgsInfo(page, pageSize))
}
func NewGetMsgPacket(num int) Packet {
return NewPacket(FlagGetMsg, NewGetMsg(num))
}
func NewSendMsgPacket(toUID string, subject []byte, body []byte) Packet {
return NewPacket(FlagSendMsg, NewSendMsg(toUID, subject, body))
}
func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
return NewPacket(FlagAnswerGetUserCert, NewAnswerGetUserCert(uid, certificate))
}
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo))
}
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet RequestUserCertPacket var packet GetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestUserCertPacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
} }
return packet return packet
} }
func UnmarshalRequestMsgsQueuePacket(data PacketBody) RequestMsgsQueuePacket { func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet RequestMsgsQueuePacket var packet GetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgsQueuePacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
} }
return packet return packet
} }
func UnmarshalRequestMsgPacket(data PacketBody) RequestMsgPacket { func UnmarshalGetMsg(data PacketBody) GetMsg {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet RequestMsgPacket var packet GetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgPacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
} }
return packet return packet
} }
func UnmarshalSubmitMessagePacket(data PacketBody) SubmitMessagePacket { func UnmarshalSendMsg(data PacketBody) SendMsg {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet SubmitMessagePacket var packet SendMsg
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SubmitMessagePacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
} }
return packet return packet
} }
func UnmarshalSendUserCertPacket(data PacketBody) SendUserCertPacket { func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet SendUserCertPacket var packet AnswerGetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendUserCertPacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
} }
return packet return packet
} }
func UnmarshalServerMessageInfoPacket(data PacketBody) ServerMessageInfoPacket { func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet ServerMessageInfoPacket var packet MsgInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessageInfoPacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
} }
return packet return packet
} }
func UnmarshalServerMessagePacket(data PacketBody) ServerMessagePacket { func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) panic(fmt.Errorf("failed to marshal data: %v", err))
} }
var packet ServerMessagePacket var packet AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessagePacket: %v", err)) panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
}
return packet
}
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet AnswerGetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
} }
return packet return packet
} }

View file

@ -41,6 +41,7 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT, fromUID TEXT,
toUID TEXT, toUID TEXT,
timestamp TIMESTAMP, timestamp TIMESTAMP,
queue_position INT DEFAULT 0,
subject BLOB, subject BLOB,
body BLOB, body BLOB,
status INT CHECK (status IN (0,1)), status INT CHECK (status IN (0,1)),
@ -53,18 +54,36 @@ func (ds DataStore) CreateTables() error {
return err return err
} }
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
fmt.Println("Error creating trigger", err)
return err
}
return nil return nil
} }
func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket var serverMessage protocol.AnswerGetMsg
query := ` query := `
SELECT fromUID, toUID, subject, body, timestamp SELECT fromUID, toUID, subject, body, timestamp
FROM messages FROM messages
WHERE toUID = ? WHERE toUID = ? AND queue_position = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
` `
// Execute the query // Execute the query
row := ds.db.QueryRow(query, toUID, position) row := ds.db.QueryRow(query, toUID, position)
@ -73,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err) log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
} }
return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
} }
@ -84,9 +103,7 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
WHERE (fromUID,toUID,timestamp) = ( WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp SELECT fromUID,toUID,timestamp
FROM messages FROM messages
WHERE toUID = ? WHERE toUID = ? AND queue_position = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
) )
` `
@ -97,8 +114,14 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
} }
} }
func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet { func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
var messageInfoPackets []protocol.Packet
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
}
// Query to retrieve all messages from the user's queue // Query to retrieve all messages from the user's queue
query := ` query := `
@ -109,38 +132,23 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
queue_position, queue_position,
subject, subject,
status status
FROM ( FROM messages
SELECT
fromUID,
toUID,
timestamp,
ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position,
subject,
status
FROM
messages
WHERE WHERE
toUID = ? toUID = ? AND status = 0
) AS ranked_messages
WHERE
status = 0
ORDER BY ORDER BY
timestamp; queue_position DESC
LIMIT ? OFFSET ?;
` `
// Execute the query // Execute the query
rows, err := ds.db.Query(query, toUID) rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil { if err != nil {
log.Printf("Error getting all messages for UID %v: %v", toUID, err) log.Printf("Error getting all messages for UID %v: %v", toUID, err)
} }
defer rows.Close() defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct messageInfoPackets := []protocol.MsgInfo{}
//First row for rows.Next() {
if !rows.Next() {
return []protocol.Packet{}
}
for {
var fromUID string var fromUID string
var subject []byte var subject []byte
var timestamp time.Time var timestamp time.Time
@ -148,25 +156,19 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil { if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err) panic(err)
} }
var message protocol.Packet answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
hasNext := rows.Next() messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
if !hasNext {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true)
messageInfoPackets = append(messageInfoPackets, message)
break
} else {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false)
messageInfoPackets = append(messageInfoPackets, message)
}
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err) log.Printf("Error when getting messages for UID %v: %v", toUID, err)
} }
return messageInfoPackets numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
} }
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) { func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
query := ` query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0) VALUES (?, ?, ?, ?, ?, 0)
@ -197,7 +199,7 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
//if err!=nil { //if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid) // log.Panicf("Error parsing certificate for UID %v",uid)
//} //}
return protocol.NewSendUserCertPacket(uid, userCertBytes) return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
} }
func (ds DataStore) userExists(uid string) bool { func (ds DataStore) userExists(uid string) bool {

View file

@ -4,7 +4,6 @@ import (
"PD1/internal/protocol" "PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils" "PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking" "PD1/internal/utils/networking"
"fmt"
) )
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) { func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
@ -24,33 +23,30 @@ F:
for { for {
pac, active := connection.Receive() pac, active := connection.Receive()
if !active { if !active {
break F break
} }
switch pac.Flag { switch pac.Flag {
case protocol.ReqUserCertPkt: case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body) reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active { if active := connection.Send(userCertPacket); !active {
break F break F
} }
case protocol.ReqMsgsQueue: case protocol.FlagGetUnreadMsgsInfo:
_ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body) getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMessagesInfoQueue(UID) messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages)) if !connection.Send(messages) {
for _, message := range messages { break F
if !connection.Send(message) {
break
} }
} case protocol.FlagGetMsg:
case protocol.ReqMsgPkt: reqMsg := protocol.UnmarshalGetMsg(pac.Body)
reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num) message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active { if active := connection.Send(message); !active {
break F break F
} }
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.SubmitMsgPkt: case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body) submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) { if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg) dataStore.AddMessageToQueue(UID, submitMsg)
} }

Binary file not shown.