[PD1] Error handling project-wide

This commit is contained in:
Afonso Franco 2024-04-28 22:02:13 +01:00
parent f5b3726673
commit b918211736
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
13 changed files with 364 additions and 245 deletions

View file

@ -5,5 +5,5 @@ import (
)
func main(){
server.Run(8080)
server.Run()
}

View file

@ -5,6 +5,7 @@ import (
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"crypto/x509"
"errors"
"flag"
"log"
"sort"
@ -17,45 +18,27 @@ func Run() {
flag.Parse()
if flag.NArg() == 0 {
panic("No command provided. Use 'help' for instructions.")
log.Fatalln("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
password := readStdin("Insert keystore passphrase")
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
if err != nil {
log.Fatalln(err)
}
command := flag.Arg(0)
switch command {
case "send":
if flag.NArg() < 3 {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
log.Fatalln("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
}
uid := flag.Arg(1)
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes
plainSubjectBytes := Marshal(plainSubject)
plainBodyBytes := Marshal(plainBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
receiverCert := getUserCert(cl, clientKeyStore, uid)
if receiverCert == nil {
return
}
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(sendMsgPacket) {
return
}
answerSendMsg, active := cl.Connection.Receive()
if !active {
return
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
log.Println(reportError.ErrorMessage)
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
if err != nil {
log.Fatalln(err)
}
case "askqueue":
@ -74,41 +57,24 @@ func Run() {
}
}
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
askQueue(cl, clientKeyStore, page, pageSize)
err := askQueueCommand(clientKeyStore, page, pageSize)
if err != nil {
log.Fatalln(err)
}
case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
log.Fatalln("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString)
if err != nil {
log.Panicln("NUM argument provided is not a number")
log.Fatalln(err)
}
packet := protocol.NewGetMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
err = getMsgCommand(clientKeyStore, num)
if err != nil {
log.Fatalln(err)
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
log.Println(reportError.ErrorMessage)
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help":
showHelp()
@ -119,43 +85,152 @@ func Run() {
}
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) *x509.Certificate {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if !cl.Connection.Send(getUserCertPacket) {
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
log.Println(reportError.ErrorMessage)
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
//Turn content to bytes
plainSubjectBytes, err := Marshal(plainSubject)
if err != nil {
return nil
return err
}
if !keyStore.CheckCert(userCert, uid){
return nil
plainBodyBytes, err := Marshal(plainBody)
if err != nil {
return err
}
return userCert
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
if err != nil {
return err
}
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
if err != nil {
return err
}
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
if err != nil {
return err
}
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if err := cl.Connection.Send(sendMsgPacket); err != nil {
return err
}
answerSendMsg, err := cl.Connection.Receive()
if err != nil {
return err
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
return nil
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
packet := protocol.NewGetMsgPacket(num)
if err := cl.Connection.Send(packet); err != nil {
return err
}
receivedMsgPacket, err := cl.Connection.Receive()
if err != nil {
return err
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
if err != nil {
return err
}
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
if err != nil {
return err
}
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
if err != nil {
return err
}
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
if err != nil {
return err
}
subject, err := Unmarshal(decSubjectBytes)
if err != nil {
return err
}
body, err := Unmarshal(decBodyBytes)
if err != nil {
return err
}
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
return nil
}
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if err := cl.Connection.Send(getUserCertPacket); err != nil {
return nil, err
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, err := cl.Connection.Receive()
if err != nil {
return nil, err
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
return nil, errors.New(reportError.ErrorMessage)
}
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil, err
}
if err := keyStore.CheckCert(userCert, uid); err != nil {
return nil, err
}
return userCert, nil
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
if err != nil {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
}
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
log.Println(reportError.ErrorMessage)
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
}
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
senderSet := map[string]bool{}
@ -165,32 +240,60 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoU
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, keyStore, senderUID)
senderCert, err := getUserCert(cl, keyStore, senderUID)
if err == nil {
certificatesMap[senderUID] = senderCert
}
return answerGetUnreadMsgsInfo, certificatesMap
}
return answerGetUnreadMsgsInfo, certificatesMap, nil
}
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
return askQueueRec(cl, clientKeyStore, page, pageSize)
}
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
return err
}
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
if err != nil {
return err
}
unreadMsgsInfo, certificates := getManyMessagesInfo(cl, clientKeyStore)
var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo {
var clientMessageInfo ClientMessageInfo
senderCert, ok := certificates[message.FromUID]
if ok {
var subject string
if senderCert != nil {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject = Unmarshal(decryptedSubjectBytes)
} else {
subject = ""
if !ok {
clientMessageInfo = newClientMessageInfo(message.Num,
message.FromUID,
"",
message.Timestamp,
errors.New("certificate needed to decrypt not received"))
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage)
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
subject, err := Unmarshal(decryptedSubjectBytes)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
clientMessages = append(clientMessages, clientMessageInfo)
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
@ -200,10 +303,10 @@ func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 0:
return
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
default:
return nil
}
}

View file

@ -1,7 +1,7 @@
package client
import (
"log"
"encoding/json"
"time"
)
@ -18,29 +18,30 @@ type ClientMessageInfo struct {
FromUID string
Timestamp time.Time
Subject string
decryptError error
}
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
}
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo {
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp}
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo {
return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err}
}
func Marshal(data any) []byte {
func Marshal(data any) ([]byte, error) {
subject, err := json.Marshal(data)
if err != nil {
log.Panicf("Error when marshalling message: %v", err)
return nil, err
}
return subject
return subject, nil
}
func Unmarshal(data []byte) string {
func Unmarshal(data []byte) (string, error) {
var c string
err := json.Unmarshal(data, &c)
if err != nil {
log.Panicln("Could not unmarshal data")
return "", err
}
return c
return c, nil
}

View file

@ -3,6 +3,7 @@ package client
import (
"bufio"
"fmt"
"log"
"os"
"strings"
)
@ -34,12 +35,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
return 0
}
for _, message := range messages {
if message.Subject == "" {
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
continue
}
if message.decryptError != nil {
fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp)
log.Println(message.decryptError)
} else {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
}
fmt.Printf("Page %v/%v\n", page, numPages)
return messagesInfoPageNavigation(page, numPages)
}

View file

@ -2,7 +2,6 @@ package protocol
import (
"encoding/json"
"fmt"
"time"
)
@ -192,118 +191,118 @@ func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, t
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
}
func NewAnswerSendMsgPacket() Packet{
func NewAnswerSendMsgPacket() Packet {
//This packet has no body
return NewPacket(FlagAnswerSendMsg,nil)
return NewPacket(FlagAnswerSendMsg, nil)
}
func NewReportErrorPacket(errorMessage string) Packet {
return NewPacket(FlagReportError, NewReportError(errorMessage))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
func UnmarshalGetUserCert(data PacketBody) (GetUserCert, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return GetUserCert{}, err
}
var packet GetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
return GetUserCert{}, err
}
return packet
return packet, nil
}
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
func UnmarshalGetUnreadMsgsInfo(data PacketBody) (GetUnreadMsgsInfo, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return GetUnreadMsgsInfo{}, err
}
var packet GetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
return GetUnreadMsgsInfo{}, err
}
return packet
return packet, nil
}
func UnmarshalGetMsg(data PacketBody) GetMsg {
func UnmarshalGetMsg(data PacketBody) (GetMsg, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return GetMsg{}, err
}
var packet GetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
return GetMsg{}, err
}
return packet
return packet, nil
}
func UnmarshalSendMsg(data PacketBody) SendMsg {
func UnmarshalSendMsg(data PacketBody) (SendMsg, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return SendMsg{}, err
}
var packet SendMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
return SendMsg{}, err
}
return packet
return packet, nil
}
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
func UnmarshalAnswerGetUserCert(data PacketBody) (AnswerGetUserCert, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return AnswerGetUserCert{}, err
}
var packet AnswerGetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
return AnswerGetUserCert{}, err
}
return packet
return packet, nil
}
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
func UnmarshalUnreadMsgInfo(data PacketBody) (MsgInfo, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return MsgInfo{}, err
}
var packet MsgInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
return MsgInfo{}, err
}
return packet
return packet, nil
}
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) (AnswerGetUnreadMsgsInfo, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return AnswerGetUnreadMsgsInfo{}, err
}
var packet AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
return AnswerGetUnreadMsgsInfo{}, err
}
return packet
return packet, nil
}
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
func UnmarshalAnswerGetMsg(data PacketBody) (AnswerGetMsg, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return AnswerGetMsg{}, err
}
var packet AnswerGetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
return AnswerGetMsg{}, err
}
return packet
return packet, nil
}
func UnmarshalReportError(data PacketBody) ReportError {
func UnmarshalReportError(data PacketBody) (ReportError, error) {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
return ReportError{}, err
}
var packet ReportError
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
return ReportError{}, err
}
return packet
return packet, nil
}

