[PD1] Error handling project-wide

This commit is contained in:
Afonso Franco 2024-04-28 22:02:13 +01:00
parent f5b3726673
commit b918211736
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
13 changed files with 364 additions and 245 deletions

View file

@ -5,5 +5,5 @@ import (
) )
func main(){ func main(){
server.Run(8080) server.Run()
} }

View file

@ -5,6 +5,7 @@ import (
"PD1/internal/utils/cryptoUtils" "PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking" "PD1/internal/utils/networking"
"crypto/x509" "crypto/x509"
"errors"
"flag" "flag"
"log" "log"
"sort" "sort"
@ -17,45 +18,27 @@ func Run() {
flag.Parse() flag.Parse()
if flag.NArg() == 0 { if flag.NArg() == 0 {
panic("No command provided. Use 'help' for instructions.") log.Fatalln("No command provided. Use 'help' for instructions.")
} }
//Get user KeyStore //Get user KeyStore
password := readStdin("Insert keystore passphrase") password := readStdin("Insert keystore passphrase")
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password) clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
if err != nil {
log.Fatalln(err)
}
command := flag.Arg(0) command := flag.Arg(0)
switch command { switch command {
case "send": case "send":
if flag.NArg() < 3 { if flag.NArg() < 3 {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>") log.Fatalln("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
} }
uid := flag.Arg(1) uid := flag.Arg(1)
plainSubject := flag.Arg(2) plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):") plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
plainSubjectBytes := Marshal(plainSubject) if err != nil {
plainBodyBytes := Marshal(plainBody) log.Fatalln(err)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
receiverCert := getUserCert(cl, clientKeyStore, uid)
if receiverCert == nil {
return
}
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(sendMsgPacket) {
return
}
answerSendMsg, active := cl.Connection.Receive()
if !active {
return
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
log.Println(reportError.ErrorMessage)
} }
case "askqueue": case "askqueue":
@ -74,41 +57,24 @@ func Run() {
} }
} }
cl := networking.NewClient[protocol.Packet](&clientKeyStore) err := askQueueCommand(clientKeyStore, page, pageSize)
defer cl.Connection.Conn.Close() if err != nil {
askQueue(cl, clientKeyStore, page, pageSize) log.Fatalln(err)
}
case "getmsg": case "getmsg":
if flag.NArg() < 2 { if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>") log.Fatalln("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
} }
numString := flag.Arg(1) numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString) num, err := strconv.Atoi(numString)
if err != nil { if err != nil {
log.Panicln("NUM argument provided is not a number") log.Fatalln(err)
} }
packet := protocol.NewGetMsgPacket(num) err = getMsgCommand(clientKeyStore, num)
cl.Connection.Send(packet) if err != nil {
log.Fatalln(err)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
} }
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
log.Println(reportError.ErrorMessage)
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help": case "help":
showHelp() showHelp()
@ -119,43 +85,152 @@ func Run() {
} }
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) *x509.Certificate { func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
getUserCertPacket := protocol.NewGetUserCertPacket(uid) //Turn content to bytes
if !cl.Connection.Send(getUserCertPacket) { plainSubjectBytes, err := Marshal(plainSubject)
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
log.Println(reportError.ErrorMessage)
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil { if err != nil {
return nil return err
} }
if !keyStore.CheckCert(userCert, uid){ plainBodyBytes, err := Marshal(plainBody)
return nil if err != nil {
} return err
return userCert
} }
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) { cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive() if err != nil {
if !active { return err
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil }
defer cl.Connection.Conn.Close()
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
if err != nil {
return err
}
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
if err != nil {
return err
}
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
if err != nil {
return err
}
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if err := cl.Connection.Send(sendMsgPacket); err != nil {
return err
}
answerSendMsg, err := cl.Connection.Receive()
if err != nil {
return err
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
return nil
}
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
packet := protocol.NewGetMsgPacket(num)
if err := cl.Connection.Send(packet); err != nil {
return err
}
receivedMsgPacket, err := cl.Connection.Receive()
if err != nil {
return err
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
if err != nil {
return err
}
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
if err != nil {
return err
}
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
if err != nil {
return err
}
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
if err != nil {
return err
}
subject, err := Unmarshal(decSubjectBytes)
if err != nil {
return err
}
body, err := Unmarshal(decBodyBytes)
if err != nil {
return err
}
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
return nil
}
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if err := cl.Connection.Send(getUserCertPacket); err != nil {
return nil, err
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, err := cl.Connection.Receive()
if err != nil {
return nil, err
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
return nil, errors.New(reportError.ErrorMessage)
}
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil, err
}
if err := keyStore.CheckCert(userCert, uid); err != nil {
return nil, err
}
return userCert, nil
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
if err != nil {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
} }
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError { if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body) reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
log.Println(reportError.ErrorMessage) if err != nil {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
}
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
} }
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates //Create Set of needed certificates
senderSet := map[string]bool{} senderSet := map[string]bool{}
@ -165,32 +240,60 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoU
certificatesMap := map[string]*x509.Certificate{} certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates //Get senders' certificates
for senderUID := range senderSet { for senderUID := range senderSet {
senderCert := getUserCert(cl, keyStore, senderUID) senderCert, err := getUserCert(cl, keyStore, senderUID)
if err == nil {
certificatesMap[senderUID] = senderCert certificatesMap[senderUID] = senderCert
} }
return answerGetUnreadMsgsInfo, certificatesMap }
return answerGetUnreadMsgsInfo, certificatesMap, nil
} }
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) { func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize) cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { if err != nil {
return return err
}
defer cl.Connection.Conn.Close()
return askQueueRec(cl, clientKeyStore, page, pageSize)
}
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
return err
}
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
if err != nil {
return err
} }
unreadMsgsInfo, certificates := getManyMessagesInfo(cl, clientKeyStore)
var clientMessages []ClientMessageInfo var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo { for _, message := range unreadMsgsInfo.MessagesInfo {
var clientMessageInfo ClientMessageInfo
senderCert, ok := certificates[message.FromUID] senderCert, ok := certificates[message.FromUID]
if ok { if !ok {
var subject string clientMessageInfo = newClientMessageInfo(message.Num,
if senderCert != nil { message.FromUID,
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) "",
subject = Unmarshal(decryptedSubjectBytes) message.Timestamp,
} else { errors.New("certificate needed to decrypt not received"))
subject = "" clientMessages = append(clientMessages, clientMessageInfo)
continue
} }
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
clientMessages = append(clientMessages, clientMessage) if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
} }
subject, err := Unmarshal(decryptedSubjectBytes)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
clientMessages = append(clientMessages, clientMessageInfo)
} }
//Sort the messages //Sort the messages
sort.Slice(clientMessages, func(i, j int) bool { sort.Slice(clientMessages, func(i, j int) bool {
@ -200,10 +303,10 @@ func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages) action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action { switch action {
case -1: case -1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize) return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 0:
return
case 1: case 1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize) return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
default:
return nil
} }
} }

