diff --git a/Projs/PD1/go.mod b/Projs/PD1/go.mod index ee2f96b..29ae242 100644 --- a/Projs/PD1/go.mod +++ b/Projs/PD1/go.mod @@ -3,6 +3,7 @@ module PD1 go 1.22.2 require ( + github.com/mattn/go-sqlite3 v1.14.22 // indirect golang.org/x/crypto v0.11.0 // indirect software.sslmate.com/src/go-pkcs12 v0.4.0 // indirect ) diff --git a/Projs/PD1/go.sum b/Projs/PD1/go.sum index 1741a0a..90a922e 100644 --- a/Projs/PD1/go.sum +++ b/Projs/PD1/go.sum @@ -1,3 +1,5 @@ +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= diff --git a/Projs/PD1/internal/client/client.go b/Projs/PD1/internal/client/client.go index 1b0b588..0f7dccb 100644 --- a/Projs/PD1/internal/client/client.go +++ b/Projs/PD1/internal/client/client.go @@ -30,7 +30,7 @@ func Run() { defer cl.Connection.Conn.Close() // TODO: getuserinfo client cert // TODO: ask server for the recieving client's cert - certRequestPacket := protocol.NewRequestPubKey() + certRequestPacket := protocol.NewRequestUserCertPacket(uid) cl.Connection.Send(certRequestPacket) certPacket := cl.Connection.Receive() // TODO: cipherContent := cryptoUtils.encryptMessageContent() diff --git a/Projs/PD1/internal/protocol/protocol.go b/Projs/PD1/internal/protocol/protocol.go index db7d6eb..be05e8d 100644 --- a/Projs/PD1/internal/protocol/protocol.go +++ b/Projs/PD1/internal/protocol/protocol.go @@ -7,12 +7,12 @@ import ( type PacketType int const ( - ReqPK PacketType = iota - ReqAllMsg - ReqMsg - SubmitMsg - SendPK - Msg + ReqUserCertPkt PacketType = iota + ReqAllMsgPkt + ReqMsgPkt + SubmitMsgPkt + SendUserCertPkt + ServerMsgPkt ) type PacketBody interface{} @@ -22,91 +22,91 @@ type Packet struct { Body PacketBody } -// Client --> Server: Ask for a user's public key -type RequestPubKey struct { - FromUID string - KeyUID string +// Client --> Server: Ask for a user's certificate +type RequestUserCertPacket struct { + UID string } -func NewRequestPubKey(fromUID, keyUID string) Packet { +func NewRequestUserCertPacket(UID string) Packet { return Packet{ - Flag: ReqPK, - Body: RequestPubKey{ - FromUID: fromUID, - KeyUID: keyUID, + Flag: ReqUserCertPkt, + Body: RequestUserCertPacket{ + UID: UID, }, } } // Client --> Server: Ask for all the client's messages in the queue -type RequestAllMsg struct { +type RequestAllMsgPacket struct { FromUID string } -func NewRequestAllMsg(fromUID string) Packet { +func NewRequestAllMsgPacket(fromUID string) Packet { return Packet{ - Flag: ReqAllMsg, - Body: RequestAllMsg{ + Flag: ReqAllMsgPkt, + Body: RequestAllMsgPacket{ FromUID: fromUID, }, } } // Client --> Server: Ask for a specific message in the queue -type RequestMsg struct { +type RequestMsgPacket struct { Num uint16 } -func NewRequestMsg(num uint16) Packet { +func NewRequestMsgPacket(num uint16) Packet { return Packet{ - Flag: ReqMsg, - Body: RequestMsg{ + Flag: ReqMsgPkt, + Body: RequestMsgPacket{ Num: num, }, } } // Client --> Server: Send message from client to server -type SubmitMessage struct { +type SubmitMessagePacket struct { ToUID string Content []byte } -func NewSubmitMessage(toUID string, content []byte) Packet { +func NewSubmitMessagePacket(toUID string, content []byte) Packet { return Packet{ - Flag: SubmitMsg, - Body: SubmitMessage{ + Flag: SubmitMsgPkt, + Body: SubmitMessagePacket{ ToUID: toUID, Content: content}, } } // Server --> Client: Send the client the requested public key -type SendPubKey struct { +type SendUserCertPacket struct { + UID string Key []byte } -func NewSendPubKey(key []byte) Packet { +func NewSendUserCertPacket(uid string, key []byte) Packet { return Packet{ - Flag: SendPK, - Body: SendPubKey{ + Flag: SendUserCertPkt, + Body: SendUserCertPacket{ + UID: uid, Key: key, }, } } // Server --> Client: Send the client a message -type ServerMessage struct { +type ServerMessagePacket struct { FromUID string ToUID string Content []byte Timestamp time.Time } -func NewMessage(fromUID, toUID string, content []byte, timestamp time.Time) Packet { +func NewMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet { return Packet{ Flag: Msg, - Body: ServerMessage{ + Body: ServerMessagePacket{ FromUID: fromUID, ToUID: toUID, Content: content, diff --git a/Projs/PD1/internal/server/datastore.go b/Projs/PD1/internal/server/datastore.go index abb4e43..b92cba2 100644 --- a/Projs/PD1/internal/server/datastore.go +++ b/Projs/PD1/internal/server/datastore.go @@ -1 +1,156 @@ package server + +import ( + "PD1/internal/protocol" + "database/sql" + "log" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type DataStore struct { + db *sql.DB +} + +func OpenDB() DataStore { + db, err := sql.Open("sqlite3", "server.db") + if err != nil { + log.Fatalln("Error opening db file") + } + return DataStore{db: db} +} + +func (ds DataStore) CreateTables() error { + // Create users table + _, err := ds.db.Exec(`CREATE TABLE IF NOT EXISTS users ( + UID TEXT PRIMARY KEY, + userCert BLOB + )`) + if err != nil { + return err + } + + // Create messages table + _, err = ds.db.Exec(`CREATE TABLE IF NOT EXISTS messages ( + fromUID TEXT, + toUID TEXT, + timestamp TIMESTAMP, + content BLOB, + PRIMARY KEY (toUID, fromUID, timestamp), + FOREIGN KEY(fromUID) REFERENCES users(UID), + FOREIGN KEY(toUID) REFERENCES users(UID) + )`) + if err != nil { + return err + } + + return nil +} + +func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessagePacket { + + var serverMessage protocol.ServerMessagePacket + query := ` + SELECT fromUID, toUID, content, timestamp + FROM messages + WHERE toUID = ? + AND status = 0 + ORDER BY timestamp + LIMIT 1 OFFSET ? + ` + // Execute the query + row := ds.db.QueryRow(query, toUID, position) + err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Content, &serverMessage.Timestamp) + if err != nil { + log.Panicln("Could not map DB query to ServerMessage") + } + return serverMessage + +} + +func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error { + query := ` + UPDATE messages + SET status = 1 + WHERE toUID = ? AND status = 0 + ORDER BY timestamp + LIMIT 1 OFFSET ? + ` + + // Execute the SQL statement + _, err := ds.db.Exec(query, toUID, position) + if err != nil { + return err + } + + return nil +} + +func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet { + var messagePackets []protocol.Packet + + // Query to retrieve all messages from the user's queue + query := ` + SELECT fromUID, toUID, content, timestamp + FROM messages + WHERE toUID = ? + AND status = 0 + ORDER BY timestamp + ` + + // Execute the query + rows, err := ds.db.Query(query, toUID) + if err != nil { + log.Panicln("Failed to execute query:", err) + } + defer rows.Close() + + // Iterate through the result set and scan each row into a ServerMessage struct + for rows.Next() { + var fromUID string + var toUID string + var content []byte + var timestamp time.Time + if err := rows.Scan(&fromUID, &toUID, &content, ×tamp); err != nil { + log.Panicln("Failed to scan row:", err) + } + message := protocol.NewMessagePacket(fromUID, toUID, content, timestamp) + messagePackets = append(messagePackets, message) + } + if err := rows.Err(); err != nil { + log.Panicln("Error when getting user's messages") + } + + return messagePackets +} + +func (ds DataStore) AddMessageToQueue(uid string, message protocol.SubmitMessagePacket) { + query := ` + INSERT INTO messages (fromUID, toUID, content, timestamp, status) + VALUES (?, ?, ?, ?, 0) + ` + + // Execute the SQL statement + currentTime := time.Now() + _, err := ds.db.Exec(query, uid, message.ToUID, message.Content, currentTime) + if err != nil { + log.Panicln("Error adding message to database") + } +} + +func (ds DataStore) GetUserCertificate(uid string) protocol.Packet { + query := ` + SELECT userCert + FROM users + WHERE UID = ? + ` + + // Execute the SQL query + var userCert []byte + err := ds.db.QueryRow(query, uid).Scan(&userCert) + if err != nil { + log.Panicln("Error getting user certificate from the database") + } + return protocol.NewSendUserCertPacket(uid, userCert) +} diff --git a/Projs/PD1/internal/server/server.go b/Projs/PD1/internal/server/server.go index 25a7978..fcd2073 100644 --- a/Projs/PD1/internal/server/server.go +++ b/Projs/PD1/internal/server/server.go @@ -6,19 +6,23 @@ import ( "fmt" ) -func clientHandler(connection networking.Connection[protocol.Packet]) { - defer connection.Conn.Close() +func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) { + defer connection.Conn.Close() + + // FIX: GET THE UID FROM THE USER CERTIFICATE FROM THE TLS SESSION + uid := "0" for { - pac := connection.Receive() + pac := connection.Receive() switch pac.Flag { - case protocol.ReqPK: - fmt.Println("ReqPK") - case protocol.ReqAllMsg: + case protocol.ReqUserCertPkt: + userCertPacket := dataStore.GetUserCertificate(uid) + connection.Send(userCertPacket) + case protocol.ReqAllMsgPkt: fmt.Println("ReqAllMsg") - case protocol.ReqMsg: + case protocol.ReqMsgPkt: fmt.Println("ReqMsg") - case protocol.SubmitMsg: + case protocol.SubmitMsgPkt: fmt.Println("SubmitMsg") } } @@ -26,13 +30,18 @@ func clientHandler(connection networking.Connection[protocol.Packet]) { } func Run(port int) { - server := networking.NewServer[protocol.Packet](port) - go server.ListenLoop() + //Open connection to DB + dataStore := OpenDB() + defer dataStore.db.Close() + + //Create server listener + server := networking.NewServer[protocol.Packet](port) + go server.ListenLoop() for { //Receive Connection via channel conn := <-server.C //Launch client handler via clientHandler - go clientHandler(conn) + go clientHandler(conn, dataStore) } }