[PD1] Cleaned up protocol,database and receiving multiple messageInfos

This commit is contained in:
Afonso Franco 2024-04-20 17:16:52 +01:00
parent 4c141bbc6e
commit f3cf9cfc40
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
6 changed files with 347 additions and 231 deletions

View file

@ -41,6 +41,7 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
queue_position INT DEFAULT 0,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
@ -53,18 +54,36 @@ func (ds DataStore) CreateTables() error {
return err
}
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
fmt.Println("Error creating trigger", err)
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket
var serverMessage protocol.AnswerGetMsg
query := `
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
WHERE toUID = ? AND queue_position = ?
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
@ -73,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
}
return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
}
@ -84,9 +103,7 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp
FROM messages
WHERE toUID = ?
ORDER BY timestamp
LIMIT 1 OFFSET ?
WHERE toUID = ? AND queue_position = ?
)
`
@ -97,8 +114,14 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
}
}
func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
var messageInfoPackets []protocol.Packet
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
}
// Query to retrieve all messages from the user's queue
query := `
@ -109,38 +132,23 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
queue_position,
subject,
status
FROM (
SELECT
fromUID,
toUID,
timestamp,
ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position,
subject,
status
FROM
messages
WHERE
toUID = ?
) AS ranked_messages
FROM messages
WHERE
status = 0
toUID = ? AND status = 0
ORDER BY
timestamp;
queue_position DESC
LIMIT ? OFFSET ?;
`
// Execute the query
rows, err := ds.db.Query(query, toUID)
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
}
defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct
//First row
if !rows.Next() {
return []protocol.Packet{}
}
for {
messageInfoPackets := []protocol.MsgInfo{}
for rows.Next() {
var fromUID string
var subject []byte
var timestamp time.Time
@ -148,25 +156,19 @@ func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
}
var message protocol.Packet
hasNext := rows.Next()
if !hasNext {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true)
messageInfoPackets = append(messageInfoPackets, message)
break
} else {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false)
messageInfoPackets = append(messageInfoPackets, message)
}
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
}
return messageInfoPackets
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) {
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
@ -197,7 +199,7 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewSendUserCertPacket(uid, userCertBytes)
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
}
func (ds DataStore) userExists(uid string) bool {

View file

@ -4,7 +4,6 @@ import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"fmt"
)
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
@ -24,33 +23,30 @@ F:
for {
pac, active := connection.Receive()
if !active {
break F
break
}
switch pac.Flag {
case protocol.ReqUserCertPkt:
reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body)
case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active {
break F
}
case protocol.ReqMsgsQueue:
_ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body)
messages := dataStore.GetUnreadMessagesInfoQueue(UID)
fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages))
for _, message := range messages {
if !connection.Send(message) {
break
}
case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
if !connection.Send(messages) {
break F
}
case protocol.ReqMsgPkt:
reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body)
case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
break F
}
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.SubmitMsgPkt:
submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body)
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
}