View file

@ -1,7 +1,7 @@
package client package client
import ( import (
"log" "encoding/json"
"time" "time"
) )
@ -18,29 +18,30 @@ type ClientMessageInfo struct {
FromUID string FromUID string
Timestamp time.Time Timestamp time.Time
Subject string Subject string
decryptError error
} }
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage { func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp} return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
} }
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo { func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo {
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp} return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err}
} }
func Marshal(data any) []byte { func Marshal(data any) ([]byte, error) {
subject, err := json.Marshal(data) subject, err := json.Marshal(data)
if err != nil { if err != nil {
log.Panicf("Error when marshalling message: %v", err) return nil, err
} }
return subject return subject, nil
} }
func Unmarshal(data []byte) string { func Unmarshal(data []byte) (string, error) {
var c string var c string
err := json.Unmarshal(data, &c) err := json.Unmarshal(data, &c)
if err != nil { if err != nil {
log.Panicln("Could not unmarshal data") return "", err
} }
return c return c, nil
} }

View file

@ -3,6 +3,7 @@ package client
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"log"
"os" "os"
"strings" "strings"
) )
@ -34,12 +35,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
return 0 return 0
} }
for _, message := range messages { for _, message := range messages {
if message.Subject == "" { if message.decryptError != nil {
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID) fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp)
continue log.Println(message.decryptError)
} } else {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
} }
}
fmt.Printf("Page %v/%v\n", page, numPages) fmt.Printf("Page %v/%v\n", page, numPages)
return messagesInfoPageNavigation(page, numPages) return messagesInfoPageNavigation(page, numPages)
} }

View file

