[PD1] Fixed almost everything

This commit is contained in:
Afonso Franco 2024-04-19 23:59:26 +01:00
parent 39a0e5c01f
commit 7b3172a850
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
13 changed files with 534 additions and 192 deletions

View file

@ -3,7 +3,9 @@ module PD1
go 1.22.2 go 1.22.2
require ( require (
github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mattn/go-sqlite3 v1.14.22
golang.org/x/crypto v0.11.0 // indirect golang.org/x/crypto v0.11.0
software.sslmate.com/src/go-pkcs12 v0.4.0 // indirect software.sslmate.com/src/go-pkcs12 v0.4.0
) )
require golang.org/x/sys v0.10.0 // indirect

View file

@ -2,5 +2,7 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

View file

@ -4,8 +4,10 @@ import (
"PD1/internal/protocol" "PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils" "PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking" "PD1/internal/utils/networking"
"crypto/x509"
"flag" "flag"
"fmt" "fmt"
"sort"
) )
func Run() { func Run() {
@ -27,27 +29,79 @@ func Run() {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>") panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
} }
uid := flag.Arg(1) uid := flag.Arg(1)
//subject := flag.Arg(2) subject := flag.Arg(2)
//messageContent := readMessageContent() messageBody := readMessageBody()
//Turn content to bytes
marshaledSubject := Marshal(subject)
marshaledBody := Marshal(messageBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore) cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close() defer cl.Connection.Conn.Close()
certRequestPacket := protocol.NewRequestUserCertPacket(uid) uidCert := getUserCert(cl, uid)
cl.Connection.Send(certRequestPacket) if uidCert == nil {
return
var certPacket protocol.Packet }
cl.Connection.Receive(&certPacket) encryptedSubject := clientKeyStore.EncryptMessageContent(uidCert, marshaledSubject)
uidCert := (certPacket.Body).(protocol.SendUserCertPacket) encryptedBody := clientKeyStore.EncryptMessageContent(uidCert, marshaledBody)
fmt.Println(uidCert) submitMessage := protocol.NewSubmitMessagePacket(uid, encryptedSubject, encryptedBody)
if !cl.Connection.Send(submitMessage) {
// TODO: Encrypt message return
//submitMessage(cl, uid, cipherContent) }
cl.Connection.Conn.Close()
case "askqueue": case "askqueue":
cl := networking.NewClient[protocol.Packet](&clientKeyStore) cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close() defer cl.Connection.Conn.Close()
requestUnreadMsgsQueuePacket := protocol.NewRequestUnreadMsgsQueuePacket()
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
}
serverMessagePackets, certificates := getManyMessagesInfo(cl)
var clientMessages []ClientMessageInfo
for _, message := range serverMessagePackets {
senderCert, ok := certificates[message.FromUID]
if ok {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject := Unmarshal(decryptedSubjectBytes)
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage)
}
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
return clientMessages[i].Num > clientMessages[j].Num
})
showMessagesInfo(clientMessages)
//case "getall":
// cl := networking.NewClient[protocol.Packet](&clientKeyStore)
// defer cl.Connection.Conn.Close()
// requestAllMsgPacket := protocol.NewRequestAllMsgPacket()
// if !cl.Connection.Send(requestAllMsgPacket) {
// return
// }
// serverMessagePackets,certificates := getManyMessages(cl)
// var clientMessages []ClientMessage
// for _, message := range serverMessagePackets {
// senderCert, ok := certificates[message.FromUID]
// if ok {
// decryptedContentBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Content)
// content := UnmarshalContent(decryptedContentBytes)
// clientMessage := newClientMessage(message.FromUID, message.ToUID, content, message.Timestamp)
// clientMessages = append(clientMessages, clientMessage)
// }
// }
// //Sort the messages
// sort.Slice(clientMessages, func(i, j int) bool {
// return clientMessages[i].Timestamp.After(clientMessages[j].Timestamp)
// })
// showMessages(clientMessages)
case "getmsg": case "getmsg":
if flag.NArg() < 2 { if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>") panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
@ -65,7 +119,52 @@ func Run() {
} }
func submitMessage(cl networking.Client[protocol.Packet], uid string, content []byte) { func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
pack := protocol.NewSubmitMessagePacket(uid, content) certRequestPacket := protocol.NewRequestUserCertPacket(uid)
cl.Connection.Send(pack) if !cl.Connection.Send(certRequestPacket) {
return nil
}
var certPacket *protocol.Packet
certPacket, active := cl.Connection.Receive()
if !active {
return nil
}
uidCertInBytes := protocol.UnmarshalSendUserCertPacket(certPacket.Body)
uidCert, err := x509.ParseCertificate(uidCertInBytes.Certificate)
if err != nil {
return nil
}
return uidCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) ([]protocol.ServerMessageInfoPacket, map[string]*x509.Certificate) {
//Create the slice to hold the incoming messages before decrypting
//Create the map to hold the sender certificates
//Create sync mutexes
serverMessageInfoPackets := []protocol.ServerMessageInfoPacket{}
//Run while message isn't the last one
msg := protocol.ServerMessageInfoPacket{}
for !msg.Last {
sendMsgPacket, active := cl.Connection.Receive()
if !active {
return nil, nil
}
msg = protocol.UnmarshalServerMessageInfoPacket(sendMsgPacket.Body)
//Lock and append
serverMessageInfoPackets = append(serverMessageInfoPackets, msg)
}
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range serverMessageInfoPackets {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
fmt.Println("Got a User cert")
certificatesMap[senderUID] = senderCert
}
return serverMessageInfoPackets, certificatesMap
} }

