[PD1] Error handling project-wide
This commit is contained in:
parent
f5b3726673
commit
b918211736
13 changed files with 364 additions and 245 deletions
|
@ -5,5 +5,5 @@ import (
|
|||
)
|
||||
|
||||
func main(){
|
||||
server.Run(8080)
|
||||
server.Run()
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"PD1/internal/utils/cryptoUtils"
|
||||
"PD1/internal/utils/networking"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"flag"
|
||||
"log"
|
||||
"sort"
|
||||
|
@ -17,45 +18,27 @@ func Run() {
|
|||
flag.Parse()
|
||||
|
||||
if flag.NArg() == 0 {
|
||||
panic("No command provided. Use 'help' for instructions.")
|
||||
log.Fatalln("No command provided. Use 'help' for instructions.")
|
||||
}
|
||||
//Get user KeyStore
|
||||
password := readStdin("Insert keystore passphrase")
|
||||
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
|
||||
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
command := flag.Arg(0)
|
||||
switch command {
|
||||
case "send":
|
||||
if flag.NArg() < 3 {
|
||||
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
|
||||
log.Fatalln("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
|
||||
}
|
||||
uid := flag.Arg(1)
|
||||
plainSubject := flag.Arg(2)
|
||||
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
|
||||
//Turn content to bytes
|
||||
plainSubjectBytes := Marshal(plainSubject)
|
||||
plainBodyBytes := Marshal(plainBody)
|
||||
|
||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
defer cl.Connection.Conn.Close()
|
||||
|
||||
receiverCert := getUserCert(cl, clientKeyStore, uid)
|
||||
if receiverCert == nil {
|
||||
return
|
||||
}
|
||||
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
|
||||
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
|
||||
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
|
||||
if !cl.Connection.Send(sendMsgPacket) {
|
||||
return
|
||||
}
|
||||
answerSendMsg, active := cl.Connection.Receive()
|
||||
if !active {
|
||||
return
|
||||
}
|
||||
if answerSendMsg.Flag == protocol.FlagReportError {
|
||||
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
|
||||
log.Println(reportError.ErrorMessage)
|
||||
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
case "askqueue":
|
||||
|
@ -74,41 +57,24 @@ func Run() {
|
|||
}
|
||||
}
|
||||
|
||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
defer cl.Connection.Conn.Close()
|
||||
askQueue(cl, clientKeyStore, page, pageSize)
|
||||
err := askQueueCommand(clientKeyStore, page, pageSize)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
case "getmsg":
|
||||
if flag.NArg() < 2 {
|
||||
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
|
||||
log.Fatalln("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
|
||||
}
|
||||
numString := flag.Arg(1)
|
||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
defer cl.Connection.Conn.Close()
|
||||
num, err := strconv.Atoi(numString)
|
||||
if err != nil {
|
||||
log.Panicln("NUM argument provided is not a number")
|
||||
log.Fatalln(err)
|
||||
}
|
||||
packet := protocol.NewGetMsgPacket(num)
|
||||
cl.Connection.Send(packet)
|
||||
|
||||
receivedMsgPacket, active := cl.Connection.Receive()
|
||||
if !active {
|
||||
return
|
||||
err = getMsgCommand(clientKeyStore, num)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
||||
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
||||
log.Println(reportError.ErrorMessage)
|
||||
return
|
||||
}
|
||||
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
||||
senderCert := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
|
||||
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
||||
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
|
||||
subject := Unmarshal(decSubjectBytes)
|
||||
body := Unmarshal(decBodyBytes)
|
||||
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
|
||||
showMessage(message)
|
||||
|
||||
case "help":
|
||||
showHelp()
|
||||
|
@ -119,43 +85,152 @@ func Run() {
|
|||
|
||||
}
|
||||
|
||||
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) *x509.Certificate {
|
||||
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
|
||||
if !cl.Connection.Send(getUserCertPacket) {
|
||||
return nil
|
||||
}
|
||||
var answerGetUserCertPacket *protocol.Packet
|
||||
answerGetUserCertPacket, active := cl.Connection.Receive()
|
||||
if !active {
|
||||
return nil
|
||||
}
|
||||
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
||||
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
||||
log.Println(reportError.ErrorMessage)
|
||||
return nil
|
||||
}
|
||||
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
||||
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
||||
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
|
||||
//Turn content to bytes
|
||||
plainSubjectBytes, err := Marshal(plainSubject)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
if !keyStore.CheckCert(userCert, uid){
|
||||
return nil
|
||||
plainBodyBytes, err := Marshal(plainBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return userCert
|
||||
|
||||
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cl.Connection.Conn.Close()
|
||||
|
||||
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
|
||||
if err := cl.Connection.Send(sendMsgPacket); err != nil {
|
||||
return err
|
||||
}
|
||||
answerSendMsg, err := cl.Connection.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if answerSendMsg.Flag == protocol.FlagReportError {
|
||||
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(reportError.ErrorMessage)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
|
||||
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
|
||||
if !active {
|
||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
||||
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
|
||||
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cl.Connection.Conn.Close()
|
||||
packet := protocol.NewGetMsgPacket(num)
|
||||
if err := cl.Connection.Send(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
receivedMsgPacket, err := cl.Connection.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
||||
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(reportError.ErrorMessage)
|
||||
}
|
||||
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subject, err := Unmarshal(decSubjectBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := Unmarshal(decBodyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
|
||||
showMessage(message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
|
||||
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
|
||||
if err := cl.Connection.Send(getUserCertPacket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var answerGetUserCertPacket *protocol.Packet
|
||||
answerGetUserCertPacket, err := cl.Connection.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
||||
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errors.New(reportError.ErrorMessage)
|
||||
}
|
||||
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := keyStore.CheckCert(userCert, uid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userCert, nil
|
||||
}
|
||||
|
||||
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
|
||||
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
|
||||
if err != nil {
|
||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
|
||||
}
|
||||
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
|
||||
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
||||
log.Println(reportError.ErrorMessage)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
||||
reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
||||
if err != nil {
|
||||
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||||
}
|
||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
|
||||
}
|
||||
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
||||
if err != nil {
|
||||
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||||
}
|
||||
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
||||
|
||||
//Create Set of needed certificates
|
||||
senderSet := map[string]bool{}
|
||||
|
@ -165,32 +240,60 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoU
|
|||
certificatesMap := map[string]*x509.Certificate{}
|
||||
//Get senders' certificates
|
||||
for senderUID := range senderSet {
|
||||
senderCert := getUserCert(cl, keyStore, senderUID)
|
||||
senderCert, err := getUserCert(cl, keyStore, senderUID)
|
||||
if err == nil {
|
||||
certificatesMap[senderUID] = senderCert
|
||||
}
|
||||
return answerGetUnreadMsgsInfo, certificatesMap
|
||||
}
|
||||
return answerGetUnreadMsgsInfo, certificatesMap, nil
|
||||
}
|
||||
|
||||
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
|
||||
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
||||
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
|
||||
return
|
||||
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||||
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cl.Connection.Conn.Close()
|
||||
return askQueueRec(cl, clientKeyStore, page, pageSize)
|
||||
}
|
||||
|
||||
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||||
|
||||
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
||||
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
|
||||
return err
|
||||
}
|
||||
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
unreadMsgsInfo, certificates := getManyMessagesInfo(cl, clientKeyStore)
|
||||
var clientMessages []ClientMessageInfo
|
||||
for _, message := range unreadMsgsInfo.MessagesInfo {
|
||||
var clientMessageInfo ClientMessageInfo
|
||||
senderCert, ok := certificates[message.FromUID]
|
||||
if ok {
|
||||
var subject string
|
||||
if senderCert != nil {
|
||||
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
||||
subject = Unmarshal(decryptedSubjectBytes)
|
||||
} else {
|
||||
subject = ""
|
||||
if !ok {
|
||||
clientMessageInfo = newClientMessageInfo(message.Num,
|
||||
message.FromUID,
|
||||
"",
|
||||
message.Timestamp,
|
||||
errors.New("certificate needed to decrypt not received"))
|
||||
clientMessages = append(clientMessages, clientMessageInfo)
|
||||
continue
|
||||
}
|
||||
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
|
||||
clientMessages = append(clientMessages, clientMessage)
|
||||
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
||||
if err != nil {
|
||||
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||||
clientMessages = append(clientMessages, clientMessageInfo)
|
||||
continue
|
||||
}
|
||||
subject, err := Unmarshal(decryptedSubjectBytes)
|
||||
if err != nil {
|
||||
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||||
clientMessages = append(clientMessages, clientMessageInfo)
|
||||
continue
|
||||
}
|
||||
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
|
||||
clientMessages = append(clientMessages, clientMessageInfo)
|
||||
}
|
||||
//Sort the messages
|
||||
sort.Slice(clientMessages, func(i, j int) bool {
|
||||
|
@ -200,10 +303,10 @@ func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.
|
|||
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
||||
switch action {
|
||||
case -1:
|
||||
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
||||
case 0:
|
||||
return
|
||||
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
||||
case 1:
|
||||
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
||||
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"log"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -18,29 +18,30 @@ type ClientMessageInfo struct {
|
|||
FromUID string
|
||||
Timestamp time.Time
|
||||
Subject string
|
||||
decryptError error
|
||||
}
|
||||
|
||||
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
|
||||
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
|
||||
}
|
||||
|
||||
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo {
|
||||
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp}
|
||||
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo {
|
||||
return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err}
|
||||
}
|
||||
|
||||
func Marshal(data any) []byte {
|
||||
func Marshal(data any) ([]byte, error) {
|
||||
subject, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Panicf("Error when marshalling message: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return subject
|
||||
return subject, nil
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte) string {
|
||||
func Unmarshal(data []byte) (string, error) {
|
||||
var c string
|
||||
err := json.Unmarshal(data, &c)
|
||||
if err != nil {
|
||||
log.Panicln("Could not unmarshal data")
|
||||
return "", err
|
||||
}
|
||||
return c
|
||||
return c, nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package client
|
|||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
@ -34,12 +35,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
|
|||
return 0
|
||||
}
|
||||
for _, message := range messages {
|
||||
if message.Subject == "" {
|
||||
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
|
||||
continue
|
||||
}
|
||||
if message.decryptError != nil {
|
||||
fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp)
|
||||
log.Println(message.decryptError)
|
||||
} else {
|
||||
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
|
||||
}
|
||||
}
|
||||
fmt.Printf("Page %v/%v\n", page, numPages)
|
||||
return messagesInfoPageNavigation(page, numPages)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package protocol
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -192,118 +191,118 @@ func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, t
|
|||
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
|
||||
}
|
||||
|
||||
func NewAnswerSendMsgPacket() Packet{
|
||||
func NewAnswerSendMsgPacket() Packet {
|
||||
//This packet has no body
|
||||
return NewPacket(FlagAnswerSendMsg,nil)
|
||||
return NewPacket(FlagAnswerSendMsg, nil)
|
||||
}
|
||||
|
||||
func NewReportErrorPacket(errorMessage string) Packet {
|
||||
return NewPacket(FlagReportError, NewReportError(errorMessage))
|
||||
}
|
||||
|
||||
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
|
||||
func UnmarshalGetUserCert(data PacketBody) (GetUserCert, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return GetUserCert{}, err
|
||||
}
|
||||
var packet GetUserCert
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
|
||||
return GetUserCert{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
|
||||
func UnmarshalGetUnreadMsgsInfo(data PacketBody) (GetUnreadMsgsInfo, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return GetUnreadMsgsInfo{}, err
|
||||
}
|
||||
var packet GetUnreadMsgsInfo
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
|
||||
return GetUnreadMsgsInfo{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalGetMsg(data PacketBody) GetMsg {
|
||||
func UnmarshalGetMsg(data PacketBody) (GetMsg, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return GetMsg{}, err
|
||||
}
|
||||
var packet GetMsg
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
|
||||
return GetMsg{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalSendMsg(data PacketBody) SendMsg {
|
||||
func UnmarshalSendMsg(data PacketBody) (SendMsg, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return SendMsg{}, err
|
||||
}
|
||||
var packet SendMsg
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
|
||||
return SendMsg{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
|
||||
func UnmarshalAnswerGetUserCert(data PacketBody) (AnswerGetUserCert, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return AnswerGetUserCert{}, err
|
||||
}
|
||||
var packet AnswerGetUserCert
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
|
||||
return AnswerGetUserCert{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
|
||||
func UnmarshalUnreadMsgInfo(data PacketBody) (MsgInfo, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return MsgInfo{}, err
|
||||
}
|
||||
var packet MsgInfo
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
|
||||
return MsgInfo{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
|
||||
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) (AnswerGetUnreadMsgsInfo, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return AnswerGetUnreadMsgsInfo{}, err
|
||||
}
|
||||
var packet AnswerGetUnreadMsgsInfo
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
|
||||
return AnswerGetUnreadMsgsInfo{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
|
||||
func UnmarshalAnswerGetMsg(data PacketBody) (AnswerGetMsg, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return AnswerGetMsg{}, err
|
||||
}
|
||||
var packet AnswerGetMsg
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
|
||||
return AnswerGetMsg{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func UnmarshalReportError(data PacketBody) ReportError {
|
||||
func UnmarshalReportError(data PacketBody) (ReportError, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||
return ReportError{}, err
|
||||
}
|
||||
var packet ReportError
|
||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
|
||||
return ReportError{}, err
|
||||
}
|
||||
return packet
|
||||
return packet, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"PD1/internal/protocol"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
@ -15,14 +16,17 @@ type DataStore struct {
|
|||
db *sql.DB
|
||||
}
|
||||
|
||||
func OpenDB() DataStore {
|
||||
func OpenDB() (DataStore, error) {
|
||||
db, err := sql.Open("sqlite3", "server.db")
|
||||
if err != nil {
|
||||
log.Fatalln("Error opening db file")
|
||||
return DataStore{}, err
|
||||
}
|
||||
ds := DataStore{db: db}
|
||||
ds.CreateTables()
|
||||
return ds
|
||||
err = ds.CreateTables()
|
||||
if err != nil {
|
||||
return DataStore{}, err
|
||||
}
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
func (ds DataStore) CreateTables() error {
|
||||
|
@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error {
|
|||
userCert BLOB
|
||||
)`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating users table", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error {
|
|||
FOREIGN KEY(toUID) REFERENCES users(UID)
|
||||
)`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating messages table", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error {
|
|||
END;
|
||||
`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating trigger", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
|
||||
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) {
|
||||
|
||||
// Retrieve the total count of unread messages
|
||||
var totalCount int
|
||||
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
||||
if err == sql.ErrNoRows {
|
||||
log.Printf("No unread messages for UID %v: %v", toUID, err)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil
|
||||
}
|
||||
|
||||
// Query to retrieve all messages from the user's queue
|
||||
|
@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
|||
var timestamp time.Time
|
||||
var queuePosition, status int
|
||||
if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil {
|
||||
panic(err)
|
||||
return protocol.Packet{}, err
|
||||
}
|
||||
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
|
||||
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
||||
return protocol.NewReportErrorPacket(err.Error())
|
||||
return protocol.Packet{}, err
|
||||
}
|
||||
|
||||
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
||||
currentPage := min(numberOfPages, page)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil
|
||||
}
|
||||
|
||||
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
|
||||
|
@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool {
|
|||
// Execute the SQL query
|
||||
err := ds.db.QueryRow(query, uid).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
log.Printf("User with UID %v does not exist", uid)
|
||||
log.Println("user with UID %v does not exist", uid)
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
|
||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
|
||||
// Check if the user already exists
|
||||
if ds.userExists(uid) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert the user certificate
|
||||
|
@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate)
|
|||
`
|
||||
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
|
||||
if err != nil {
|
||||
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
|
||||
return
|
||||
return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err))
|
||||
}
|
||||
log.Printf("User certificate for UID %s stored successfully.\n", uid)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package server
|
|||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
|
@ -13,7 +12,3 @@ func readStdin(message string) string {
|
|||
scanner.Scan()
|
||||
return scanner.Text()
|
||||
}
|
||||
|
||||
func LogFatal(err error) {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
|
|
@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
|
|||
//Check if certificate usage is MSG SERVICE
|
||||
usage := oidMap["2.5.4.11"]
|
||||
if usage == "" {
|
||||
log.Println("User certificate does not have the correct usage")
|
||||
return
|
||||
log.Fatalln("User certificate does not have the correct usage")
|
||||
}
|
||||
//Get the UID of this user
|
||||
UID := oidMap["2.5.4.65"]
|
||||
if UID == "" {
|
||||
log.Println("User certificate does not specify it's PSEUDONYM")
|
||||
log.Fatalln("User certificate does not specify it's PSEUDONYM")
|
||||
}
|
||||
err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||
F:
|
||||
for {
|
||||
pac, active := connection.Receive()
|
||||
if !active {
|
||||
pac, err := connection.Receive()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch pac.Flag {
|
||||
case protocol.FlagGetUserCert:
|
||||
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
|
||||
reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
||||
if !connection.Send(userCertPacket) {
|
||||
if err := connection.Send(userCertPacket); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
|
||||
case protocol.FlagGetUnreadMsgsInfo:
|
||||
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||
getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var messages protocol.Packet
|
||||
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
|
||||
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
|
||||
} else {
|
||||
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||
messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
if !connection.Send(messages) {
|
||||
break F
|
||||
}
|
||||
if err := connection.Send(messages); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
case protocol.FlagGetMsg:
|
||||
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
|
||||
reqMsg, err := protocol.UnmarshalGetMsg(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var message protocol.Packet
|
||||
if reqMsg.Num <= 0 {
|
||||
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
|
||||
} else {
|
||||
message = dataStore.GetMessage(UID, reqMsg.Num)
|
||||
}
|
||||
if !connection.Send(message) {
|
||||
if err := connection.Send(message); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
||||
|
||||
case protocol.FlagSendMsg:
|
||||
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
|
||||
submitMsg, err := protocol.UnmarshalSendMsg(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var answerSendMsgPacket protocol.Packet
|
||||
if submitMsg.ToUID == UID {
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user")
|
||||
} else if !dataStore.userExists(submitMsg.ToUID) {
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist")
|
||||
} else {
|
||||
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
|
||||
}
|
||||
if !connection.Send(answerSendMsgPacket) {
|
||||
if err := connection.Send(answerSendMsgPacket); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Run(port int) {
|
||||
func Run() {
|
||||
//Open connection to DB
|
||||
dataStore := OpenDB()
|
||||
dataStore, err := OpenDB()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
defer dataStore.db.Close()
|
||||
|
||||
//FIX: Get the server's keystore path instead of hardcoding it
|
||||
|
||||
//Read server keystore
|
||||
password := readStdin("Insert keystore passphrase")
|
||||
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
|
||||
keystorePassphrase := readStdin("Insert keystore passphrase")
|
||||
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
|
||||
if err != nil {
|
||||
LogFatal(err)
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
//Create server listener
|
||||
server := networking.NewServer[protocol.Packet](&serverKeyStore, port)
|
||||
server, err := networking.NewServer[protocol.Packet](&serverKeyStore)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
go server.ListenLoop()
|
||||
|
||||
for {
|
||||
|
|
|
@ -2,7 +2,6 @@ package networking
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
)
|
||||
|
||||
|
||||
|
@ -14,11 +13,11 @@ type Client[T any] struct {
|
|||
Connection Connection[T]
|
||||
}
|
||||
|
||||
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] {
|
||||
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) {
|
||||
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
|
||||
if err != nil {
|
||||
log.Panicln("Server connection error:\n",err)
|
||||
return Client[T]{},err
|
||||
}
|
||||
conn := NewConnection[T](dialConn)
|
||||
return Client[T]{Connection: conn}
|
||||
return Client[T]{Connection: conn},nil
|
||||
}
|
||||
|
|
|
@ -22,33 +22,27 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
|
|||
}
|
||||
}
|
||||
|
||||
func (c Connection[T]) Send(obj T) bool {
|
||||
func (c Connection[T]) Send(obj T) error {
|
||||
if err := c.encoder.Encode(&obj); err!=nil {
|
||||
if err == io.EOF {
|
||||
log.Println("Connection closed by peer")
|
||||
//Return false as connection not active
|
||||
return false
|
||||
} else {
|
||||
log.Panic(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
//Return true as connection active
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Connection[T]) Receive() (*T, bool) {
|
||||
func (c Connection[T]) Receive() (*T, error) {
|
||||
var obj T
|
||||
if err := c.decoder.Decode(&obj); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Println("Connection closed by peer")
|
||||
//Return false as connection not active
|
||||
return nil,false
|
||||
} else {
|
||||
log.Panic(err)
|
||||
}
|
||||
return nil,err
|
||||
}
|
||||
//Return true as connection active
|
||||
return &obj, true
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {
|
||||
|
|
|
@ -2,7 +2,6 @@ package networking
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
@ -16,16 +15,16 @@ type Server[T any] struct {
|
|||
C chan Connection[T]
|
||||
}
|
||||
|
||||
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] {
|
||||
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) {
|
||||
|
||||
listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig())
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig())
|
||||
if err != nil {
|
||||
log.Fatalln("Server could not bind to address")
|
||||
return Server[T]{}, err
|
||||
}
|
||||
return Server[T]{
|
||||
listener: listener,
|
||||
C: make(chan Connection[T]),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server[T]) ListenLoop() {
|
||||
|
@ -39,7 +38,9 @@ func (s *Server[T]) ListenLoop() {
|
|||
if !ok {
|
||||
log.Fatalln("Connection is not a TLS connection")
|
||||
}
|
||||
tlsConn.Handshake()
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
|
|
Binary file not shown.
|
@ -13,7 +13,7 @@ cmd="go build"
|
|||
cmd="go run ./cmd/server/server.go"
|
||||
|
||||
[targets.send]
|
||||
cmd="echo client1 | go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
|
||||
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
|
||||
|
||||
[targets.askQueue]
|
||||
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"
|
||||
|
|
Loading…
Reference in a new issue