@ -2,7 +2,6 @@ package protocol
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
) )
@ -201,109 +200,109 @@ func NewReportErrorPacket(errorMessage string) Packet {
return NewPacket(FlagReportError, NewReportError(errorMessage)) return NewPacket(FlagReportError, NewReportError(errorMessage))
} }
func UnmarshalGetUserCert(data PacketBody) GetUserCert { func UnmarshalGetUserCert(data PacketBody) (GetUserCert, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return GetUserCert{}, err
} }
var packet GetUserCert var packet GetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err)) return GetUserCert{}, err
} }
return packet return packet, nil
} }
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo { func UnmarshalGetUnreadMsgsInfo(data PacketBody) (GetUnreadMsgsInfo, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return GetUnreadMsgsInfo{}, err
} }
var packet GetUnreadMsgsInfo var packet GetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err)) return GetUnreadMsgsInfo{}, err
} }
return packet return packet, nil
} }
func UnmarshalGetMsg(data PacketBody) GetMsg { func UnmarshalGetMsg(data PacketBody) (GetMsg, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return GetMsg{}, err
} }
var packet GetMsg var packet GetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err)) return GetMsg{}, err
} }
return packet return packet, nil
} }
func UnmarshalSendMsg(data PacketBody) SendMsg { func UnmarshalSendMsg(data PacketBody) (SendMsg, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return SendMsg{}, err
} }
var packet SendMsg var packet SendMsg
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err)) return SendMsg{}, err
} }
return packet return packet, nil
} }
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert { func UnmarshalAnswerGetUserCert(data PacketBody) (AnswerGetUserCert, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return AnswerGetUserCert{}, err
} }
var packet AnswerGetUserCert var packet AnswerGetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err)) return AnswerGetUserCert{}, err
} }
return packet return packet, nil
} }
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo { func UnmarshalUnreadMsgInfo(data PacketBody) (MsgInfo, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return MsgInfo{}, err
} }
var packet MsgInfo var packet MsgInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err)) return MsgInfo{}, err
} }
return packet return packet, nil
} }
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo { func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) (AnswerGetUnreadMsgsInfo, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return AnswerGetUnreadMsgsInfo{}, err
} }
var packet AnswerGetUnreadMsgsInfo var packet AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err)) return AnswerGetUnreadMsgsInfo{}, err
} }
return packet return packet, nil
} }
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg { func UnmarshalAnswerGetMsg(data PacketBody) (AnswerGetMsg, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return AnswerGetMsg{}, err
} }
var packet AnswerGetMsg var packet AnswerGetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err)) return AnswerGetMsg{}, err
} }
return packet return packet, nil
} }
func UnmarshalReportError(data PacketBody) ReportError { func UnmarshalReportError(data PacketBody) (ReportError, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err)) return ReportError{}, err
} }
var packet ReportError var packet ReportError
if err := json.Unmarshal(jsonData, &packet); err != nil { if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err)) return ReportError{}, err
} }
return packet return packet, nil
} }

View file