View file

@ -1,15 +1,47 @@
package client package client
import "time" import (
"encoding/json"
"log"
"time"
)
type Content struct { type ClientMessage struct {
Subject []byte
Body []byte
}
type RecievedMessage struct {
FromUID string FromUID string
ToUID string ToUID string
Content Content Subject string
Body string
Timestamp time.Time Timestamp time.Time
} }
type ClientMessageInfo struct {
Num int
FromUID string
Timestamp time.Time
Subject string
}
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
}
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo {
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp}
}
func Marshal(data any) []byte {
subject, err := json.Marshal(data)
if err != nil {
log.Panicf("Error when marshalling message: %v", err)
}
return subject
}
func Unmarshal(data []byte) string {
var c string
err := json.Unmarshal(data, &c)
if err != nil {
log.Panicln("Could not unmarshal data")
}
return c
}

View file

@ -6,7 +6,7 @@ import (
"os" "os"
) )
func readMessageContent() string { func readMessageBody() string {
fmt.Println("Enter message content (limited to 1000 bytes):") fmt.Println("Enter message content (limited to 1000 bytes):")
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
scanner.Scan() scanner.Scan()
@ -27,7 +27,6 @@ func commandError() {
showHelp() showHelp()
} }
func showHelp() { func showHelp() {
fmt.Println("Comandos da aplicação cliente:") fmt.Println("Comandos da aplicação cliente:")
fmt.Println("-user <FNAME>: Especifica o ficheiro com dados do utilizador. Por omissão, será assumido que esse ficheiro é userdata.p12.") fmt.Println("-user <FNAME>: Especifica o ficheiro com dados do utilizador. Por omissão, será assumido que esse ficheiro é userdata.p12.")
@ -36,3 +35,16 @@ func showHelp() {
fmt.Println("getmsg <NUM>: Solicita ao servidor o envio da mensagem da sua queue com número <NUM>.") fmt.Println("getmsg <NUM>: Solicita ao servidor o envio da mensagem da sua queue com número <NUM>.")
fmt.Println("help: Imprime instruções de uso do programa.") fmt.Println("help: Imprime instruções de uso do programa.")
} }
func showMessagesInfo(messages []ClientMessageInfo) {
for _, message := range messages {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
}
func showMessage(message ClientMessage) {
fmt.Printf("From:%v\n", message.FromUID)
fmt.Printf("To:%v\n", message.ToUID)
fmt.Printf("Subject:%v\n", message.Subject)
fmt.Printf("Body:%v\n", message.Body)
}

View file

@ -1,6 +1,8 @@
package protocol package protocol
import ( import (
"encoding/json"
"fmt"
"time" "time"
) )
@ -8,41 +10,49 @@ type PacketType int
const ( const (
ReqUserCertPkt PacketType = iota ReqUserCertPkt PacketType = iota
ReqAllMsgPkt ReqMsgsQueue
ReqMsgPkt ReqMsgPkt
SubmitMsgPkt SubmitMsgPkt
SendUserCertPkt SendUserCertPkt
ServerMsgInfoPkt
ServerMsgPkt ServerMsgPkt
) )
// Define interfaces for packet bodies
type ( type (
RequestUserCertPacket struct { RequestUserCertPacket struct {
UID string `json:"uid"` UID string `json:"uid"`
} }
RequestAllMsgPacket struct { RequestMsgsQueuePacket struct {
FromUID string `json:"from_uid"`
} }
RequestMsgPacket struct { RequestMsgPacket struct {
Num uint16 `json:"num"` Num int `json:"num"`
} }
SubmitMessagePacket struct { SubmitMessagePacket struct {
ToUID string `json:"to_uid"` ToUID string `json:"to_uid"`
Content []byte `json:"content"` Subject []byte `json:"subject"`
Body []byte `json:"body"`
} }
SendUserCertPacket struct { SendUserCertPacket struct {
UID string `json:"uid"` UID string `json:"uid"`
Key []byte `json:"key"` Certificate []byte `json:"certificate"`
} }
ServerMessageInfoPacket struct {
Num int `json:"num"`
FromUID string `json:"from_uid"`
Subject []byte `json:"subject"`
Timestamp time.Time `json:"timestamp"`
Last bool `json:"last"`
}
ServerMessagePacket struct { ServerMessagePacket struct {
FromUID string `json:"from_uid"` FromUID string `json:"from_uid"`
ToUID string `json:"to_uid"` ToUID string `json:"to_uid"`
Content []byte `json:"content"` Subject []byte `json:"subject"`
Body []byte `json:"body"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
} }
) )
@ -63,16 +73,14 @@ func NewRequestUserCertPacket(UID string) Packet {
} }
} }
func NewRequestAllMsgPacket(fromUID string) Packet { func NewRequestUnreadMsgsQueuePacket() Packet {
return Packet{ return Packet{
Flag: ReqAllMsgPkt, Flag: ReqMsgsQueue,
Body: RequestAllMsgPacket{ Body: RequestMsgsQueuePacket{},
FromUID: fromUID,
},
} }
} }
func NewRequestMsgPacket(num uint16) Packet { func NewRequestMsgPacket(num int) Packet {
return Packet{ return Packet{
Flag: ReqMsgPkt, Flag: ReqMsgPkt,
Body: RequestMsgPacket{ Body: RequestMsgPacket{
@ -81,34 +89,131 @@ func NewRequestMsgPacket(num uint16) Packet {
} }
} }
func NewSubmitMessagePacket(toUID string, content []byte) Packet { func NewSubmitMessagePacket(toUID string, subject []byte, body []byte) Packet {
return Packet{ return Packet{
Flag: SubmitMsgPkt, Flag: SubmitMsgPkt,
Body: SubmitMessagePacket{ Body: SubmitMessagePacket{
ToUID: toUID, ToUID: toUID,
Content: content, Subject: subject,
Body: body,
}, },
} }
} }
func NewSendUserCertPacket(uid string, key []byte) Packet { func NewSendUserCertPacket(uid string, certificate []byte) Packet {
return Packet{ return Packet{
Flag: SendUserCertPkt, Flag: SendUserCertPkt,
Body: SendUserCertPacket{ Body: SendUserCertPacket{
UID: uid, UID: uid,
Key: key, Certificate: certificate,
},
}
}
func NewServerMessageInfoPacket(num int, fromUID string, subject []byte, timestamp time.Time, last bool) Packet {
return Packet{
Flag: ServerMsgInfoPkt,
Body: ServerMessageInfoPacket{
Num: num,
FromUID: fromUID,
Subject: subject,
Timestamp: timestamp,
Last: last,
}, },
} }
} }
func NewServerMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet { func NewServerMessagePacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return Packet{ return Packet{
Flag: ServerMsgPkt, Flag: ServerMsgPkt,
Body: ServerMessagePacket{ Body: ServerMessagePacket{
FromUID: fromUID, FromUID: fromUID,
ToUID: toUID, ToUID: toUID,
Content: content, Subject: subject,
Body: body,
Timestamp: timestamp, Timestamp: timestamp,
}, },
} }
} }
func UnmarshalRequestUserCertPacket(data PacketBody) RequestUserCertPacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestUserCertPacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestUserCertPacket: %v", err))
}
return packet
}
func UnmarshalRequestMsgsQueuePacket(data PacketBody) RequestMsgsQueuePacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestMsgsQueuePacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgsQueuePacket: %v", err))
}
return packet
}
func UnmarshalRequestMsgPacket(data PacketBody) RequestMsgPacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet RequestMsgPacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into RequestMsgPacket: %v", err))
}
return packet
}
func UnmarshalSubmitMessagePacket(data PacketBody) SubmitMessagePacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet SubmitMessagePacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SubmitMessagePacket: %v", err))
}
return packet
}
func UnmarshalSendUserCertPacket(data PacketBody) SendUserCertPacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet SendUserCertPacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendUserCertPacket: %v", err))
}
return packet
}
func UnmarshalServerMessageInfoPacket(data PacketBody) ServerMessageInfoPacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ServerMessageInfoPacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessageInfoPacket: %v", err))
}
return packet
}
func UnmarshalServerMessagePacket(data PacketBody) ServerMessagePacket {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ServerMessagePacket
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into ServerMessagePacket: %v", err))
}
return packet
}

View file

@ -2,7 +2,9 @@ package server
import ( import (
"PD1/internal/protocol" "PD1/internal/protocol"
"crypto/x509"
"database/sql" "database/sql"
"fmt"
"log" "log"
"time" "time"
@ -18,7 +20,9 @@ func OpenDB() DataStore {
if err != nil { if err != nil {
log.Fatalln("Error opening db file") log.Fatalln("Error opening db file")
} }
return DataStore{db: db} ds := DataStore{db: db}
ds.CreateTables()
return ds
} }
func (ds DataStore) CreateTables() error { func (ds DataStore) CreateTables() error {
@ -28,6 +32,7 @@ func (ds DataStore) CreateTables() error {
userCert BLOB userCert BLOB
)`) )`)
if err != nil { if err != nil {
fmt.Println("Error creating users table", err)
return err return err
} }
@ -36,23 +41,26 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT, fromUID TEXT,
toUID TEXT, toUID TEXT,
timestamp TIMESTAMP, timestamp TIMESTAMP,
content BLOB, subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
PRIMARY KEY (toUID, fromUID, timestamp), PRIMARY KEY (toUID, fromUID, timestamp),
FOREIGN KEY(fromUID) REFERENCES users(UID), FOREIGN KEY(fromUID) REFERENCES users(UID),
FOREIGN KEY(toUID) REFERENCES users(UID) FOREIGN KEY(toUID) REFERENCES users(UID)
)`) )`)
if err != nil { if err != nil {
fmt.Println("Error creating messages table", err)
return err return err
} }
return nil return nil
} }
func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessagePacket { func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket var serverMessage protocol.ServerMessagePacket
query := ` query := `
SELECT fromUID, toUID, content, timestamp SELECT fromUID, toUID, subject, body, timestamp
FROM messages FROM messages
WHERE toUID = ? WHERE toUID = ?
AND status = 0 AND status = 0
@ -61,15 +69,16 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessag
` `
// Execute the query // Execute the query
row := ds.db.QueryRow(query, toUID, position) row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Content, &serverMessage.Timestamp) err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err != nil { if err != nil {
log.Panicln("Could not map DB query to ServerMessage") log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
} }
return serverMessage
return protocol.NewServerMessagePacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
} }
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error { func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
query := ` query := `
UPDATE messages UPDATE messages
SET status = 1 SET status = 1
@ -81,61 +90,90 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error {
// Execute the SQL statement // Execute the SQL statement
_, err := ds.db.Exec(query, toUID, position) _, err := ds.db.Exec(query, toUID, position)
if err != nil { if err != nil {
return err log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err)
} }
return nil
} }
func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet { func (ds DataStore) GetUnreadMessagesInfoQueue(toUID string) []protocol.Packet {
var messagePackets []protocol.Packet var messageInfoPackets []protocol.Packet
// Query to retrieve all messages from the user's queue // Query to retrieve all messages from the user's queue
query := ` query := `
SELECT fromUID, toUID, content, timestamp SELECT
FROM messages fromUID,
WHERE toUID = ? toUID,
AND status = 0 timestamp,
ORDER BY timestamp queue_position,
subject,
status
FROM (
SELECT
fromUID,
toUID,
timestamp,
ROW_NUMBER() OVER (PARTITION BY toUID ORDER BY timestamp) - 1 AS queue_position,
subject,
status
FROM
messages
WHERE
toUID = ?
) AS ranked_messages
WHERE
status = 0
ORDER BY
timestamp;
` `
// Execute the query // Execute the query
rows, err := ds.db.Query(query, toUID) rows, err := ds.db.Query(query, toUID)
if err != nil { if err != nil {
log.Panicln("Failed to execute query:", err) log.Printf("Error getting all messages for UID %v: %v", toUID, err)
} }
defer rows.Close() defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct // Iterate through the result set and scan each row into a ServerMessage struct
for rows.Next() { //First row
var fromUID string if !rows.Next() {
var toUID string return []protocol.Packet{}
var content []byte }
var timestamp time.Time for {
if err := rows.Scan(&fromUID, &toUID, &content, &timestamp); err != nil { var fromUID string
log.Panicln("Failed to scan row:", err) var subject []byte
var timestamp time.Time
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
}
var message protocol.Packet
hasNext := rows.Next()
if !hasNext {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, true)
messageInfoPackets = append(messageInfoPackets, message)
break
} else {
message = protocol.NewServerMessageInfoPacket(queuePosition, fromUID, subject, timestamp, false)
messageInfoPackets = append(messageInfoPackets, message)
} }
message := protocol.NewServerMessagePacket(fromUID, toUID, content, timestamp)
messagePackets = append(messagePackets, message)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Panicln("Error when getting user's messages") log.Printf("Error when getting messages for UID %v: %v", toUID, err)
} }
return messagePackets return messageInfoPackets
} }
func (ds DataStore) AddMessageToQueue(uid string, message protocol.SubmitMessagePacket) { func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SubmitMessagePacket) {
query := ` query := `
INSERT INTO messages (fromUID, toUID, content, timestamp, status) INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, 0) VALUES (?, ?, ?, ?, ?, 0)
` `
// Execute the SQL statement // Execute the SQL statement
currentTime := time.Now() currentTime := time.Now()
_, err := ds.db.Exec(query, uid, message.ToUID, message.Content, currentTime) _, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil { if err != nil {
log.Panicln("Error adding message to database") log.Printf("Error adding message to UID %v: %v", fromUID, err)
} }
} }
@ -147,15 +185,19 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
` `
// Execute the SQL query // Execute the SQL query
var userCert []byte var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCert) err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err != nil { if err == sql.ErrNoRows {
log.Panicln("Error getting user certificate from the database") log.Panicf("No certificate for UID %v found in the database", uid)
} }
return protocol.NewSendUserCertPacket(uid, userCert) //userCert,err := x509.ParseCertificate(userCertBytes)
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewSendUserCertPacket(uid, userCertBytes)
} }
func userExists(db *sql.DB, uid string) bool { func (ds DataStore) userExists(uid string) bool {
// Prepare the SQL statement for checking if a user exists // Prepare the SQL statement for checking if a user exists
query := ` query := `
SELECT COUNT(*) SELECT COUNT(*)
@ -165,11 +207,31 @@ func userExists(db *sql.DB, uid string) bool {
var count int var count int
// Execute the SQL query // Execute the SQL query
err := db.QueryRow(query, uid).Scan(&count) err := ds.db.QueryRow(query, uid).Scan(&count)
if err != nil { if err == sql.ErrNoRows {
log.Panicln("Error checking if user exists") log.Printf("User with UID %v does not exist", uid)
return false
} else {
return true
}
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
// Check if the user already exists
if ds.userExists(uid) {
log.Printf("User certificate for UID %s already exists.\n", uid)
return
} }
// If count is greater than 0, the user exists // Insert the user certificate
return count > 0 insertQuery := `
INSERT INTO users (UID, userCert)
VALUES (?, ?)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
return
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
} }

View file

@ -16,28 +16,44 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert) oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
//Get the UID of this user //Get the UID of this user
UID := oidMap["2.5.4.65"] UID := oidMap["2.5.4.65"]
if UID=="" { if UID == "" {
panic("User certificate does not specify it's PSEUDONYM") panic("User certificate does not specify it's PSEUDONYM")
} }
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F:
for { for {
var pac protocol.Packet pac, active := connection.Receive()
connection.Receive(&pac) if !active {
break F
}
switch pac.Flag { switch pac.Flag {
case protocol.ReqUserCertPkt: case protocol.ReqUserCertPkt:
fmt.Printf("Type of pac.Body: %T\n", pac.Body) reqUserCert := protocol.UnmarshalRequestUserCertPacket(pac.Body)
UserCertPacket, ok := (pac.Body).(protocol.RequestUserCertPacket) userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if !ok { if active := connection.Send(userCertPacket); !active {
panic("Could not cast packet to it's type") break F
}
case protocol.ReqMsgsQueue:
_ = protocol.UnmarshalRequestMsgsQueuePacket(pac.Body)
messages := dataStore.GetUnreadMessagesInfoQueue(UID)
fmt.Printf("Number of unread messages by user %v is %v\n",UID,len(messages))
for _, message := range messages {
if !connection.Send(message) {
break
}
} }
userCertPacket := dataStore.GetUserCertificate(UserCertPacket.UID)
connection.Send(userCertPacket)
case protocol.ReqAllMsgPkt:
fmt.Println("ReqAllMsg")
case protocol.ReqMsgPkt: case protocol.ReqMsgPkt:
fmt.Println("ReqMsg") reqMsg := protocol.UnmarshalRequestMsgPacket(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
break F
}
case protocol.SubmitMsgPkt: case protocol.SubmitMsgPkt:
fmt.Println("SubmitMsg") submitMsg := protocol.UnmarshalSubmitMessagePacket(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
}
} }
} }

