[PD1] Fixed almost everything

This commit is contained in:
Afonso Franco 2024-04-19 23:59:26 +01:00
parent 39a0e5c01f
commit 7b3172a850
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
13 changed files with 534 additions and 192 deletions

View file

@ -2,7 +2,9 @@ package server
import (
"PD1/internal/protocol"
"crypto/x509"
"database/sql"
"fmt"
"log"
"time"
@ -18,7 +20,9 @@ func OpenDB() DataStore {
if err != nil {
log.Fatalln("Error opening db file")
}
return DataStore{db: db}
ds := DataStore{db: db}
ds.CreateTables()
return ds
}
func (ds DataStore) CreateTables() error {
@ -28,6 +32,7 @@ func (ds DataStore) CreateTables() error {
userCert BLOB
)`)
if err != nil {
fmt.Println("Error creating users table", err)
return err
}
@ -36,23 +41,26 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
content BLOB,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
PRIMARY KEY (toUID, fromUID, timestamp),
FOREIGN KEY(fromUID) REFERENCES users(UID),
FOREIGN KEY(toUID) REFERENCES users(UID)
)`)
if err != nil {
fmt.Println("Error creating messages table", err)
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessagePacket {
func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket
query := `
SELECT fromUID, toUID, content, timestamp
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ?
AND status = 0
@ -61,15 +69,16 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessag
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Content, &serverMessage.Timestamp)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err != nil {
log.Panicln("Could not map DB query to ServerMessage")
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
}
return serverMessage
return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
}
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error {
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
query := `
UPDATE messages
SET status = 1
@ -81,61 +90,90 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error {
// Execute the SQL statement
_, err := ds.db.Exec(query, toUID, position)
if err != nil {
return err
log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err)
}
return nil
}
func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet {
var messagePackets []protocol.Packet
func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
var messageInfoPackets []protocol.Packet
// Query to retrieve all messages from the user's queue
query := `
SELECT fromUID, toUID, content, timestamp
FROM messages
WHERE toUID = ?
AND status = 0
ORDER BY timestamp
SELECT
fromUID,
toUID,
timestamp,
queue_position,
subject,
status
FROM (
SELECT
fromUID,
toUID,
timestamp,
ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position,
subject,
status
FROM
messages
WHERE
toUID = ?
) AS ranked_messages
WHERE
status = 0
ORDER BY
timestamp;
`
// Execute the query
rows, err := ds.db.Query(query, toUID)
if err != nil {
log.Panicln("Failed to execute query:", err)
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
}
defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct
for rows.Next() {
//First row
if !rows.Next() {
return []protocol.Packet{}
}
for {
var fromUID string
var toUID string
var content []byte
var subject []byte
var timestamp time.Time
if err := rows.Scan(&fromUID, &toUID, &content, &timestamp); err != nil {
log.Panicln("Failed to scan row:", err)
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
}
var message protocol.Packet
hasNext := rows.Next()
if !hasNext {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true)
messageInfoPackets = append(messageInfoPackets, message)
break
} else {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false)
messageInfoPackets = append(messageInfoPackets, message)
}
message := protocol.NewServerMessagePacket(fromUID, toUID, content, timestamp)
messagePackets = append(messagePackets, message)
}
if err := rows.Err(); err != nil {
log.Panicln("Error when getting user's messages")
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
}
return messagePackets
return messageInfoPackets
}
func (ds DataStore) AddMessageToQueue(uid string, message protocol.SubmitMessagePacket) {
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) {
query := `
INSERT INTO messages (fromUID, toUID, content, timestamp, status)
VALUES (?, ?, ?, ?, 0)
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
`
// Execute the SQL statement
currentTime := time.Now()
_, err := ds.db.Exec(query, uid, message.ToUID, message.Content, currentTime)
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil {
log.Panicln("Error adding message to database")
log.Printf("Error adding message to UID %v: %v", fromUID, err)
}
}
@ -147,29 +185,53 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
`
// Execute the SQL query
var userCert []byte
err := ds.db.QueryRow(query, uid).Scan(&userCert)
if err != nil {
log.Panicln("Error getting user certificate from the database")
var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows {
log.Panicf("No certificate for UID %v found in the database", uid)
}
return protocol.NewSendUserCertPacket(uid, userCert)
//userCert,err := x509.ParseCertificate(userCertBytes)
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewSendUserCertPacket(uid, userCertBytes)
}
func userExists(db *sql.DB, uid string) bool {
// Prepare the SQL statement for checking if a user exists
query := `
func (ds DataStore) userExists(uid string) bool {
// Prepare the SQL statement for checking if a user exists
query := `
SELECT COUNT(*)
FROM users
WHERE UID = ?
`
var count int
// Execute the SQL query
err := db.QueryRow(query, uid).Scan(&count)
if err != nil {
log.Panicln("Error checking if user exists")
}
// If count is greater than 0, the user exists
return count > 0
var count int
// Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count)
if err == sql.ErrNoRows {
log.Printf("User with UID %v does not exist", uid)
return false
} else {
return true
}
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
// Check if the user already exists
if ds.userExists(uid) {
log.Printf("User certificate for UID %s already exists.\n", uid)
return
}
// Insert the user certificate
insertQuery := `
INSERT INTO users (UID, userCert)
VALUES (?, ?)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
return
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
}

View file

@ -16,28 +16,44 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
//Get the UID of this user
UID := oidMap["2.5.4.65"]
if UID=="" {
if UID == "" {
panic("User certificate does not specify it's PSEUDONYM")
}
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F:
for {
var pac protocol.Packet
connection.Receive(&pac)
pac, active := connection.Receive()
if !active {
break F
}
switch pac.Flag {
case protocol.ReqUserCertPkt:
fmt.Printf("Type of pac.Body: %T\n", pac.Body)
UserCertPacket, ok := (pac.Body).(protocol.RequestUserCertPacket)
if !ok {
panic("Could not cast packet to it's type")
reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active {
break F
}
case protocol.ReqMsgsQueue:
_ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body)
messages := dataStore.GetUnreadMessagesInfoQueue(UID)
fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages))
for _, message := range messages {
if !connection.Send(message) {
break
}
}
userCertPacket := dataStore.GetUserCertificate(UserCertPacket.UID)
connection.Send(userCertPacket)
case protocol.ReqAllMsgPkt:
fmt.Println("ReqAllMsg")
case protocol.ReqMsgPkt:
fmt.Println("ReqMsg")
reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
break F
}
case protocol.SubmitMsgPkt:
fmt.Println("SubmitMsg")
submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
}
}
}