diff --git a/Projs/PD1/go.mod b/Projs/PD1/go.mod index 29ae242..17a5c07 100644 --- a/Projs/PD1/go.mod +++ b/Projs/PD1/go.mod @@ -3,7 +3,9 @@ module PD1 go 1.22.2 require ( - github.com/mattn/go-sqlite3 v1.14.22 // indirect - golang.org/x/crypto v0.11.0 // indirect - software.sslmate.com/src/go-pkcs12 v0.4.0 // indirect + github.com/mattn/go-sqlite3 v1.14.22 + golang.org/x/crypto v0.11.0 + software.sslmate.com/src/go-pkcs12 v0.4.0 ) + +require golang.org/x/sys v0.10.0 // indirect diff --git a/Projs/PD1/go.sum b/Projs/PD1/go.sum index 90a922e..cb0edca 100644 --- a/Projs/PD1/go.sum +++ b/Projs/PD1/go.sum @@ -2,5 +2,7 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/Projs/PD1/internal/client/client.go b/Projs/PD1/internal/client/client.go index 8ffca2c..3f8487e 100644 --- a/Projs/PD1/internal/client/client.go +++ b/Projs/PD1/internal/client/client.go @@ -4,8 +4,10 @@ import ( "PD1/internal/protocol" "PD1/internal/utils/cryptoUtils" "PD1/internal/utils/networking" + "crypto/x509" "flag" "fmt" + "sort" ) func Run() { @@ -16,7 +18,7 @@ func Run() { if flag.NArg() == 0 { panic("No command provided. Use 'help' for instructions.") } - //Get user KeyStore + //Get user KeyStore password := AskUserPassword() clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password) @@ -27,27 +29,79 @@ func Run() { panic("Insufficient arguments for 'send' command. Usage: send ") } uid := flag.Arg(1) - //subject := flag.Arg(2) - //messageContent := readMessageContent() + subject := flag.Arg(2) + messageBody := readMessageBody() + //Turn content to bytes + marshaledSubject := Marshal(subject) + marshaledBody := Marshal(messageBody) cl := networking.NewClient[protocol.Packet](&clientKeyStore) defer cl.Connection.Conn.Close() - certRequestPacket := protocol.NewRequestUserCertPacket(uid) - cl.Connection.Send(certRequestPacket) - - var certPacket protocol.Packet - cl.Connection.Receive(&certPacket) - uidCert := (certPacket.Body).(protocol.SendUserCertPacket) - fmt.Println(uidCert) - - // TODO: Encrypt message - //submitMessage(cl, uid, cipherContent) + uidCert := getUserCert(cl, uid) + if uidCert == nil { + return + } + encryptedSubject := clientKeyStore.EncryptMessageContent(uidCert, marshaledSubject) + encryptedBody := clientKeyStore.EncryptMessageContent(uidCert, marshaledBody) + submitMessage := protocol.NewSubmitMessagePacket(uid, encryptedSubject, encryptedBody) + if !cl.Connection.Send(submitMessage) { + return + } + cl.Connection.Conn.Close() case "askqueue": cl := networking.NewClient[protocol.Packet](&clientKeyStore) defer cl.Connection.Conn.Close() + requestUnreadMsgsQueuePacket := protocol.NewRequestUnreadMsgsQueuePacket() + if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { + return + } + serverMessagePackets, certificates := getManyMessagesInfo(cl) + var clientMessages []ClientMessageInfo + for _, message := range serverMessagePackets { + senderCert, ok := certificates[message.FromUID] + if ok { + decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) + subject := Unmarshal(decryptedSubjectBytes) + clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) + clientMessages = append(clientMessages, clientMessage) + } + } + //Sort the messages + sort.Slice(clientMessages, func(i, j int) bool { + return clientMessages[i].Num > clientMessages[j].Num + }) + + showMessagesInfo(clientMessages) + + //case "getall": + // cl := networking.NewClient[protocol.Packet](&clientKeyStore) + // defer cl.Connection.Conn.Close() + + // requestAllMsgPacket := protocol.NewRequestAllMsgPacket() + // if !cl.Connection.Send(requestAllMsgPacket) { + // return + // } + // serverMessagePackets,certificates := getManyMessages(cl) + // var clientMessages []ClientMessage + // for _, message := range serverMessagePackets { + // senderCert, ok := certificates[message.FromUID] + // if ok { + // decryptedContentBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Content) + // content := UnmarshalContent(decryptedContentBytes) + // clientMessage := newClientMessage(message.FromUID, message.ToUID, content, message.Timestamp) + // clientMessages = append(clientMessages, clientMessage) + // } + // } + // //Sort the messages + // sort.Slice(clientMessages, func(i, j int) bool { + // return clientMessages[i].Timestamp.After(clientMessages[j].Timestamp) + // }) + + // showMessages(clientMessages) + case "getmsg": if flag.NArg() < 2 { panic("Insufficient arguments for 'getmsg' command. Usage: getmsg ") @@ -65,7 +119,52 @@ func Run() { } -func submitMessage(cl networking.Client[protocol.Packet], uid string, content []byte) { - pack := protocol.NewSubmitMessagePacket(uid, content) - cl.Connection.Send(pack) +func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate { + certRequestPacket := protocol.NewRequestUserCertPacket(uid) + if !cl.Connection.Send(certRequestPacket) { + return nil + } + var certPacket *protocol.Packet + certPacket, active := cl.Connection.Receive() + if !active { + return nil + } + uidCertInBytes := protocol.UnmarshalSendUserCertPacket(certPacket.Body) + uidCert, err := x509.ParseCertificate(uidCertInBytes.Certificate) + if err != nil { + return nil + } + return uidCert +} + +func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.ServerMessageInfoPacket, map[string]*x509.Certificate) { + //Create the slice to hold the incoming messages before decrypting + //Create the map to hold the sender certificates + //Create sync mutexes + serverMessageInfoPackets := []protocol.ServerMessageInfoPacket{} + //Run while message isn't the last one + msg := protocol.ServerMessageInfoPacket{} + for !msg.Last { + sendMsgPacket, active := cl.Connection.Receive() + if !active { + return nil, nil + } + msg = protocol.UnmarshalServerMessageInfoPacket(sendMsgPacket.Body) + //Lock and append + serverMessageInfoPackets = append(serverMessageInfoPackets, msg) + } + + //Create Set of needed certificates + senderSet := map[string]bool{} + for _, messageInfo := range serverMessageInfoPackets { + senderSet[messageInfo.FromUID] = true + } + certificatesMap := map[string]*x509.Certificate{} + //Get senders' certificates + for senderUID := range senderSet { + senderCert := getUserCert(cl, senderUID) + fmt.Println("Got a User cert") + certificatesMap[senderUID] = senderCert + } + return serverMessageInfoPackets, certificatesMap } diff --git a/Projs/PD1/internal/client/datastore.go b/Projs/PD1/internal/client/datastore.go index 2ff30e9..55a38ff 100644 --- a/Projs/PD1/internal/client/datastore.go +++ b/Projs/PD1/internal/client/datastore.go @@ -1,15 +1,47 @@ package client -import "time" +import ( + "encoding/json" + "log" + "time" +) -type Content struct { - Subject []byte - Body []byte -} - -type RecievedMessage struct { +type ClientMessage struct { FromUID string ToUID string - Content Content + Subject string + Body string Timestamp time.Time } + +type ClientMessageInfo struct { + Num int + FromUID string + Timestamp time.Time + Subject string +} + +func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage { + return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp} +} + +func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo { + return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp} +} + +func Marshal(data any) []byte { + subject, err := json.Marshal(data) + if err != nil { + log.Panicf("Error when marshalling message: %v", err) + } + return subject +} + +func Unmarshal(data []byte) string { + var c string + err := json.Unmarshal(data, &c) + if err != nil { + log.Panicln("Could not unmarshal data") + } + return c +} diff --git a/Projs/PD1/internal/client/interface.go b/Projs/PD1/internal/client/interface.go index 02cee71..8d49bb6 100644 --- a/Projs/PD1/internal/client/interface.go +++ b/Projs/PD1/internal/client/interface.go @@ -6,7 +6,7 @@ import ( "os" ) -func readMessageContent() string { +func readMessageBody() string { fmt.Println("Enter message content (limited to 1000 bytes):") scanner := bufio.NewScanner(os.Stdin) scanner.Scan() @@ -23,11 +23,10 @@ func AskUserPassword() string { } func commandError() { - fmt.Println("MSG SERVICE: command error!") - showHelp() + fmt.Println("MSG SERVICE: command error!") + showHelp() } - func showHelp() { fmt.Println("Comandos da aplicação cliente:") fmt.Println("-user : Especifica o ficheiro com dados do utilizador. Por omissão, será assumido que esse ficheiro é userdata.p12.") @@ -36,3 +35,16 @@ func showHelp() { fmt.Println("getmsg : Solicita ao servidor o envio da mensagem da sua queue com número .") fmt.Println("help: Imprime instruções de uso do programa.") } + +func showMessagesInfo(messages []ClientMessageInfo) { + for _, message := range messages { + fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) + } +} + +func showMessage(message ClientMessage) { + fmt.Printf("From:%v\n", message.FromUID) + fmt.Printf("To:%v\n", message.ToUID) + fmt.Printf("Subject:%v\n", message.Subject) + fmt.Printf("Body:%v\n", message.Body) +} diff --git a/Projs/PD1/internal/protocol/protocol.go b/Projs/PD1/internal/protocol/protocol.go index 6134233..6aeecec 100644 --- a/Projs/PD1/internal/protocol/protocol.go +++ b/Projs/PD1/internal/protocol/protocol.go @@ -1,114 +1,219 @@ package protocol import ( - "time" + "encoding/json" + "fmt" + "time" ) type PacketType int const ( - ReqUserCertPkt PacketType = iota - ReqAllMsgPkt - ReqMsgPkt - SubmitMsgPkt - SendUserCertPkt - ServerMsgPkt + ReqUserCertPkt PacketType = iota + ReqMsgsQueue + ReqMsgPkt + SubmitMsgPkt + SendUserCertPkt + ServerMsgInfoPkt + ServerMsgPkt ) -// Define interfaces for packet bodies type ( - RequestUserCertPacket struct { - UID string `json:"uid"` - } + RequestUserCertPacket struct { + UID string `json:"uid"` + } - RequestAllMsgPacket struct { - FromUID string `json:"from_uid"` - } + RequestMsgsQueuePacket struct { + } - RequestMsgPacket struct { - Num uint16 `json:"num"` - } + RequestMsgPacket struct { + Num int `json:"num"` + } - SubmitMessagePacket struct { - ToUID string `json:"to_uid"` - Content []byte `json:"content"` - } + SubmitMessagePacket struct { + ToUID string `json:"to_uid"` + Subject []byte `json:"subject"` + Body []byte `json:"body"` + } - SendUserCertPacket struct { - UID string `json:"uid"` - Key []byte `json:"key"` - } + SendUserCertPacket struct { + UID string `json:"uid"` + Certificate []byte `json:"certificate"` + } - ServerMessagePacket struct { - FromUID string `json:"from_uid"` - ToUID string `json:"to_uid"` - Content []byte `json:"content"` - Timestamp time.Time `json:"timestamp"` - } + ServerMessageInfoPacket struct { + Num int `json:"num"` + FromUID string `json:"from_uid"` + Subject []byte `json:"subject"` + Timestamp time.Time `json:"timestamp"` + Last bool `json:"last"` + } + ServerMessagePacket struct { + FromUID string `json:"from_uid"` + ToUID string `json:"to_uid"` + Subject []byte `json:"subject"` + Body []byte `json:"body"` + Timestamp time.Time `json:"timestamp"` + } ) type PacketBody interface{} type Packet struct { - Flag PacketType `json:"flag"` - Body PacketBody `json:"body"` + Flag PacketType `json:"flag"` + Body PacketBody `json:"body"` } func NewRequestUserCertPacket(UID string) Packet { - return Packet{ - Flag: ReqUserCertPkt, - Body: RequestUserCertPacket{ - UID: UID, - }, - } + return Packet{ + Flag: ReqUserCertPkt, + Body: RequestUserCertPacket{ + UID: UID, + }, + } } -func NewRequestAllMsgPacket(fromUID string) Packet { - return Packet{ - Flag: ReqAllMsgPkt, - Body: RequestAllMsgPacket{ - FromUID: fromUID, - }, - } +func NewRequestUnreadMsgsQueuePacket() Packet { + return Packet{ + Flag: ReqMsgsQueue, + Body: RequestMsgsQueuePacket{}, + } } -func NewRequestMsgPacket(num uint16) Packet { - return Packet{ - Flag: ReqMsgPkt, - Body: RequestMsgPacket{ - Num: num, - }, - } +func NewRequestMsgPacket(num int) Packet { + return Packet{ + Flag: ReqMsgPkt, + Body: RequestMsgPacket{ + Num: num, + }, + } } -func NewSubmitMessagePacket(toUID string, content []byte) Packet { - return Packet{ - Flag: SubmitMsgPkt, - Body: SubmitMessagePacket{ - ToUID: toUID, - Content: content, - }, - } +func NewSubmitMessagePacket(toUID string, subject []byte, body []byte) Packet { + return Packet{ + Flag: SubmitMsgPkt, + Body: SubmitMessagePacket{ + ToUID: toUID, + Subject: subject, + Body: body, + }, + } } -func NewSendUserCertPacket(uid string, key []byte) Packet { - return Packet{ - Flag: SendUserCertPkt, - Body: SendUserCertPacket{ - UID: uid, - Key: key, - }, - } +func NewSendUserCertPacket(uid string, certificate []byte) Packet { + return Packet{ + Flag: SendUserCertPkt, + Body: SendUserCertPacket{ + UID: uid, + Certificate: certificate, + }, + } +} +func NewServerMessageInfoPacket(num int, fromUID string, subject []byte, timestamp time.Time, last bool) Packet { + return Packet{ + Flag: ServerMsgInfoPkt, + Body: ServerMessageInfoPacket{ + Num: num, + FromUID: fromUID, + Subject: subject, + Timestamp: timestamp, + Last: last, + }, + } } -func NewServerMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet { - return Packet{ - Flag: ServerMsgPkt, - Body: ServerMessagePacket{ - FromUID: fromUID, - ToUID: toUID, - Content: content, - Timestamp: timestamp, - }, - } +func NewServerMessagePacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet { + return Packet{ + Flag: ServerMsgPkt, + Body: ServerMessagePacket{ + FromUID: fromUID, + ToUID: toUID, + Subject: subject, + Body: body, + Timestamp: timestamp, + }, + } +} + +func UnmarshalRequestUserCertPacket(data PacketBody) RequestUserCertPacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet RequestUserCertPacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into RequestUserCertPacket: %v", err)) + } + return packet +} + +func UnmarshalRequestMsgsQueuePacket(data PacketBody) RequestMsgsQueuePacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet RequestMsgsQueuePacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into RequestMsgsQueuePacket: %v", err)) + } + return packet +} + +func UnmarshalRequestMsgPacket(data PacketBody) RequestMsgPacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet RequestMsgPacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into RequestMsgPacket: %v", err)) + } + return packet +} + +func UnmarshalSubmitMessagePacket(data PacketBody) SubmitMessagePacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet SubmitMessagePacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into SubmitMessagePacket: %v", err)) + } + return packet +} + +func UnmarshalSendUserCertPacket(data PacketBody) SendUserCertPacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet SendUserCertPacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into SendUserCertPacket: %v", err)) + } + return packet +} +func UnmarshalServerMessageInfoPacket(data PacketBody) ServerMessageInfoPacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet ServerMessageInfoPacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into ServerMessageInfoPacket: %v", err)) + } + return packet +} + +func UnmarshalServerMessagePacket(data PacketBody) ServerMessagePacket { + jsonData, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal data: %v", err)) + } + var packet ServerMessagePacket + if err := json.Unmarshal(jsonData, &packet); err != nil { + panic(fmt.Errorf("failed to unmarshal into ServerMessagePacket: %v", err)) + } + return packet } diff --git a/Projs/PD1/internal/server/datastore.go b/Projs/PD1/internal/server/datastore.go index 7120ea9..75e54e9 100644 --- a/Projs/PD1/internal/server/datastore.go +++ b/Projs/PD1/internal/server/datastore.go @@ -2,7 +2,9 @@ package server import ( "PD1/internal/protocol" + "crypto/x509" "database/sql" + "fmt" "log" "time" @@ -18,7 +20,9 @@ func OpenDB() DataStore { if err != nil { log.Fatalln("Error opening db file") } - return DataStore{db: db} + ds := DataStore{db: db} + ds.CreateTables() + return ds } func (ds DataStore) CreateTables() error { @@ -28,6 +32,7 @@ func (ds DataStore) CreateTables() error { userCert BLOB )`) if err != nil { + fmt.Println("Error creating users table", err) return err } @@ -36,23 +41,26 @@ func (ds DataStore) CreateTables() error { fromUID TEXT, toUID TEXT, timestamp TIMESTAMP, - content BLOB, + subject BLOB, + body BLOB, + status INT CHECK (status IN (0,1)), PRIMARY KEY (toUID, fromUID, timestamp), FOREIGN KEY(fromUID) REFERENCES users(UID), FOREIGN KEY(toUID) REFERENCES users(UID) )`) if err != nil { + fmt.Println("Error creating messages table", err) return err } return nil } -func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessagePacket { +func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { var serverMessage protocol.ServerMessagePacket query := ` - SELECT fromUID, toUID, content, timestamp + SELECT fromUID, toUID, subject, body, timestamp FROM messages WHERE toUID = ? AND status = 0 @@ -61,15 +69,16 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessag ` // Execute the query row := ds.db.QueryRow(query, toUID, position) - err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Content, &serverMessage.Timestamp) + err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp) if err != nil { - log.Panicln("Could not map DB query to ServerMessage") + log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err) } - return serverMessage + + return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) } -func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error { +func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { query := ` UPDATE messages SET status = 1 @@ -81,61 +90,90 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error { // Execute the SQL statement _, err := ds.db.Exec(query, toUID, position) if err != nil { - return err + log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err) } - - return nil } -func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet { - var messagePackets []protocol.Packet +func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet { + var messageInfoPackets []protocol.Packet // Query to retrieve all messages from the user's queue query := ` - SELECT fromUID, toUID, content, timestamp - FROM messages - WHERE toUID = ? - AND status = 0 - ORDER BY timestamp + SELECT + fromUID, + toUID, + timestamp, + queue_position, + subject, + status + FROM ( + SELECT + fromUID, + toUID, + timestamp, + ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position, + subject, + status + FROM + messages + WHERE + toUID = ? + ) AS ranked_messages + WHERE + status = 0 + ORDER BY + timestamp; ` // Execute the query rows, err := ds.db.Query(query, toUID) if err != nil { - log.Panicln("Failed to execute query:", err) + log.Printf("Error getting all messages for UID %v: %v", toUID, err) } defer rows.Close() // Iterate through the result set and scan each row into a ServerMessage struct - for rows.Next() { + //First row + if !rows.Next() { + return []protocol.Packet{} + } + for { var fromUID string - var toUID string - var content []byte + var subject []byte var timestamp time.Time - if err := rows.Scan(&fromUID, &toUID, &content, ×tamp); err != nil { - log.Panicln("Failed to scan row:", err) + var queuePosition, status int + if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil { + panic(err) + } + var message protocol.Packet + hasNext := rows.Next() + if !hasNext { + message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true) + messageInfoPackets = append(messageInfoPackets, message) + break + } else { + message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false) + messageInfoPackets = append(messageInfoPackets, message) } - message := protocol.NewServerMessagePacket(fromUID, toUID, content, timestamp) - messagePackets = append(messagePackets, message) } if err := rows.Err(); err != nil { - log.Panicln("Error when getting user's messages") + log.Printf("Error when getting messages for UID %v: %v", toUID, err) } - return messagePackets + return messageInfoPackets } -func (ds DataStore) AddMessageToQueue(uid string, message protocol.SubmitMessagePacket) { +func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) { query := ` - INSERT INTO messages (fromUID, toUID, content, timestamp, status) - VALUES (?, ?, ?, ?, 0) + INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) + VALUES (?, ?, ?, ?, ?, 0) ` // Execute the SQL statement currentTime := time.Now() - _, err := ds.db.Exec(query, uid, message.ToUID, message.Content, currentTime) + _, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime) if err != nil { - log.Panicln("Error adding message to database") + log.Printf("Error adding message to UID %v: %v", fromUID, err) } } @@ -147,29 +185,53 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet { ` // Execute the SQL query - var userCert []byte - err := ds.db.QueryRow(query, uid).Scan(&userCert) - if err != nil { - log.Panicln("Error getting user certificate from the database") + var userCertBytes []byte + err := ds.db.QueryRow(query, uid).Scan(&userCertBytes) + if err == sql.ErrNoRows { + log.Panicf("No certificate for UID %v found in the database", uid) } - return protocol.NewSendUserCertPacket(uid, userCert) + //userCert,err := x509.ParseCertificate(userCertBytes) + //if err!=nil { + // log.Panicf("Error parsing certificate for UID %v",uid) + //} + return protocol.NewSendUserCertPacket(uid, userCertBytes) } -func userExists(db *sql.DB, uid string) bool { - // Prepare the SQL statement for checking if a user exists - query := ` +func (ds DataStore) userExists(uid string) bool { + // Prepare the SQL statement for checking if a user exists + query := ` SELECT COUNT(*) FROM users WHERE UID = ? ` - var count int - // Execute the SQL query - err := db.QueryRow(query, uid).Scan(&count) - if err != nil { - log.Panicln("Error checking if user exists") - } - - // If count is greater than 0, the user exists - return count > 0 + var count int + // Execute the SQL query + err := ds.db.QueryRow(query, uid).Scan(&count) + if err == sql.ErrNoRows { + log.Printf("User with UID %v does not exist", uid) + return false + } else { + return true + } +} + +func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) { + // Check if the user already exists + if ds.userExists(uid) { + log.Printf("User certificate for UID %s already exists.\n", uid) + return + } + + // Insert the user certificate + insertQuery := ` + INSERT INTO users (UID, userCert) + VALUES (?, ?) + ` + _, err := ds.db.Exec(insertQuery, uid, cert.Raw) + if err != nil { + log.Printf("Error storing user certificate for UID %s: %v\n", uid, err) + return + } + log.Printf("User certificate for UID %s stored successfully.\n", uid) } diff --git a/Projs/PD1/internal/server/server.go b/Projs/PD1/internal/server/server.go index c963342..ba22925 100644 --- a/Projs/PD1/internal/server/server.go +++ b/Projs/PD1/internal/server/server.go @@ -16,28 +16,44 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore oidMap := cryptoUtils.ExtractAllOIDValues(clientCert) //Get the UID of this user UID := oidMap["2.5.4.65"] - if UID=="" { + if UID == "" { panic("User certificate does not specify it's PSEUDONYM") } - + dataStore.storeUserCertIfNotExists(UID, *clientCert) +F: for { - var pac protocol.Packet - connection.Receive(&pac) + pac, active := connection.Receive() + if !active { + break F + } switch pac.Flag { case protocol.ReqUserCertPkt: - fmt.Printf("Type of pac.Body: %T\n", pac.Body) - UserCertPacket, ok := (pac.Body).(protocol.RequestUserCertPacket) - if !ok { - panic("Could not cast packet to it's type") + reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body) + userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) + if active := connection.Send(userCertPacket); !active { + break F + } + case protocol.ReqMsgsQueue: + _ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body) + messages := dataStore.GetUnreadMessagesInfoQueue(UID) + fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages)) + for _, message := range messages { + if !connection.Send(message) { + break + } } - userCertPacket := dataStore.GetUserCertificate(UserCertPacket.UID) - connection.Send(userCertPacket) - case protocol.ReqAllMsgPkt: - fmt.Println("ReqAllMsg") case protocol.ReqMsgPkt: - fmt.Println("ReqMsg") + reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body) + message := dataStore.GetMessage(UID, reqMsg.Num) + if active := connection.Send(message); !active { + break F + } case protocol.SubmitMsgPkt: - fmt.Println("SubmitMsg") + submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body) + if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) { + dataStore.AddMessageToQueue(UID, submitMsg) + } + } } diff --git a/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go b/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go index 3c5dd1c..49bfd06 100644 --- a/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go +++ b/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go @@ -178,7 +178,6 @@ func (k KeyStore) EncryptMessageContent(receiverCert *x509.Certificate, content } func (k KeyStore) DecryptMessageContent(senderCert *x509.Certificate, cipherContent []byte) []byte { - return nil } diff --git a/Projs/PD1/internal/utils/networking/connection.go b/Projs/PD1/internal/utils/networking/connection.go index 28a5997..1e40f9b 100644 --- a/Projs/PD1/internal/utils/networking/connection.go +++ b/Projs/PD1/internal/utils/networking/connection.go @@ -4,6 +4,8 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "io" + "log" ) type Connection[T any] struct { @@ -20,16 +22,33 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] { } } -func (c Connection[T]) Send(obj T) { - if err := c.encoder.Encode(&obj); err != nil { - panic("Failed encoding data or sending it to connection") - } +func (c Connection[T]) Send(obj T) bool { + if err := c.encoder.Encode(&obj); err!=nil { + if err == io.EOF { + log.Println("Connection closed by peer") + //Return false as connection not active + return false + } else { + log.Panic(err) + } + } + //Return true as connection active + return true } -func (c Connection[T]) Receive(objPtr *T) { - if err := c.decoder.Decode(objPtr); err != nil { - panic("Failed decoding data or reading it from connection") +func (c Connection[T]) Receive() (*T, bool) { + var obj T + if err := c.decoder.Decode(&obj); err != nil { + if err == io.EOF { + log.Println("Connection closed by peer") + //Return false as connection not active + return nil,false + } else { + log.Panic(err) + } } + //Return true as connection active + return &obj, true } func (c Connection[T]) GetPeerCertificate() *x509.Certificate { diff --git a/Projs/PD1/internal/utils/networking/server.go b/Projs/PD1/internal/utils/networking/server.go index 60e9c70..49f5948 100644 --- a/Projs/PD1/internal/utils/networking/server.go +++ b/Projs/PD1/internal/utils/networking/server.go @@ -43,7 +43,6 @@ func (s *Server[T]) ListenLoop() { state := tlsConn.ConnectionState() if len(state.PeerCertificates) == 0 { - fmt.Println(state.PeerCertificates) log.Panicln("Client did not provide a certificate") } conn := NewConnection[T](tlsConn) diff --git a/Projs/PD1/server.db b/Projs/PD1/server.db new file mode 100644 index 0000000..6496786 Binary files /dev/null and b/Projs/PD1/server.db differ diff --git a/Projs/PD1/tokefile.toml b/Projs/PD1/tokefile.toml index 2c6fe08..055056f 100644 --- a/Projs/PD1/tokefile.toml +++ b/Projs/PD1/tokefile.toml @@ -10,21 +10,16 @@ cmd="@@" cmd="go build" [targets.server] -deps=["check"] cmd="go run ./cmd/server/server.go" [targets.client1] -deps=["check"] -cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CLI1 testsubject" +cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject" [targets.FakeClient1] -deps=["check"] -cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CLI1 testsubject" +cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CL2 testsubject" [targets.client2] -deps=["check"] -cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CLI1 testsubject" +cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CL3 testsubject" [targets.client3] -deps=["check"] -cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CLI1 testsubject" +cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CL1 testsubject"