View file

@ -4,6 +4,7 @@ import (
"PD1/internal/protocol"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"log"
"time"
@ -15,14 +16,17 @@ type DataStore struct {
db *sql.DB
}
func OpenDB() DataStore {
func OpenDB() (DataStore, error) {
db, err := sql.Open("sqlite3", "server.db")
if err != nil {
log.Fatalln("Error opening db file")
return DataStore{}, err
}
ds := DataStore{db: db}
ds.CreateTables()
return ds
err = ds.CreateTables()
if err != nil {
return DataStore{}, err
}
return ds, nil
}
func (ds DataStore) CreateTables() error {
@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error {
userCert BLOB
)`)
if err != nil {
fmt.Println("Error creating users table", err)
return err
}
@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error {
FOREIGN KEY(toUID) REFERENCES users(UID)
)`)
if err != nil {
fmt.Println("Error creating messages table", err)
return err
}
@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error {
END;
`)
if err != nil {
fmt.Println("Error creating trigger", err)
return err
}
@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
}
}
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err == sql.ErrNoRows {
log.Printf("No unread messages for UID %v: %v", toUID, err)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil
}
// Query to retrieve all messages from the user's queue
@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
var timestamp time.Time
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
return protocol.Packet{}, err
}
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.NewReportErrorPacket(err.Error())
return protocol.Packet{}, err
}
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool {
// Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count)
if err == sql.ErrNoRows {
log.Printf("User with UID %v does not exist", uid)
log.Println("user with UID %v does not exist", uid)
return false
} else {
return true
}
return true
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
// Check if the user already exists
if ds.userExists(uid) {
return
return nil
}
// Insert the user certificate
@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
return
return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err))
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
return nil
}

View file

@ -3,7 +3,6 @@ package server
import (
"bufio"
"fmt"
"log"
"os"
)
@ -13,7 +12,3 @@ func readStdin(message string) string {
scanner.Scan()
return scanner.Text()
}
func LogFatal(err error) {
log.Fatalln(err)
}

View file

@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
//Check if certificate usage is MSG SERVICE
usage := oidMap["2.5.4.11"]
if usage == "" {
log.Println("User certificate does not have the correct usage")
return
log.Fatalln("User certificate does not have the correct usage")
}
//Get the UID of this user
UID := oidMap["2.5.4.65"]
if UID == "" {
log.Println("User certificate does not specify it's PSEUDONYM")
log.Fatalln("User certificate does not specify it's PSEUDONYM")
}
err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
if err != nil {
log.Fatalln(err)
}
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F:
for {
pac, active := connection.Receive()
if !active {
pac, err := connection.Receive()
if err != nil {
break
}
switch pac.Flag {
case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body)
if err != nil {
log.Fatalln(err)
}
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if !connection.Send(userCertPacket) {
if err := connection.Send(userCertPacket); err != nil {
log.Fatalln(err)
break F
}
case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
if err != nil {
log.Fatalln(err)
}
var messages protocol.Packet
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
} else {
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
if err != nil {
log.Fatalln(err)
}
if !connection.Send(messages) {
break F
}
if err := connection.Send(messages); err != nil {
log.Fatalln(err)
}
case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
reqMsg, err := protocol.UnmarshalGetMsg(pac.Body)
if err != nil {
log.Fatalln(err)
}
var message protocol.Packet
if reqMsg.Num <= 0 {
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
} else {
message = dataStore.GetMessage(UID, reqMsg.Num)
}
if !connection.Send(message) {
if err := connection.Send(message); err != nil {
log.Fatalln(err)
break F
}
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
submitMsg, err := protocol.UnmarshalSendMsg(pac.Body)
if err != nil {
log.Fatalln(err)
}
var answerSendMsgPacket protocol.Packet
if submitMsg.ToUID == UID {
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user")
} else if !dataStore.userExists(submitMsg.ToUID) {
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist")
} else {
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
}
if !connection.Send(answerSendMsgPacket) {
if err := connection.Send(answerSendMsgPacket); err != nil {
log.Fatalln(err)
break F
}
}
}
}
func Run(port int) {
func Run() {
//Open connection to DB
dataStore := OpenDB()
dataStore, err := OpenDB()
if err != nil {
log.Fatalln(err)
}
defer dataStore.db.Close()
//FIX: Get the server's keystore path instead of hardcoding it
//Read server keystore
password := readStdin("Insert keystore passphrase")
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
keystorePassphrase := readStdin("Insert keystore passphrase")
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
if err != nil {
LogFatal(err)
log.Fatalln(err)
}
//Create server listener
server := networking.NewServer[protocol.Packet](&serverKeyStore, port)
server, err := networking.NewServer[protocol.Packet](&serverKeyStore)
if err != nil {
log.Fatalln(err)
}
go server.ListenLoop()
for {

View file

@ -2,7 +2,6 @@ package networking
import (
"crypto/tls"
"log"
)
@ -14,11 +13,11 @@ type Client[T any] struct {
Connection Connection[T]
}
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] {
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) {
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
if err != nil {
log.Panicln("Server connection error:\n",err)
return Client[T]{},err
}
conn := NewConnection[T](dialConn)
return Client[T]{Connection: conn}
return Client[T]{Connection: conn},nil
}

View file

@ -22,33 +22,27 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
}
}
func (c Connection[T]) Send(obj T) bool {
func (c Connection[T]) Send(obj T) error {
if err := c.encoder.Encode(&obj); err!=nil {
if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return false
} else {
log.Panic(err)
}
return err
}
//Return true as connection active
return true
return nil
}
func (c Connection[T]) Receive() (*T, bool) {
func (c Connection[T]) Receive() (*T, error) {
var obj T
if err := c.decoder.Decode(&obj); err != nil {
if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return nil,false
} else {
log.Panic(err)
}
return nil,err
}
//Return true as connection active
return &obj, true
return &obj, nil
}
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {

View file

@ -2,7 +2,6 @@ package networking
import (
"crypto/tls"
"fmt"
"log"
"net"
)
@ -16,16 +15,16 @@ type Server[T any] struct {
C chan Connection[T]
}
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] {
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) {
listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig())
listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig())
if err != nil {
log.Fatalln("Server could not bind to address")
return Server[T]{}, err
}
return Server[T]{
listener: listener,
C: make(chan Connection[T]),
}
}, nil
}
func (s *Server[T]) ListenLoop() {
@ -39,7 +38,9 @@ func (s *Server[T]) ListenLoop() {
if !ok {
log.Fatalln("Connection is not a TLS connection")
}
tlsConn.Handshake()
if err := tlsConn.Handshake(); err != nil {
log.Fatalln(err)
}
state := tlsConn.ConnectionState()
if len(state.PeerCertificates) == 0 {

Binary file not shown.

View file

@ -13,7 +13,7 @@ cmd="go build"
cmd="go run ./cmd/server/server.go"
[targets.send]
cmd="echo client1 | go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
[targets.askQueue]
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"