diff --git a/Projs/PD1/internal/client/client.go b/Projs/PD1/internal/client/client.go index 3afbda7..3769cb2 100644 --- a/Projs/PD1/internal/client/client.go +++ b/Projs/PD1/internal/client/client.go @@ -30,52 +30,46 @@ func Run() { panic("Insufficient arguments for 'send' command. Usage: send ") } uid := flag.Arg(1) - subject := flag.Arg(2) - messageBody := readMessageBody() + plainSubject := flag.Arg(2) + plainBody := readStdin("Enter message content (limited to 1000 bytes):") //Turn content to bytes - marshaledSubject := Marshal(subject) - marshaledBody := Marshal(messageBody) + plainSubjectBytes := Marshal(plainSubject) + plainBodyBytes := Marshal(plainBody) cl := networking.NewClient[protocol.Packet](&clientKeyStore) defer cl.Connection.Conn.Close() - uidCert := getUserCert(cl, uid) - if uidCert == nil { + receiverCert := getUserCert(cl, uid) + if receiverCert == nil { return } - encryptedSubject := clientKeyStore.EncryptMessageContent(uidCert, marshaledSubject) - encryptedBody := clientKeyStore.EncryptMessageContent(uidCert, marshaledBody) - submitMessage := protocol.NewSubmitMessagePacket(uid, encryptedSubject, encryptedBody) - if !cl.Connection.Send(submitMessage) { + subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes) + body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes) + sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body) + if !cl.Connection.Send(sendMsgPacket) { return } cl.Connection.Conn.Close() case "askqueue": - cl := networking.NewClient[protocol.Packet](&clientKeyStore) - defer cl.Connection.Conn.Close() - - requestUnreadMsgsQueuePacket := protocol.NewRequestUnreadMsgsQueuePacket() - if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { - return - } - serverMessagePackets, certificates := getManyMessagesInfo(cl) - var clientMessages []ClientMessageInfo - for _, message := range serverMessagePackets { - senderCert, ok := certificates[message.FromUID] - if ok { - decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) - subject := Unmarshal(decryptedSubjectBytes) - clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) - clientMessages = append(clientMessages, clientMessage) + pageInput := flag.Arg(1) + page := 1 + if pageInput != "" { + if val, err := strconv.Atoi(pageInput); err == nil { + page = max(1, val) + } + } + pageSizeInput := flag.Arg(2) + pageSize := 5 + if pageSizeInput != "" { + if val, err := strconv.Atoi(pageSizeInput); err == nil { + pageSize = max(1, val) } } - //Sort the messages - sort.Slice(clientMessages, func(i, j int) bool { - return clientMessages[i].Num > clientMessages[j].Num - }) - showMessagesInfo(clientMessages) + cl := networking.NewClient[protocol.Packet](&clientKeyStore) + defer cl.Connection.Conn.Close() + askQueue(cl,clientKeyStore, page, pageSize) case "getmsg": if flag.NArg() < 2 { @@ -84,26 +78,26 @@ func Run() { numString := flag.Arg(1) cl := networking.NewClient[protocol.Packet](&clientKeyStore) defer cl.Connection.Conn.Close() - num,err :=strconv.Atoi(numString) - if err!=nil{ - log.Panicln("NUM argument provided is not a number") - } - packet := protocol.NewRequestMsgPacket(num) - cl.Connection.Send(packet) + num, err := strconv.Atoi(numString) + if err != nil { + log.Panicln("NUM argument provided is not a number") + } + packet := protocol.NewGetMsgPacket(num) + cl.Connection.Send(packet) + + receivedMsgPacket, active := cl.Connection.Receive() + if !active { + return + } + answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body) + senderCert := getUserCert(cl, answerGetMsg.FromUID) + decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject) + decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body) + subject := Unmarshal(decSubjectBytes) + body := Unmarshal(decBodyBytes) + message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp) + showMessage(message) - receivedMsgPacket,active := cl.Connection.Receive() - if !active{ - return - } - serverMessagePacket := protocol.UnmarshalServerMessagePacket(receivedMsgPacket.Body) - senderCert := getUserCert(cl, serverMessagePacket.FromUID) - decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Subject) - decryptedBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, serverMessagePacket.Body) - subject := Unmarshal(decryptedSubjectBytes) - body := Unmarshal(decryptedBodyBytes) - message := newClientMessage(serverMessagePacket.FromUID, serverMessagePacket.ToUID, subject, body, serverMessagePacket.Timestamp) - showMessage(message) - case "help": showHelp() @@ -114,43 +108,33 @@ func Run() { } func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate { - certRequestPacket := protocol.NewRequestUserCertPacket(uid) - if !cl.Connection.Send(certRequestPacket) { + getUserCertPacket := protocol.NewGetUserCertPacket(uid) + if !cl.Connection.Send(getUserCertPacket) { return nil } - var certPacket *protocol.Packet - certPacket, active := cl.Connection.Receive() + var answerGetUserCertPacket *protocol.Packet + answerGetUserCertPacket, active := cl.Connection.Receive() if !active { return nil } - uidCertInBytes := protocol.UnmarshalSendUserCertPacket(certPacket.Body) - uidCert, err := x509.ParseCertificate(uidCertInBytes.Certificate) + answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body) + userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate) if err != nil { return nil } - return uidCert + return userCert } -func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.ServerMessageInfoPacket, map[string]*x509.Certificate) { - //Create the slice to hold the incoming messages before decrypting - //Create the map to hold the sender certificates - //Create sync mutexes - serverMessageInfoPackets := []protocol.ServerMessageInfoPacket{} - //Run while message isn't the last one - msg := protocol.ServerMessageInfoPacket{} - for !msg.Last { - sendMsgPacket, active := cl.Connection.Receive() - if !active { - return nil, nil - } - msg = protocol.UnmarshalServerMessageInfoPacket(sendMsgPacket.Body) - //Lock and append - serverMessageInfoPackets = append(serverMessageInfoPackets, msg) +func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) { + answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive() + if !active { + return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil } + answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body) //Create Set of needed certificates senderSet := map[string]bool{} - for _, messageInfo := range serverMessageInfoPackets { + for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo { senderSet[messageInfo.FromUID] = true } certificatesMap := map[string]*x509.Certificate{} @@ -159,5 +143,37 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.Serv senderCert := getUserCert(cl, senderUID) certificatesMap[senderUID] = senderCert } - return serverMessageInfoPackets, certificatesMap + return answerGetUnreadMsgsInfo, certificatesMap +} + +func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) { + requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize) + if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { + return + } + unreadMsgsInfo, certificates := getManyMessagesInfo(cl) + var clientMessages []ClientMessageInfo + for _, message := range unreadMsgsInfo.MessagesInfo { + senderCert, ok := certificates[message.FromUID] + if ok { + decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) + subject := Unmarshal(decryptedSubjectBytes) + clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) + clientMessages = append(clientMessages, clientMessage) + } + } + //Sort the messages + sort.Slice(clientMessages, func(i, j int) bool { + return clientMessages[i].Num > clientMessages[j].Num + }) + + action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages) + switch action { + case -1: + askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize) + case 0: + return + case 1: + askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize) + } } diff --git a/Projs/PD1/internal/client/interface.go b/Projs/PD1/internal/client/interface.go index 8d49bb6..82f2d8a 100644 --- a/Projs/PD1/internal/client/interface.go +++ b/Projs/PD1/internal/client/interface.go @@ -4,10 +4,11 @@ import ( "bufio" "fmt" "os" + "strings" ) -func readMessageBody() string { - fmt.Println("Enter message content (limited to 1000 bytes):") +func readStdin(message string) string { + fmt.Println(message) scanner := bufio.NewScanner(os.Stdin) scanner.Scan() // FIX: make sure this doesnt die @@ -36,15 +37,63 @@ func showHelp() { fmt.Println("help: Imprime instruções de uso do programa.") } -func showMessagesInfo(messages []ClientMessageInfo) { +func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int { + if messages == nil { + fmt.Println("No unread messages in the queue") + return 0 + } for _, message := range messages { fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) } + fmt.Printf("Page %v/%v\n",page,numPages) + return messagesInfoPageNavigation(page, numPages) } -func showMessage(message ClientMessage) { - fmt.Printf("From:%v\n", message.FromUID) - fmt.Printf("To:%v\n", message.ToUID) - fmt.Printf("Subject:%v\n", message.Subject) - fmt.Printf("Body:%v\n", message.Body) +func messagesInfoPageNavigation(page int, numPages int) int { + var action string + + switch page { + case 1: + if page == numPages { + action = readStdin("Actions: quit") + } else { + action = readStdin("Actions: quit/next") + } + case numPages: + action = readStdin("Actions: prev/quit") + default: + action = readStdin("prev/quit/next") + } + + switch strings.ToLower(action) { + case "prev": + if page == 1 { + fmt.Println("Unavailable action: Already in first page") + messagesInfoPageNavigation(page, numPages) + } else { + return -1 + } + case "quit": + return 0 + case "next": + if page == numPages { + fmt.Println("Unavailable action: Already in last page") + messagesInfoPageNavigation(page, numPages) + } else { + return 1 + } + default: + fmt.Println("Unknown action") + messagesInfoPageNavigation(page, numPages) + } + return 0 } + + +func showMessage(message ClientMessage) { + fmt.Printf("From: %s\n", message.FromUID) + fmt.Printf("To: %s\n", message.ToUID) + fmt.Printf("Subject: %s\n", message.Subject) + fmt.Printf("Body: %s\n", message.Body) +} + diff --git a/Projs/PD1/internal/protocol/protocol.go b/Projs/PD1/internal/protocol/protocol.go index 6aeecec..680c2b5 100644 --- a/Projs/PD1/internal/protocol/protocol.go +++ b/Projs/PD1/internal/protocol/protocol.go @@ -9,46 +9,67 @@ import ( type PacketType int const ( - ReqUserCertPkt PacketType = iota - ReqMsgsQueue - ReqMsgPkt - SubmitMsgPkt - SendUserCertPkt - ServerMsgInfoPkt - ServerMsgPkt + // Client requests user certificate + FlagGetUserCert PacketType = iota + + // Client requests unread message info + FlagGetUnreadMsgsInfo + + // Client requests a message from the queue + FlagGetMsg + + // Client sends a message + FlagSendMsg + + // Server sends user certificate + FlagAnswerGetUserCert + + // Server sends list of unread messages + FlagAnswerGetUnreadMsgsInfo + + // Server sends requested message + FlagAnswerGetMsg ) type ( - RequestUserCertPacket struct { + GetUserCert struct { UID string `json:"uid"` } - RequestMsgsQueuePacket struct { + GetUnreadMsgsInfo struct { + Page int `json:"page"` + PageSize int `json:"pageSize"` } - RequestMsgPacket struct { + GetMsg struct { Num int `json:"num"` } - SubmitMessagePacket struct { + SendMsg struct { ToUID string `json:"to_uid"` Subject []byte `json:"subject"` Body []byte `json:"body"` } - SendUserCertPacket struct { + AnswerGetUserCert struct { UID string `json:"uid"` Certificate []byte `json:"certificate"` } - ServerMessageInfoPacket struct { + AnswerGetUnreadMsgsInfo struct { + Page int `json:"page"` + NumPages int `json:"num_pages"` + MessagesInfo []MsgInfo `json:"messages_info"` + } + + MsgInfo struct { Num int `json:"num"` FromUID string `json:"from_uid"` Subject []byte `json:"subject"` Timestamp time.Time `json:"timestamp"` - Last bool `json:"last"` } - ServerMessagePacket struct { + + AnswerGetMsg struct { FromUID string `json:"from_uid"` ToUID string `json:"to_uid"` Subject []byte `json:"subject"` @@ -64,156 +85,188 @@ type Packet struct { Body PacketBody `json:"body"` } -func NewRequestUserCertPacket(UID string) Packet { +func NewPacket(fl PacketType, body PacketBody) Packet { return Packet{ - Flag: ReqUserCertPkt, - Body: RequestUserCertPacket{ - UID: UID, - }, + Flag: fl, + Body: body, + } + +} + +func NewGetUserCert(UID string) GetUserCert { + return GetUserCert{ + UID: UID, } } -func NewRequestUnreadMsgsQueuePacket() Packet { - return Packet{ - Flag: ReqMsgsQueue, - Body: RequestMsgsQueuePacket{}, +func NewGetUnreadMsgsInfo(page int, pageSize int) GetUnreadMsgsInfo { + return GetUnreadMsgsInfo{ + Page: page, + PageSize: pageSize} +} + +func NewGetMsg(num int) GetMsg { + return GetMsg{ + Num: num, } } -func NewRequestMsgPacket(num int) Packet { - return Packet{ - Flag: ReqMsgPkt, - Body: RequestMsgPacket{ - Num: num, - }, +func NewSendMsg(toUID string, subject []byte, body []byte) SendMsg { + return SendMsg{ + ToUID: toUID, + Subject: subject, + Body: body, } } -func NewSubmitMessagePacket(toUID string, subject []byte, body []byte) Packet { - return Packet{ - Flag: SubmitMsgPkt, - Body: SubmitMessagePacket{ - ToUID: toUID, - Subject: subject, - Body: body, - }, +func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert { + return AnswerGetUserCert{ + UID: uid, + Certificate: certificate, } } -func NewSendUserCertPacket(uid string, certificate []byte) Packet { - return Packet{ - Flag: SendUserCertPkt, - Body: SendUserCertPacket{ - UID: uid, - Certificate: certificate, - }, - } +func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo { + return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo} } -func NewServerMessageInfoPacket(num int, fromUID string, subject []byte, timestamp time.Time, last bool) Packet { - return Packet{ - Flag: ServerMsgInfoPkt, - Body: ServerMessageInfoPacket{ - Num: num, - FromUID: fromUID, - Subject: subject, - Timestamp: timestamp, - Last: last, - }, +func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo { + return MsgInfo{ + Num: num, + FromUID: fromUID, + Subject: subject, + Timestamp: timestamp, } } -func NewServerMessagePacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet { - return Packet{ - Flag: ServerMsgPkt, - Body: ServerMessagePacket{ - FromUID: fromUID, - ToUID: toUID, - Subject: subject, - Body: body, - Timestamp: timestamp, - }, +func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) AnswerGetMsg { + return AnswerGetMsg{ + FromUID: fromUID, + ToUID: toUID, + Subject: subject, + Body: body, + Timestamp: timestamp, } } -func UnmarshalRequestUserCertPacket(data PacketBody) RequestUserCertPacket { +func NewGetUserCertPacket(UID string) Packet { + return NewPacket(FlagGetUserCert, NewGetUserCert(UID)) +} + +func NewGetUnreadMsgsInfoPacket(page int, pageSize int) Packet { + return NewPacket(FlagGetUnreadMsgsInfo, NewGetUnreadMsgsInfo(page, pageSize)) +} + +func NewGetMsgPacket(num int) Packet { + return NewPacket(FlagGetMsg, NewGetMsg(num)) +} + +func NewSendMsgPacket(toUID string, subject []byte, body []byte) Packet { + return NewPacket(FlagSendMsg, NewSendMsg(toUID, subject, body)) +} + +func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet { + return NewPacket(FlagAnswerGetUserCert, NewAnswerGetUserCert(uid, certificate)) +} + +func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet { + return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo)) +} + +func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet { + return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last)) +} + +func UnmarshalGetUserCert(data PacketBody) GetUserCert { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet RequestUserCertPacket + var packet GetUserCert if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into RequestUserCertPacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err)) } return packet } -func UnmarshalRequestMsgsQueuePacket(data PacketBody) RequestMsgsQueuePacket { +func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet RequestMsgsQueuePacket + var packet GetUnreadMsgsInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into RequestMsgsQueuePacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err)) } return packet } -func UnmarshalRequestMsgPacket(data PacketBody) RequestMsgPacket { +func UnmarshalGetMsg(data PacketBody) GetMsg { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet RequestMsgPacket + var packet GetMsg if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into RequestMsgPacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err)) } return packet } -func UnmarshalSubmitMessagePacket(data PacketBody) SubmitMessagePacket { +func UnmarshalSendMsg(data PacketBody) SendMsg { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet SubmitMessagePacket + var packet SendMsg if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into SubmitMessagePacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err)) } return packet } -func UnmarshalSendUserCertPacket(data PacketBody) SendUserCertPacket { +func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet SendUserCertPacket + var packet AnswerGetUserCert if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into SendUserCertPacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err)) } return packet } -func UnmarshalServerMessageInfoPacket(data PacketBody) ServerMessageInfoPacket { +func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet ServerMessageInfoPacket + var packet MsgInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into ServerMessageInfoPacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err)) } return packet } -func UnmarshalServerMessagePacket(data PacketBody) ServerMessagePacket { +func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo { jsonData, err := json.Marshal(data) if err != nil { panic(fmt.Errorf("failed to marshal data: %v", err)) } - var packet ServerMessagePacket + var packet AnswerGetUnreadMsgsInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into ServerMessagePacket: %v", err)) + panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err)) + } + return packet +} + +func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet AnswerGetMsg + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err)) } return packet } diff --git a/Projs/PD1/internal/server/datastore.go b/Projs/PD1/internal/server/datastore.go index 9229077..c2466c1 100644 --- a/Projs/PD1/internal/server/datastore.go +++ b/Projs/PD1/internal/server/datastore.go @@ -41,6 +41,7 @@ func (ds DataStore) CreateTables() error { fromUID TEXT, toUID TEXT, timestamp TIMESTAMP, + queue_position INT DEFAULT 0, subject BLOB, body BLOB, status INT CHECK (status IN (0,1)), @@ -53,18 +54,36 @@ func (ds DataStore) CreateTables() error { return err } + // Define a trigger to automatically assign numbers for each message of each user starting from 1 + _, err = ds.db.Exec(` + CREATE TRIGGER IF NOT EXISTS assign_queue_position + AFTER INSERT ON messages + FOR EACH ROW + BEGIN + UPDATE messages + SET queue_position = ( + SELECT COUNT(*) + FROM messages + WHERE toUID = NEW.toUID + ) + WHERE toUID = NEW.toUID AND rowid = NEW.rowid; + END; + `) + if err != nil { + fmt.Println("Error creating trigger", err) + return err + } + return nil } func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { - var serverMessage protocol.ServerMessagePacket + var serverMessage protocol.AnswerGetMsg query := ` SELECT fromUID, toUID, subject, body, timestamp FROM messages - WHERE toUID = ? - ORDER BY timestamp - LIMIT 1 OFFSET ? + WHERE toUID = ? AND queue_position = ? ` // Execute the query row := ds.db.QueryRow(query, toUID, position) @@ -73,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err) } - return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) + return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) } @@ -84,9 +103,7 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { WHERE (fromUID,toUID,timestamp) = ( SELECT fromUID,toUID,timestamp FROM messages - WHERE toUID = ? - ORDER BY timestamp - LIMIT 1 OFFSET ? + WHERE toUID = ? AND queue_position = ? ) ` @@ -97,8 +114,14 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { } } -func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet { - var messageInfoPackets []protocol.Packet +func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet { + + // Retrieve the total count of unread messages + var totalCount int + err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) + if err != nil { + log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err) + } // Query to retrieve all messages from the user's queue query := ` @@ -109,38 +132,23 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet { queue_position, subject, status - FROM ( - SELECT - fromUID, - toUID, - timestamp, - ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position, - subject, - status - FROM - messages - WHERE - toUID = ? - ) AS ranked_messages + FROM messages WHERE - status = 0 + toUID = ? AND status = 0 ORDER BY - timestamp; + queue_position DESC + LIMIT ? OFFSET ?; ` // Execute the query - rows, err := ds.db.Query(query, toUID) + rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize) if err != nil { log.Printf("Error getting all messages for UID %v: %v", toUID, err) } defer rows.Close() - // Iterate through the result set and scan each row into a ServerMessage struct - //First row - if !rows.Next() { - return []protocol.Packet{} - } - for { + messageInfoPackets := []protocol.MsgInfo{} + for rows.Next() { var fromUID string var subject []byte var timestamp time.Time @@ -148,25 +156,19 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet { if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil { panic(err) } - var message protocol.Packet - hasNext := rows.Next() - if !hasNext { - message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true) - messageInfoPackets = append(messageInfoPackets, message) - break - } else { - message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false) - messageInfoPackets = append(messageInfoPackets, message) - } + answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp) + messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo) } if err := rows.Err(); err != nil { log.Printf("Error when getting messages for UID %v: %v", toUID, err) } - return messageInfoPackets + numberOfPages := (totalCount + pageSize - 1) / pageSize + currentPage := min(numberOfPages, page) + return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets) } -func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) { +func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) { query := ` INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) VALUES (?, ?, ?, ?, ?, 0) @@ -197,7 +199,7 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet { //if err!=nil { // log.Panicf("Error parsing certificate for UID %v",uid) //} - return protocol.NewSendUserCertPacket(uid, userCertBytes) + return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes) } func (ds DataStore) userExists(uid string) bool { diff --git a/Projs/PD1/internal/server/server.go b/Projs/PD1/internal/server/server.go index 1f2863a..e4d6474 100644 --- a/Projs/PD1/internal/server/server.go +++ b/Projs/PD1/internal/server/server.go @@ -4,7 +4,6 @@ import ( "PD1/internal/protocol" "PD1/internal/utils/cryptoUtils" "PD1/internal/utils/networking" - "fmt" ) func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) { @@ -24,33 +23,30 @@ F: for { pac, active := connection.Receive() if !active { - break F + break } switch pac.Flag { - case protocol.ReqUserCertPkt: - reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body) + case protocol.FlagGetUserCert: + reqUserCert := protocol.UnmarshalGetUserCert(pac.Body) userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) if active := connection.Send(userCertPacket); !active { break F } - case protocol.ReqMsgsQueue: - _ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body) - messages := dataStore.GetUnreadMessagesInfoQueue(UID) - fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages)) - for _, message := range messages { - if !connection.Send(message) { - break - } + case protocol.FlagGetUnreadMsgsInfo: + getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body) + messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize) + if !connection.Send(messages) { + break F } - case protocol.ReqMsgPkt: - reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body) + case protocol.FlagGetMsg: + reqMsg := protocol.UnmarshalGetMsg(pac.Body) message := dataStore.GetMessage(UID, reqMsg.Num) if active := connection.Send(message); !active { break F } - dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) - case protocol.SubmitMsgPkt: - submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body) + dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) + case protocol.FlagSendMsg: + submitMsg := protocol.UnmarshalSendMsg(pac.Body) if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) { dataStore.AddMessageToQueue(UID, submitMsg) } diff --git a/Projs/PD1/server.db b/Projs/PD1/server.db index b5ed3b1..3784ec9 100644 Binary files a/Projs/PD1/server.db and b/Projs/PD1/server.db differ