@ -4,6 +4,7 @@ import (
"PD1/internal/protocol" "PD1/internal/protocol"
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
@ -15,14 +16,17 @@ type DataStore struct {
db *sql.DB db *sql.DB
} }
func OpenDB() DataStore { func OpenDB() (DataStore, error) {
db, err := sql.Open("sqlite3", "server.db") db, err := sql.Open("sqlite3", "server.db")
if err != nil { if err != nil {
log.Fatalln("Error opening db file") return DataStore{}, err
} }
ds := DataStore{db: db} ds := DataStore{db: db}
ds.CreateTables() err = ds.CreateTables()
return ds if err != nil {
return DataStore{}, err
}
return ds, nil
} }
func (ds DataStore) CreateTables() error { func (ds DataStore) CreateTables() error {
@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error {
userCert BLOB userCert BLOB
)`) )`)
if err != nil { if err != nil {
fmt.Println("Error creating users table", err)
return err return err
} }
@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error {
FOREIGN KEY(toUID) REFERENCES users(UID) FOREIGN KEY(toUID) REFERENCES users(UID)
)`) )`)
if err != nil { if err != nil {
fmt.Println("Error creating messages table", err)
return err return err
} }
@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error {
END; END;
`) `)
if err != nil { if err != nil {
fmt.Println("Error creating trigger", err)
return err return err
} }
@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
} }
} }
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet { func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) {
// Retrieve the total count of unread messages // Retrieve the total count of unread messages
var totalCount int var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
log.Printf("No unread messages for UID %v: %v", toUID, err) return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
} }
// Query to retrieve all messages from the user's queue // Query to retrieve all messages from the user's queue
@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
var timestamp time.Time var timestamp time.Time
var queuePosition, status int var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil { if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err) return protocol.Packet{}, err
} }
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp) answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo) messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err) log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.NewReportErrorPacket(err.Error()) return protocol.Packet{}, err
} }
numberOfPages := (totalCount + pageSize - 1) / pageSize numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page) currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets) return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil
} }
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet { func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool {
// Execute the SQL query // Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count) err := ds.db.QueryRow(query, uid).Scan(&count)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
log.Printf("User with UID %v does not exist", uid) log.Println("user with UID %v does not exist", uid)
return false return false
} else { }
return true return true
} }
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) { func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
// Check if the user already exists // Check if the user already exists
if ds.userExists(uid) { if ds.userExists(uid) {
return return nil
} }
// Insert the user certificate // Insert the user certificate
@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate)
` `
_, err := ds.db.Exec(insertQuery, uid, cert.Raw) _, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil { if err != nil {
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err) return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err))
return
} }
log.Printf("User certificate for UID %s stored successfully.\n", uid) log.Printf("User certificate for UID %s stored successfully.\n", uid)
return nil
} }

View file

@ -3,7 +3,6 @@ package server
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"log"
"os" "os"
) )
@ -13,7 +12,3 @@ func readStdin(message string) string {
scanner.Scan() scanner.Scan()
return scanner.Text() return scanner.Text()
} }
func LogFatal(err error) {
log.Fatalln(err)
}

View file

@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
//Check if certificate usage is MSG SERVICE //Check if certificate usage is MSG SERVICE
usage := oidMap["2.5.4.11"] usage := oidMap["2.5.4.11"]
if usage == "" { if usage == "" {
log.Println("User certificate does not have the correct usage") log.Fatalln("User certificate does not have the correct usage")
return
} }
//Get the UID of this user //Get the UID of this user
UID := oidMap["2.5.4.65"] UID := oidMap["2.5.4.65"]
if UID == "" { if UID == "" {
log.Println("User certificate does not specify it's PSEUDONYM") log.Fatalln("User certificate does not specify it's PSEUDONYM")
}
err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
if err != nil {
log.Fatalln(err)
} }
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F: F:
for { for {
pac, active := connection.Receive() pac, err := connection.Receive()
if !active { if err != nil {
break break
} }
switch pac.Flag { switch pac.Flag {
case protocol.FlagGetUserCert: case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body) reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body)
if err != nil {
log.Fatalln(err)
}
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if !connection.Send(userCertPacket) { if err := connection.Send(userCertPacket); err != nil {
log.Fatalln(err)
break F break F
} }
case protocol.FlagGetUnreadMsgsInfo: case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body) getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
if err != nil {
log.Fatalln(err)
}
var messages protocol.Packet var messages protocol.Packet
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 { if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1") messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
} else { } else {
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize) messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
if err != nil {
log.Fatalln(err)
} }
if !connection.Send(messages) {
break F
} }
if err := connection.Send(messages); err != nil {
log.Fatalln(err)
}
case protocol.FlagGetMsg: case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body) reqMsg, err := protocol.UnmarshalGetMsg(pac.Body)
if err != nil {
log.Fatalln(err)
}
var message protocol.Packet var message protocol.Packet
if reqMsg.Num <= 0 { if reqMsg.Num <= 0 {
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1") message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
} else { } else {
message = dataStore.GetMessage(UID, reqMsg.Num) message = dataStore.GetMessage(UID, reqMsg.Num)
} }
if !connection.Send(message) { if err := connection.Send(message); err != nil {
log.Fatalln(err)
break F break F
} }
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg: case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body) submitMsg, err := protocol.UnmarshalSendMsg(pac.Body)
if err != nil {
log.Fatalln(err)
}
var answerSendMsgPacket protocol.Packet var answerSendMsgPacket protocol.Packet
if submitMsg.ToUID == UID { if submitMsg.ToUID == UID {
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself") answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user")
} else if !dataStore.userExists(submitMsg.ToUID) { } else if !dataStore.userExists(submitMsg.ToUID) {
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database") answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist")
} else { } else {
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg) answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
} }
if !connection.Send(answerSendMsgPacket) { if err := connection.Send(answerSendMsgPacket); err != nil {
log.Fatalln(err)
break F break F
} }
} }
} }
} }
func Run(port int) { func Run() {
//Open connection to DB //Open connection to DB
dataStore := OpenDB() dataStore, err := OpenDB()
if err != nil {
log.Fatalln(err)
}
defer dataStore.db.Close() defer dataStore.db.Close()
//FIX: Get the server's keystore path instead of hardcoding it
//Read server keystore //Read server keystore
password := readStdin("Insert keystore passphrase") keystorePassphrase := readStdin("Insert keystore passphrase")
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password) serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
if err != nil { if err != nil {
LogFatal(err) log.Fatalln(err)
} }
//Create server listener //Create server listener
server := networking.NewServer[protocol.Packet](&serverKeyStore, port) server, err := networking.NewServer[protocol.Packet](&serverKeyStore)
if err != nil {
log.Fatalln(err)
}
go server.ListenLoop() go server.ListenLoop()
for { for {

View file

@ -2,7 +2,6 @@ package networking
import ( import (
"crypto/tls" "crypto/tls"
"log"
) )
@ -14,11 +13,11 @@ type Client[T any] struct {
Connection Connection[T] Connection Connection[T]
} }
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] { func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) {
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig()) dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
if err != nil { if err != nil {
log.Panicln("Server connection error:\n",err) return Client[T]{},err
} }
conn := NewConnection[T](dialConn) conn := NewConnection[T](dialConn)
return Client[T]{Connection: conn} return Client[T]{Connection: conn},nil
} }