View file

@ -178,7 +178,6 @@ func (k KeyStore) EncryptMessageContent(receiverCert *x509.Certificate, content
} }
func (k KeyStore) DecryptMessageContent(senderCert *x509.Certificate, cipherContent []byte) []byte { func (k KeyStore) DecryptMessageContent(senderCert *x509.Certificate, cipherContent []byte) []byte {
return nil return nil
} }

View file

@ -4,6 +4,8 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"io"
"log"
) )
type Connection[T any] struct { type Connection[T any] struct {
@ -20,16 +22,33 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
} }
} }
func (c Connection[T]) Send(obj T) { func (c Connection[T]) Send(obj T) bool {
if err := c.encoder.Encode(&obj); err != nil { if err := c.encoder.Encode(&obj); err!=nil {
panic("Failed encoding data or sending it to connection") if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return false
} else {
log.Panic(err)
} }
}
//Return true as connection active
return true
} }
func (c Connection[T]) Receive(objPtr *T) { func (c Connection[T]) Receive() (*T, bool) {
if err := c.decoder.Decode(objPtr); err != nil { var obj T
panic("Failed decoding data or reading it from connection") if err := c.decoder.Decode(&obj); err != nil {
if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return nil,false
} else {
log.Panic(err)
} }
}
//Return true as connection active
return &obj, true
} }
func (c Connection[T]) GetPeerCertificate() *x509.Certificate { func (c Connection[T]) GetPeerCertificate() *x509.Certificate {

View file

@ -43,7 +43,6 @@ func (s *Server[T]) ListenLoop() {
state := tlsConn.ConnectionState() state := tlsConn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
fmt.Println(state.PeerCertificates)
log.Panicln("Client did not provide a certificate") log.Panicln("Client did not provide a certificate")
} }
conn := NewConnection[T](tlsConn) conn := NewConnection[T](tlsConn)

BIN
Projs/PD1/server.db Normal file

Binary file not shown.

View file

@ -10,21 +10,16 @@ cmd="@@"
cmd="go build" cmd="go build"
[targets.server] [targets.server]
deps=["check"]
cmd="go run ./cmd/server/server.go" cmd="go run ./cmd/server/server.go"
[targets.client1] [targets.client1]
deps=["check"] cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CLI1 testsubject"
[targets.FakeClient1] [targets.FakeClient1]
deps=["check"] cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CL2 testsubject"
cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CLI1 testsubject"
[targets.client2] [targets.client2]
deps=["check"] cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CL3 testsubject"
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CLI1 testsubject"
[targets.client3] [targets.client3]
deps=["check"] cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CL1 testsubject"
cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CLI1 testsubject"