[PD1] Error handling project-wide
This commit is contained in:
parent
f5b3726673
commit
b918211736
13 changed files with 364 additions and 245 deletions
|
@ -5,5 +5,5 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main(){
|
func main(){
|
||||||
server.Run(8080)
|
server.Run()
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"PD1/internal/utils/cryptoUtils"
|
"PD1/internal/utils/cryptoUtils"
|
||||||
"PD1/internal/utils/networking"
|
"PD1/internal/utils/networking"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -17,45 +18,27 @@ func Run() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if flag.NArg() == 0 {
|
if flag.NArg() == 0 {
|
||||||
panic("No command provided. Use 'help' for instructions.")
|
log.Fatalln("No command provided. Use 'help' for instructions.")
|
||||||
}
|
}
|
||||||
//Get user KeyStore
|
//Get user KeyStore
|
||||||
password := readStdin("Insert keystore passphrase")
|
password := readStdin("Insert keystore passphrase")
|
||||||
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
|
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
command := flag.Arg(0)
|
command := flag.Arg(0)
|
||||||
switch command {
|
switch command {
|
||||||
case "send":
|
case "send":
|
||||||
if flag.NArg() < 3 {
|
if flag.NArg() < 3 {
|
||||||
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
|
log.Fatalln("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
|
||||||
}
|
}
|
||||||
uid := flag.Arg(1)
|
uid := flag.Arg(1)
|
||||||
plainSubject := flag.Arg(2)
|
plainSubject := flag.Arg(2)
|
||||||
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
|
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
|
||||||
//Turn content to bytes
|
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
|
||||||
plainSubjectBytes := Marshal(plainSubject)
|
if err != nil {
|
||||||
plainBodyBytes := Marshal(plainBody)
|
log.Fatalln(err)
|
||||||
|
|
||||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
|
||||||
defer cl.Connection.Conn.Close()
|
|
||||||
|
|
||||||
receiverCert := getUserCert(cl, clientKeyStore, uid)
|
|
||||||
if receiverCert == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
|
|
||||||
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
|
|
||||||
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
|
|
||||||
if !cl.Connection.Send(sendMsgPacket) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
answerSendMsg, active := cl.Connection.Receive()
|
|
||||||
if !active {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if answerSendMsg.Flag == protocol.FlagReportError {
|
|
||||||
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
|
|
||||||
log.Println(reportError.ErrorMessage)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "askqueue":
|
case "askqueue":
|
||||||
|
@ -74,41 +57,24 @@ func Run() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
err := askQueueCommand(clientKeyStore, page, pageSize)
|
||||||
defer cl.Connection.Conn.Close()
|
if err != nil {
|
||||||
askQueue(cl, clientKeyStore, page, pageSize)
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
case "getmsg":
|
case "getmsg":
|
||||||
if flag.NArg() < 2 {
|
if flag.NArg() < 2 {
|
||||||
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
|
log.Fatalln("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
|
||||||
}
|
}
|
||||||
numString := flag.Arg(1)
|
numString := flag.Arg(1)
|
||||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
|
||||||
defer cl.Connection.Conn.Close()
|
|
||||||
num, err := strconv.Atoi(numString)
|
num, err := strconv.Atoi(numString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicln("NUM argument provided is not a number")
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
packet := protocol.NewGetMsgPacket(num)
|
err = getMsgCommand(clientKeyStore, num)
|
||||||
cl.Connection.Send(packet)
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
receivedMsgPacket, active := cl.Connection.Receive()
|
|
||||||
if !active {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
|
||||||
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
|
||||||
log.Println(reportError.ErrorMessage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
|
||||||
senderCert := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
|
|
||||||
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
|
||||||
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
|
|
||||||
subject := Unmarshal(decSubjectBytes)
|
|
||||||
body := Unmarshal(decBodyBytes)
|
|
||||||
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
|
|
||||||
showMessage(message)
|
|
||||||
|
|
||||||
case "help":
|
case "help":
|
||||||
showHelp()
|
showHelp()
|
||||||
|
@ -119,43 +85,152 @@ func Run() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) *x509.Certificate {
|
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
|
||||||
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
|
//Turn content to bytes
|
||||||
if !cl.Connection.Send(getUserCertPacket) {
|
plainSubjectBytes, err := Marshal(plainSubject)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var answerGetUserCertPacket *protocol.Packet
|
|
||||||
answerGetUserCertPacket, active := cl.Connection.Receive()
|
|
||||||
if !active {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
|
||||||
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
|
||||||
log.Println(reportError.ErrorMessage)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
|
||||||
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
if !keyStore.CheckCert(userCert, uid){
|
plainBodyBytes, err := Marshal(plainBody)
|
||||||
return nil
|
if err != nil {
|
||||||
}
|
return err
|
||||||
return userCert
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||||
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
|
if err != nil {
|
||||||
if !active {
|
return err
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
}
|
||||||
|
defer cl.Connection.Conn.Close()
|
||||||
|
|
||||||
|
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
|
||||||
|
if err := cl.Connection.Send(sendMsgPacket); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
answerSendMsg, err := cl.Connection.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if answerSendMsg.Flag == protocol.FlagReportError {
|
||||||
|
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errors.New(reportError.ErrorMessage)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
|
||||||
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cl.Connection.Conn.Close()
|
||||||
|
packet := protocol.NewGetMsgPacket(num)
|
||||||
|
if err := cl.Connection.Send(packet); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedMsgPacket, err := cl.Connection.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
||||||
|
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errors.New(reportError.ErrorMessage)
|
||||||
|
}
|
||||||
|
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
subject, err := Unmarshal(decSubjectBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
body, err := Unmarshal(decBodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
|
||||||
|
showMessage(message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
|
||||||
|
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
|
||||||
|
if err := cl.Connection.Send(getUserCertPacket); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var answerGetUserCertPacket *protocol.Packet
|
||||||
|
answerGetUserCertPacket, err := cl.Connection.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
||||||
|
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, errors.New(reportError.ErrorMessage)
|
||||||
|
}
|
||||||
|
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := keyStore.CheckCert(userCert, uid); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return userCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
|
||||||
|
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
|
||||||
}
|
}
|
||||||
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
|
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
|
||||||
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
||||||
log.Println(reportError.ErrorMessage)
|
if err != nil {
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||||||
|
}
|
||||||
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
|
||||||
|
}
|
||||||
|
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
||||||
|
if err != nil {
|
||||||
|
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||||||
}
|
}
|
||||||
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
|
||||||
|
|
||||||
//Create Set of needed certificates
|
//Create Set of needed certificates
|
||||||
senderSet := map[string]bool{}
|
senderSet := map[string]bool{}
|
||||||
|
@ -165,32 +240,60 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoU
|
||||||
certificatesMap := map[string]*x509.Certificate{}
|
certificatesMap := map[string]*x509.Certificate{}
|
||||||
//Get senders' certificates
|
//Get senders' certificates
|
||||||
for senderUID := range senderSet {
|
for senderUID := range senderSet {
|
||||||
senderCert := getUserCert(cl, keyStore, senderUID)
|
senderCert, err := getUserCert(cl, keyStore, senderUID)
|
||||||
|
if err == nil {
|
||||||
certificatesMap[senderUID] = senderCert
|
certificatesMap[senderUID] = senderCert
|
||||||
}
|
}
|
||||||
return answerGetUnreadMsgsInfo, certificatesMap
|
}
|
||||||
|
return answerGetUnreadMsgsInfo, certificatesMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
|
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||||||
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||||
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
|
}
|
||||||
|
defer cl.Connection.Conn.Close()
|
||||||
|
return askQueueRec(cl, clientKeyStore, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||||||
|
|
||||||
|
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
||||||
|
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
unreadMsgsInfo, certificates := getManyMessagesInfo(cl, clientKeyStore)
|
|
||||||
var clientMessages []ClientMessageInfo
|
var clientMessages []ClientMessageInfo
|
||||||
for _, message := range unreadMsgsInfo.MessagesInfo {
|
for _, message := range unreadMsgsInfo.MessagesInfo {
|
||||||
|
var clientMessageInfo ClientMessageInfo
|
||||||
senderCert, ok := certificates[message.FromUID]
|
senderCert, ok := certificates[message.FromUID]
|
||||||
if ok {
|
if !ok {
|
||||||
var subject string
|
clientMessageInfo = newClientMessageInfo(message.Num,
|
||||||
if senderCert != nil {
|
message.FromUID,
|
||||||
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
"",
|
||||||
subject = Unmarshal(decryptedSubjectBytes)
|
message.Timestamp,
|
||||||
} else {
|
errors.New("certificate needed to decrypt not received"))
|
||||||
subject = ""
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
|
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
||||||
clientMessages = append(clientMessages, clientMessage)
|
if err != nil {
|
||||||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||||||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
subject, err := Unmarshal(decryptedSubjectBytes)
|
||||||
|
if err != nil {
|
||||||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||||||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
|
||||||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||||||
}
|
}
|
||||||
//Sort the messages
|
//Sort the messages
|
||||||
sort.Slice(clientMessages, func(i, j int) bool {
|
sort.Slice(clientMessages, func(i, j int) bool {
|
||||||
|
@ -200,10 +303,10 @@ func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.
|
||||||
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
||||||
switch action {
|
switch action {
|
||||||
case -1:
|
case -1:
|
||||||
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
||||||
case 0:
|
|
||||||
return
|
|
||||||
case 1:
|
case 1:
|
||||||
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,29 +18,30 @@ type ClientMessageInfo struct {
|
||||||
FromUID string
|
FromUID string
|
||||||
Timestamp time.Time
|
Timestamp time.Time
|
||||||
Subject string
|
Subject string
|
||||||
|
decryptError error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
|
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
|
||||||
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
|
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo {
|
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo {
|
||||||
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp}
|
return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Marshal(data any) []byte {
|
func Marshal(data any) ([]byte, error) {
|
||||||
subject, err := json.Marshal(data)
|
subject, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("Error when marshalling message: %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
return subject
|
return subject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Unmarshal(data []byte) string {
|
func Unmarshal(data []byte) (string, error) {
|
||||||
var c string
|
var c string
|
||||||
err := json.Unmarshal(data, &c)
|
err := json.Unmarshal(data, &c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicln("Could not unmarshal data")
|
return "", err
|
||||||
}
|
}
|
||||||
return c
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package client
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -34,12 +35,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
if message.Subject == "" {
|
if message.decryptError != nil {
|
||||||
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
|
fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp)
|
||||||
continue
|
log.Println(message.decryptError)
|
||||||
}
|
} else {
|
||||||
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
|
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
fmt.Printf("Page %v/%v\n", page, numPages)
|
fmt.Printf("Page %v/%v\n", page, numPages)
|
||||||
return messagesInfoPageNavigation(page, numPages)
|
return messagesInfoPageNavigation(page, numPages)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -201,109 +200,109 @@ func NewReportErrorPacket(errorMessage string) Packet {
|
||||||
return NewPacket(FlagReportError, NewReportError(errorMessage))
|
return NewPacket(FlagReportError, NewReportError(errorMessage))
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
|
func UnmarshalGetUserCert(data PacketBody) (GetUserCert, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return GetUserCert{}, err
|
||||||
}
|
}
|
||||||
var packet GetUserCert
|
var packet GetUserCert
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
|
return GetUserCert{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
|
func UnmarshalGetUnreadMsgsInfo(data PacketBody) (GetUnreadMsgsInfo, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return GetUnreadMsgsInfo{}, err
|
||||||
}
|
}
|
||||||
var packet GetUnreadMsgsInfo
|
var packet GetUnreadMsgsInfo
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
|
return GetUnreadMsgsInfo{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalGetMsg(data PacketBody) GetMsg {
|
func UnmarshalGetMsg(data PacketBody) (GetMsg, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return GetMsg{}, err
|
||||||
}
|
}
|
||||||
var packet GetMsg
|
var packet GetMsg
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
|
return GetMsg{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalSendMsg(data PacketBody) SendMsg {
|
func UnmarshalSendMsg(data PacketBody) (SendMsg, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return SendMsg{}, err
|
||||||
}
|
}
|
||||||
var packet SendMsg
|
var packet SendMsg
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
|
return SendMsg{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
|
func UnmarshalAnswerGetUserCert(data PacketBody) (AnswerGetUserCert, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return AnswerGetUserCert{}, err
|
||||||
}
|
}
|
||||||
var packet AnswerGetUserCert
|
var packet AnswerGetUserCert
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
|
return AnswerGetUserCert{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
|
func UnmarshalUnreadMsgInfo(data PacketBody) (MsgInfo, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return MsgInfo{}, err
|
||||||
}
|
}
|
||||||
var packet MsgInfo
|
var packet MsgInfo
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
|
return MsgInfo{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
|
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) (AnswerGetUnreadMsgsInfo, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return AnswerGetUnreadMsgsInfo{}, err
|
||||||
}
|
}
|
||||||
var packet AnswerGetUnreadMsgsInfo
|
var packet AnswerGetUnreadMsgsInfo
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
|
return AnswerGetUnreadMsgsInfo{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
|
func UnmarshalAnswerGetMsg(data PacketBody) (AnswerGetMsg, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return AnswerGetMsg{}, err
|
||||||
}
|
}
|
||||||
var packet AnswerGetMsg
|
var packet AnswerGetMsg
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
|
return AnswerGetMsg{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalReportError(data PacketBody) ReportError {
|
func UnmarshalReportError(data PacketBody) (ReportError, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to marshal data: %v", err))
|
return ReportError{}, err
|
||||||
}
|
}
|
||||||
var packet ReportError
|
var packet ReportError
|
||||||
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
|
return ReportError{}, err
|
||||||
}
|
}
|
||||||
return packet
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"PD1/internal/protocol"
|
"PD1/internal/protocol"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
@ -15,14 +16,17 @@ type DataStore struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenDB() DataStore {
|
func OpenDB() (DataStore, error) {
|
||||||
db, err := sql.Open("sqlite3", "server.db")
|
db, err := sql.Open("sqlite3", "server.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln("Error opening db file")
|
return DataStore{}, err
|
||||||
}
|
}
|
||||||
ds := DataStore{db: db}
|
ds := DataStore{db: db}
|
||||||
ds.CreateTables()
|
err = ds.CreateTables()
|
||||||
return ds
|
if err != nil {
|
||||||
|
return DataStore{}, err
|
||||||
|
}
|
||||||
|
return ds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ds DataStore) CreateTables() error {
|
func (ds DataStore) CreateTables() error {
|
||||||
|
@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error {
|
||||||
userCert BLOB
|
userCert BLOB
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error creating users table", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error {
|
||||||
FOREIGN KEY(toUID) REFERENCES users(UID)
|
FOREIGN KEY(toUID) REFERENCES users(UID)
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error creating messages table", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error {
|
||||||
END;
|
END;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error creating trigger", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
|
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) {
|
||||||
|
|
||||||
// Retrieve the total count of unread messages
|
// Retrieve the total count of unread messages
|
||||||
var totalCount int
|
var totalCount int
|
||||||
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
log.Printf("No unread messages for UID %v: %v", toUID, err)
|
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query to retrieve all messages from the user's queue
|
// Query to retrieve all messages from the user's queue
|
||||||
|
@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
||||||
var timestamp time.Time
|
var timestamp time.Time
|
||||||
var queuePosition, status int
|
var queuePosition, status int
|
||||||
if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil {
|
if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil {
|
||||||
panic(err)
|
return protocol.Packet{}, err
|
||||||
}
|
}
|
||||||
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
|
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
|
||||||
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
|
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
||||||
return protocol.NewReportErrorPacket(err.Error())
|
return protocol.Packet{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
||||||
currentPage := min(numberOfPages, page)
|
currentPage := min(numberOfPages, page)
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
|
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
|
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
|
||||||
|
@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool {
|
||||||
// Execute the SQL query
|
// Execute the SQL query
|
||||||
err := ds.db.QueryRow(query, uid).Scan(&count)
|
err := ds.db.QueryRow(query, uid).Scan(&count)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
log.Printf("User with UID %v does not exist", uid)
|
log.Println("user with UID %v does not exist", uid)
|
||||||
return false
|
return false
|
||||||
} else {
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
|
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
|
||||||
// Check if the user already exists
|
// Check if the user already exists
|
||||||
if ds.userExists(uid) {
|
if ds.userExists(uid) {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert the user certificate
|
// Insert the user certificate
|
||||||
|
@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate)
|
||||||
`
|
`
|
||||||
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
|
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
|
return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
log.Printf("User certificate for UID %s stored successfully.\n", uid)
|
log.Printf("User certificate for UID %s stored successfully.\n", uid)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package server
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +12,3 @@ func readStdin(message string) string {
|
||||||
scanner.Scan()
|
scanner.Scan()
|
||||||
return scanner.Text()
|
return scanner.Text()
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogFatal(err error) {
|
|
||||||
log.Fatalln(err)
|
|
||||||
}
|
|
||||||
|
|
|
@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
|
||||||
//Check if certificate usage is MSG SERVICE
|
//Check if certificate usage is MSG SERVICE
|
||||||
usage := oidMap["2.5.4.11"]
|
usage := oidMap["2.5.4.11"]
|
||||||
if usage == "" {
|
if usage == "" {
|
||||||
log.Println("User certificate does not have the correct usage")
|
log.Fatalln("User certificate does not have the correct usage")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
//Get the UID of this user
|
//Get the UID of this user
|
||||||
UID := oidMap["2.5.4.65"]
|
UID := oidMap["2.5.4.65"]
|
||||||
if UID == "" {
|
if UID == "" {
|
||||||
log.Println("User certificate does not specify it's PSEUDONYM")
|
log.Fatalln("User certificate does not specify it's PSEUDONYM")
|
||||||
|
}
|
||||||
|
err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
|
||||||
F:
|
F:
|
||||||
for {
|
for {
|
||||||
pac, active := connection.Receive()
|
pac, err := connection.Receive()
|
||||||
if !active {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
switch pac.Flag {
|
switch pac.Flag {
|
||||||
case protocol.FlagGetUserCert:
|
case protocol.FlagGetUserCert:
|
||||||
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
|
reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
||||||
if !connection.Send(userCertPacket) {
|
if err := connection.Send(userCertPacket); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
|
|
||||||
case protocol.FlagGetUnreadMsgsInfo:
|
case protocol.FlagGetUnreadMsgsInfo:
|
||||||
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
var messages protocol.Packet
|
var messages protocol.Packet
|
||||||
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
|
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
|
||||||
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
|
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
|
||||||
} else {
|
} else {
|
||||||
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
if !connection.Send(messages) {
|
|
||||||
break F
|
|
||||||
}
|
}
|
||||||
|
if err := connection.Send(messages); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
case protocol.FlagGetMsg:
|
case protocol.FlagGetMsg:
|
||||||
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
|
reqMsg, err := protocol.UnmarshalGetMsg(pac.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
var message protocol.Packet
|
var message protocol.Packet
|
||||||
if reqMsg.Num <= 0 {
|
if reqMsg.Num <= 0 {
|
||||||
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
|
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
|
||||||
} else {
|
} else {
|
||||||
message = dataStore.GetMessage(UID, reqMsg.Num)
|
message = dataStore.GetMessage(UID, reqMsg.Num)
|
||||||
}
|
}
|
||||||
if !connection.Send(message) {
|
if err := connection.Send(message); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
||||||
|
|
||||||
case protocol.FlagSendMsg:
|
case protocol.FlagSendMsg:
|
||||||
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
|
submitMsg, err := protocol.UnmarshalSendMsg(pac.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
var answerSendMsgPacket protocol.Packet
|
var answerSendMsgPacket protocol.Packet
|
||||||
if submitMsg.ToUID == UID {
|
if submitMsg.ToUID == UID {
|
||||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
|
answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user")
|
||||||
} else if !dataStore.userExists(submitMsg.ToUID) {
|
} else if !dataStore.userExists(submitMsg.ToUID) {
|
||||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
|
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist")
|
||||||
} else {
|
} else {
|
||||||
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
|
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
|
||||||
}
|
}
|
||||||
if !connection.Send(answerSendMsgPacket) {
|
if err := connection.Send(answerSendMsgPacket); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run(port int) {
|
func Run() {
|
||||||
//Open connection to DB
|
//Open connection to DB
|
||||||
dataStore := OpenDB()
|
dataStore, err := OpenDB()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
defer dataStore.db.Close()
|
defer dataStore.db.Close()
|
||||||
|
|
||||||
//FIX: Get the server's keystore path instead of hardcoding it
|
|
||||||
|
|
||||||
//Read server keystore
|
//Read server keystore
|
||||||
password := readStdin("Insert keystore passphrase")
|
keystorePassphrase := readStdin("Insert keystore passphrase")
|
||||||
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
|
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
LogFatal(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//Create server listener
|
//Create server listener
|
||||||
server := networking.NewServer[protocol.Packet](&serverKeyStore, port)
|
server, err := networking.NewServer[protocol.Packet](&serverKeyStore)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
go server.ListenLoop()
|
go server.ListenLoop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package networking
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,11 +13,11 @@ type Client[T any] struct {
|
||||||
Connection Connection[T]
|
Connection Connection[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] {
|
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) {
|
||||||
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
|
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicln("Server connection error:\n",err)
|
return Client[T]{},err
|
||||||
}
|
}
|
||||||
conn := NewConnection[T](dialConn)
|
conn := NewConnection[T](dialConn)
|
||||||
return Client[T]{Connection: conn}
|
return Client[T]{Connection: conn},nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,33 +22,27 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Connection[T]) Send(obj T) bool {
|
func (c Connection[T]) Send(obj T) error {
|
||||||
if err := c.encoder.Encode(&obj); err!=nil {
|
if err := c.encoder.Encode(&obj); err!=nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
log.Println("Connection closed by peer")
|
log.Println("Connection closed by peer")
|
||||||
//Return false as connection not active
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
//Return true as connection active
|
//Return true as connection active
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Connection[T]) Receive() (*T, bool) {
|
func (c Connection[T]) Receive() (*T, error) {
|
||||||
var obj T
|
var obj T
|
||||||
if err := c.decoder.Decode(&obj); err != nil {
|
if err := c.decoder.Decode(&obj); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
log.Println("Connection closed by peer")
|
log.Println("Connection closed by peer")
|
||||||
//Return false as connection not active
|
|
||||||
return nil,false
|
|
||||||
} else {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
}
|
||||||
|
return nil,err
|
||||||
}
|
}
|
||||||
//Return true as connection active
|
//Return true as connection active
|
||||||
return &obj, true
|
return &obj, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {
|
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package networking
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
@ -16,16 +15,16 @@ type Server[T any] struct {
|
||||||
C chan Connection[T]
|
C chan Connection[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] {
|
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) {
|
||||||
|
|
||||||
listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig())
|
listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln("Server could not bind to address")
|
return Server[T]{}, err
|
||||||
}
|
}
|
||||||
return Server[T]{
|
return Server[T]{
|
||||||
listener: listener,
|
listener: listener,
|
||||||
C: make(chan Connection[T]),
|
C: make(chan Connection[T]),
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server[T]) ListenLoop() {
|
func (s *Server[T]) ListenLoop() {
|
||||||
|
@ -39,7 +38,9 @@ func (s *Server[T]) ListenLoop() {
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Fatalln("Connection is not a TLS connection")
|
log.Fatalln("Connection is not a TLS connection")
|
||||||
}
|
}
|
||||||
tlsConn.Handshake()
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
state := tlsConn.ConnectionState()
|
state := tlsConn.ConnectionState()
|
||||||
if len(state.PeerCertificates) == 0 {
|
if len(state.PeerCertificates) == 0 {
|
||||||
|
|
Binary file not shown.
|
@ -13,7 +13,7 @@ cmd="go build"
|
||||||
cmd="go run ./cmd/server/server.go"
|
cmd="go run ./cmd/server/server.go"
|
||||||
|
|
||||||
[targets.send]
|
[targets.send]
|
||||||
cmd="echo client1 | go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
|
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
|
||||||
|
|
||||||
[targets.askQueue]
|
[targets.askQueue]
|
||||||
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"
|
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"
|
||||||
|
|
Loading…
Reference in a new issue