View file

@ -22,33 +22,27 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
} }
} }
func (c Connection[T]) Send(obj T) bool { func (c Connection[T]) Send(obj T) error {
if err := c.encoder.Encode(&obj); err!=nil { if err := c.encoder.Encode(&obj); err!=nil {
if err == io.EOF { if err == io.EOF {
log.Println("Connection closed by peer") log.Println("Connection closed by peer")
//Return false as connection not active
return false
} else {
log.Panic(err)
} }
return err
} }
//Return true as connection active //Return true as connection active
return true return nil
} }
func (c Connection[T]) Receive() (*T, bool) { func (c Connection[T]) Receive() (*T, error) {
var obj T var obj T
if err := c.decoder.Decode(&obj); err != nil { if err := c.decoder.Decode(&obj); err != nil {
if err == io.EOF { if err == io.EOF {
log.Println("Connection closed by peer") log.Println("Connection closed by peer")
//Return false as connection not active
return nil,false
} else {
log.Panic(err)
} }
return nil,err
} }
//Return true as connection active //Return true as connection active
return &obj, true return &obj, nil
} }
func (c Connection[T]) GetPeerCertificate() *x509.Certificate { func (c Connection[T]) GetPeerCertificate() *x509.Certificate {

View file

@ -2,7 +2,6 @@ package networking
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"log" "log"
"net" "net"
) )
@ -16,16 +15,16 @@ type Server[T any] struct {
C chan Connection[T] C chan Connection[T]
} }
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] { func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) {
listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig()) listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig())
if err != nil { if err != nil {
log.Fatalln("Server could not bind to address") return Server[T]{}, err
} }
return Server[T]{ return Server[T]{
listener: listener, listener: listener,
C: make(chan Connection[T]), C: make(chan Connection[T]),
} }, nil
} }
func (s *Server[T]) ListenLoop() { func (s *Server[T]) ListenLoop() {
@ -39,7 +38,9 @@ func (s *Server[T]) ListenLoop() {
if !ok { if !ok {
log.Fatalln("Connection is not a TLS connection") log.Fatalln("Connection is not a TLS connection")
} }
tlsConn.Handshake() if err := tlsConn.Handshake(); err != nil {
log.Fatalln(err)
}
state := tlsConn.ConnectionState() state := tlsConn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {

Binary file not shown.

View file

@ -13,7 +13,7 @@ cmd="go build"
cmd="go run ./cmd/server/server.go" cmd="go run ./cmd/server/server.go"
[targets.send] [targets.send]
cmd="echo client1 | go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject" cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
[targets.askQueue] [targets.askQueue]
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue" cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"