diff --git a/Projs/PD1/cmd/server/server.go b/Projs/PD1/cmd/server/server.go index 0526a10..c93ba9b 100644 --- a/Projs/PD1/cmd/server/server.go +++ b/Projs/PD1/cmd/server/server.go @@ -5,5 +5,5 @@ import ( ) func main(){ - server.Run(8080) + server.Run() } diff --git a/Projs/PD1/internal/client/client.go b/Projs/PD1/internal/client/client.go index d4d2535..82a1a7f 100644 --- a/Projs/PD1/internal/client/client.go +++ b/Projs/PD1/internal/client/client.go @@ -5,6 +5,7 @@ import ( "PD1/internal/utils/cryptoUtils" "PD1/internal/utils/networking" "crypto/x509" + "errors" "flag" "log" "sort" @@ -17,45 +18,27 @@ func Run() { flag.Parse() if flag.NArg() == 0 { - panic("No command provided. Use 'help' for instructions.") + log.Fatalln("No command provided. Use 'help' for instructions.") } //Get user KeyStore password := readStdin("Insert keystore passphrase") - clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password) + clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password) + if err != nil { + log.Fatalln(err) + } command := flag.Arg(0) switch command { case "send": if flag.NArg() < 3 { - panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>") + log.Fatalln("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>") } uid := flag.Arg(1) plainSubject := flag.Arg(2) plainBody := readStdin("Enter message content (limited to 1000 bytes):") - //Turn content to bytes - plainSubjectBytes := Marshal(plainSubject) - plainBodyBytes := Marshal(plainBody) - - cl := networking.NewClient[protocol.Packet](&clientKeyStore) - defer cl.Connection.Conn.Close() - - receiverCert := getUserCert(cl, clientKeyStore, uid) - if receiverCert == nil { - return - } - subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes) - body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes) - sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body) - if !cl.Connection.Send(sendMsgPacket) { - return - } - answerSendMsg, active := cl.Connection.Receive() - if !active { - return - } - if answerSendMsg.Flag == protocol.FlagReportError { - reportError := protocol.UnmarshalReportError(answerSendMsg.Body) - log.Println(reportError.ErrorMessage) + err := sendCommand(clientKeyStore, plainSubject, plainBody, uid) + if err != nil { + log.Fatalln(err) } case "askqueue": @@ -74,41 +57,24 @@ func Run() { } } - cl := networking.NewClient[protocol.Packet](&clientKeyStore) - defer cl.Connection.Conn.Close() - askQueue(cl, clientKeyStore, page, pageSize) + err := askQueueCommand(clientKeyStore, page, pageSize) + if err != nil { + log.Fatalln(err) + } case "getmsg": if flag.NArg() < 2 { - panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>") + log.Fatalln("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>") } numString := flag.Arg(1) - cl := networking.NewClient[protocol.Packet](&clientKeyStore) - defer cl.Connection.Conn.Close() num, err := strconv.Atoi(numString) if err != nil { - log.Panicln("NUM argument provided is not a number") + log.Fatalln(err) } - packet := protocol.NewGetMsgPacket(num) - cl.Connection.Send(packet) - - receivedMsgPacket, active := cl.Connection.Receive() - if !active { - return + err = getMsgCommand(clientKeyStore, num) + if err != nil { + log.Fatalln(err) } - if receivedMsgPacket.Flag == protocol.FlagReportError { - reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body) - log.Println(reportError.ErrorMessage) - return - } - answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body) - senderCert := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID) - decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject) - decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body) - subject := Unmarshal(decSubjectBytes) - body := Unmarshal(decBodyBytes) - message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp) - showMessage(message) case "help": showHelp() @@ -119,43 +85,152 @@ func Run() { } -func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) *x509.Certificate { - getUserCertPacket := protocol.NewGetUserCertPacket(uid) - if !cl.Connection.Send(getUserCertPacket) { - return nil - } - var answerGetUserCertPacket *protocol.Packet - answerGetUserCertPacket, active := cl.Connection.Receive() - if !active { - return nil - } - if answerGetUserCertPacket.Flag == protocol.FlagReportError { - reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body) - log.Println(reportError.ErrorMessage) - return nil - } - answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body) - userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate) +func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error { + //Turn content to bytes + plainSubjectBytes, err := Marshal(plainSubject) if err != nil { - return nil + return err } - if !keyStore.CheckCert(userCert, uid){ - return nil - } - return userCert + plainBodyBytes, err := Marshal(plainBody) + if err != nil { + return err + } + + cl, err := networking.NewClient[protocol.Packet](&clientKeyStore) + if err != nil { + return err + } + defer cl.Connection.Conn.Close() + + receiverCert, err := getUserCert(cl, clientKeyStore, uid) + if err != nil { + return err + } + subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes) + if err != nil { + return err + } + body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes) + if err != nil { + return err + } + sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body) + if err := cl.Connection.Send(sendMsgPacket); err != nil { + return err + } + answerSendMsg, err := cl.Connection.Receive() + if err != nil { + return err + } + if answerSendMsg.Flag == protocol.FlagReportError { + reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body) + if err != nil { + return err + } + return errors.New(reportError.ErrorMessage) + } + return nil + } -func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) { - answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive() - if !active { - return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil +func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error { + cl, err := networking.NewClient[protocol.Packet](&clientKeyStore) + if err != nil { + return err + } + defer cl.Connection.Conn.Close() + packet := protocol.NewGetMsgPacket(num) + if err := cl.Connection.Send(packet); err != nil { + return err + } + + receivedMsgPacket, err := cl.Connection.Receive() + if err != nil { + return err + } + if receivedMsgPacket.Flag == protocol.FlagReportError { + reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body) + if err != nil { + return err + } + return errors.New(reportError.ErrorMessage) + } + answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body) + if err != nil { + return err + } + senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID) + if err != nil { + return err + } + decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject) + if err != nil { + return err + } + decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body) + if err != nil { + return err + } + subject, err := Unmarshal(decSubjectBytes) + if err != nil { + return err + } + body, err := Unmarshal(decBodyBytes) + if err != nil { + return err + } + message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp) + showMessage(message) + return nil +} + +func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) { + getUserCertPacket := protocol.NewGetUserCertPacket(uid) + if err := cl.Connection.Send(getUserCertPacket); err != nil { + return nil, err + } + var answerGetUserCertPacket *protocol.Packet + answerGetUserCertPacket, err := cl.Connection.Receive() + if err != nil { + return nil, err + } + if answerGetUserCertPacket.Flag == protocol.FlagReportError { + reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body) + if err != nil { + return nil, err + } + return nil, errors.New(reportError.ErrorMessage) + } + answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body) + if err != nil { + return nil, err + } + userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate) + if err != nil { + return nil, err + } + if err := keyStore.CheckCert(userCert, uid); err != nil { + return nil, err + } + return userCert, nil +} + +func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) { + answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive() + if err != nil { + return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err } if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError { - reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body) - log.Println(reportError.ErrorMessage) - return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil + reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body) + if err != nil { + return protocol.AnswerGetUnreadMsgsInfo{}, nil, err + } + return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage) + } + answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body) + if err != nil { + return protocol.AnswerGetUnreadMsgsInfo{}, nil, err } - answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body) //Create Set of needed certificates senderSet := map[string]bool{} @@ -165,32 +240,60 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoU certificatesMap := map[string]*x509.Certificate{} //Get senders' certificates for senderUID := range senderSet { - senderCert := getUserCert(cl, keyStore, senderUID) - certificatesMap[senderUID] = senderCert + senderCert, err := getUserCert(cl, keyStore, senderUID) + if err == nil { + certificatesMap[senderUID] = senderCert + } } - return answerGetUnreadMsgsInfo, certificatesMap + return answerGetUnreadMsgsInfo, certificatesMap, nil } -func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) { - requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize) - if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { - return +func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error { + cl, err := networking.NewClient[protocol.Packet](&clientKeyStore) + if err != nil { + return err + } + defer cl.Connection.Conn.Close() + return askQueueRec(cl, clientKeyStore, page, pageSize) +} + +func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error { + + requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize) + if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil { + return err + } + unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore) + if err != nil { + return err } - unreadMsgsInfo, certificates := getManyMessagesInfo(cl, clientKeyStore) var clientMessages []ClientMessageInfo for _, message := range unreadMsgsInfo.MessagesInfo { + var clientMessageInfo ClientMessageInfo senderCert, ok := certificates[message.FromUID] - if ok { - var subject string - if senderCert != nil { - decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) - subject = Unmarshal(decryptedSubjectBytes) - } else { - subject = "" - } - clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) - clientMessages = append(clientMessages, clientMessage) + if !ok { + clientMessageInfo = newClientMessageInfo(message.Num, + message.FromUID, + "", + message.Timestamp, + errors.New("certificate needed to decrypt not received")) + clientMessages = append(clientMessages, clientMessageInfo) + continue } + decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) + if err != nil { + clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err) + clientMessages = append(clientMessages, clientMessageInfo) + continue + } + subject, err := Unmarshal(decryptedSubjectBytes) + if err != nil { + clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err) + clientMessages = append(clientMessages, clientMessageInfo) + continue + } + clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil) + clientMessages = append(clientMessages, clientMessageInfo) } //Sort the messages sort.Slice(clientMessages, func(i, j int) bool { @@ -200,10 +303,10 @@ func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils. action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages) switch action { case -1: - askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize) - case 0: - return + return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize) case 1: - askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize) + return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize) + default: + return nil } } diff --git a/Projs/PD1/internal/client/datastore.go b/Projs/PD1/internal/client/datastore.go index a2d115c..7036d9e 100644 --- a/Projs/PD1/internal/client/datastore.go +++ b/Projs/PD1/internal/client/datastore.go @@ -1,7 +1,7 @@ package client import ( - "log" + "encoding/json" "time" ) @@ -14,33 +14,34 @@ type ClientMessage struct { } type ClientMessageInfo struct { - Num int - FromUID string - Timestamp time.Time - Subject string + Num int + FromUID string + Timestamp time.Time + Subject string + decryptError error } func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage { return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp} } -func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo { - return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp} +func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo { + return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err} } -func Marshal(data any) []byte { +func Marshal(data any) ([]byte, error) { subject, err := json.Marshal(data) if err != nil { - log.Panicf("Error when marshalling message: %v", err) + return nil, err } - return subject + return subject, nil } -func Unmarshal(data []byte) string { +func Unmarshal(data []byte) (string, error) { var c string err := json.Unmarshal(data, &c) if err != nil { - log.Panicln("Could not unmarshal data") + return "", err } - return c + return c, nil } diff --git a/Projs/PD1/internal/client/interface.go b/Projs/PD1/internal/client/interface.go index 21551c1..0a5817b 100644 --- a/Projs/PD1/internal/client/interface.go +++ b/Projs/PD1/internal/client/interface.go @@ -3,6 +3,7 @@ package client import ( "bufio" "fmt" + "log" "os" "strings" ) @@ -34,11 +35,12 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int return 0 } for _, message := range messages { - if message.Subject == "" { - fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID) - continue + if message.decryptError != nil { + fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp) + log.Println(message.decryptError) + } else { + fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) } - fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) } fmt.Printf("Page %v/%v\n", page, numPages) return messagesInfoPageNavigation(page, numPages) diff --git a/Projs/PD1/internal/protocol/protocol.go b/Projs/PD1/internal/protocol/protocol.go index 0e57d08..3aefe0e 100644 --- a/Projs/PD1/internal/protocol/protocol.go +++ b/Projs/PD1/internal/protocol/protocol.go @@ -2,7 +2,6 @@ package protocol import ( "encoding/json" - "fmt" "time" ) @@ -30,8 +29,8 @@ const ( // Server sends requested message FlagAnswerGetMsg - // Server tells the client that the message was successfully sent - FlagAnswerSendMsg + // Server tells the client that the message was successfully sent + FlagAnswerSendMsg // Report an error FlagReportError @@ -192,118 +191,118 @@ func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, t return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last)) } -func NewAnswerSendMsgPacket() Packet{ - //This packet has no body - return NewPacket(FlagAnswerSendMsg,nil) +func NewAnswerSendMsgPacket() Packet { + //This packet has no body + return NewPacket(FlagAnswerSendMsg, nil) } func NewReportErrorPacket(errorMessage string) Packet { return NewPacket(FlagReportError, NewReportError(errorMessage)) } -func UnmarshalGetUserCert(data PacketBody) GetUserCert { +func UnmarshalGetUserCert(data PacketBody) (GetUserCert, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return GetUserCert{}, err } var packet GetUserCert if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err)) + return GetUserCert{}, err } - return packet + return packet, nil } -func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo { +func UnmarshalGetUnreadMsgsInfo(data PacketBody) (GetUnreadMsgsInfo, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return GetUnreadMsgsInfo{}, err } var packet GetUnreadMsgsInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err)) + return GetUnreadMsgsInfo{}, err } - return packet + return packet, nil } -func UnmarshalGetMsg(data PacketBody) GetMsg { +func UnmarshalGetMsg(data PacketBody) (GetMsg, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return GetMsg{}, err } var packet GetMsg if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err)) + return GetMsg{}, err } - return packet + return packet, nil } -func UnmarshalSendMsg(data PacketBody) SendMsg { +func UnmarshalSendMsg(data PacketBody) (SendMsg, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return SendMsg{}, err } var packet SendMsg if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err)) + return SendMsg{}, err } - return packet + return packet, nil } -func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert { +func UnmarshalAnswerGetUserCert(data PacketBody) (AnswerGetUserCert, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return AnswerGetUserCert{}, err } var packet AnswerGetUserCert if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err)) + return AnswerGetUserCert{}, err } - return packet + return packet, nil } -func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo { +func UnmarshalUnreadMsgInfo(data PacketBody) (MsgInfo, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return MsgInfo{}, err } var packet MsgInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err)) + return MsgInfo{}, err } - return packet + return packet, nil } -func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo { +func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) (AnswerGetUnreadMsgsInfo, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return AnswerGetUnreadMsgsInfo{}, err } var packet AnswerGetUnreadMsgsInfo if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err)) + return AnswerGetUnreadMsgsInfo{}, err } - return packet + return packet, nil } -func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg { +func UnmarshalAnswerGetMsg(data PacketBody) (AnswerGetMsg, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return AnswerGetMsg{}, err } var packet AnswerGetMsg if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err)) + return AnswerGetMsg{}, err } - return packet + return packet, nil } -func UnmarshalReportError(data PacketBody) ReportError { +func UnmarshalReportError(data PacketBody) (ReportError, error) { jsonData, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal data: %v", err)) + return ReportError{}, err } var packet ReportError if err := json.Unmarshal(jsonData, &packet); err != nil { - panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err)) + return ReportError{}, err } - return packet + return packet, nil } diff --git a/Projs/PD1/internal/server/datastore.go b/Projs/PD1/internal/server/datastore.go index b578630..f55ef1d 100644 --- a/Projs/PD1/internal/server/datastore.go +++ b/Projs/PD1/internal/server/datastore.go @@ -4,6 +4,7 @@ import ( "PD1/internal/protocol" "crypto/x509" "database/sql" + "errors" "fmt" "log" "time" @@ -15,14 +16,17 @@ type DataStore struct { db *sql.DB } -func OpenDB() DataStore { +func OpenDB() (DataStore, error) { db, err := sql.Open("sqlite3", "server.db") if err != nil { - log.Fatalln("Error opening db file") + return DataStore{}, err } ds := DataStore{db: db} - ds.CreateTables() - return ds + err = ds.CreateTables() + if err != nil { + return DataStore{}, err + } + return ds, nil } func (ds DataStore) CreateTables() error { @@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error { userCert BLOB )`) if err != nil { - fmt.Println("Error creating users table", err) return err } @@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error { FOREIGN KEY(toUID) REFERENCES users(UID) )`) if err != nil { - fmt.Println("Error creating messages table", err) return err } @@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error { END; `) if err != nil { - fmt.Println("Error creating trigger", err) return err } @@ -91,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { if err == sql.ErrNoRows { log.Printf("No message with NUM %v for UID %v\n", position, toUID) errorMessage := fmt.Sprintf("No message with NUM %v", position) - return protocol.NewReportErrorPacket(errorMessage) + return protocol.NewReportErrorPacket(errorMessage) } return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) @@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { } } -func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet { +func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) { // Retrieve the total count of unread messages var totalCount int err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) if err == sql.ErrNoRows { - log.Printf("No unread messages for UID %v: %v", toUID, err) - return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}) + return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil } // Query to retrieve all messages from the user's queue @@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot var timestamp time.Time var queuePosition, status int if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil { - panic(err) + return protocol.Packet{}, err } answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp) messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo) } if err := rows.Err(); err != nil { log.Printf("Error when getting messages for UID %v: %v", toUID, err) - return protocol.NewReportErrorPacket(err.Error()) + return protocol.Packet{}, err } - numberOfPages := (totalCount + pageSize - 1) / pageSize currentPage := min(numberOfPages, page) - return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets) + return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil } func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet { @@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool { // Execute the SQL query err := ds.db.QueryRow(query, uid).Scan(&count) if err == sql.ErrNoRows { - log.Printf("User with UID %v does not exist", uid) - return false - } else { - return true + log.Println("user with UID %v does not exist", uid) + return false } + return true } -func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) { +func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error { // Check if the user already exists if ds.userExists(uid) { - return + return nil } // Insert the user certificate @@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) ` _, err := ds.db.Exec(insertQuery, uid, cert.Raw) if err != nil { - log.Printf("Error storing user certificate for UID %s: %v\n", uid, err) - return + return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err)) } log.Printf("User certificate for UID %s stored successfully.\n", uid) + return nil } diff --git a/Projs/PD1/internal/server/interface.go b/Projs/PD1/internal/server/interface.go index b2eec56..40e156f 100644 --- a/Projs/PD1/internal/server/interface.go +++ b/Projs/PD1/internal/server/interface.go @@ -3,7 +3,6 @@ package server import ( "bufio" "fmt" - "log" "os" ) @@ -13,7 +12,3 @@ func readStdin(message string) string { scanner.Scan() return scanner.Text() } - -func LogFatal(err error) { - log.Fatalln(err) -} diff --git a/Projs/PD1/internal/server/server.go b/Projs/PD1/internal/server/server.go index 37652b9..54b62da 100644 --- a/Projs/PD1/internal/server/server.go +++ b/Projs/PD1/internal/server/server.go @@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore //Check if certificate usage is MSG SERVICE usage := oidMap["2.5.4.11"] if usage == "" { - log.Println("User certificate does not have the correct usage") - return + log.Fatalln("User certificate does not have the correct usage") } //Get the UID of this user UID := oidMap["2.5.4.65"] if UID == "" { - log.Println("User certificate does not specify it's PSEUDONYM") + log.Fatalln("User certificate does not specify it's PSEUDONYM") + } + err := dataStore.storeUserCertIfNotExists(UID, *clientCert) + if err != nil { + log.Fatalln(err) } - dataStore.storeUserCertIfNotExists(UID, *clientCert) F: for { - pac, active := connection.Receive() - if !active { + pac, err := connection.Receive() + if err != nil { break } switch pac.Flag { case protocol.FlagGetUserCert: - reqUserCert := protocol.UnmarshalGetUserCert(pac.Body) + reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body) + if err != nil { + log.Fatalln(err) + } userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) - if !connection.Send(userCertPacket) { + if err := connection.Send(userCertPacket); err != nil { + log.Fatalln(err) break F } + case protocol.FlagGetUnreadMsgsInfo: - getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body) + getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body) + if err != nil { + log.Fatalln(err) + } var messages protocol.Packet if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 { messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1") } else { - messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize) + messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize) + if err != nil { + log.Fatalln(err) + } } - if !connection.Send(messages) { - break F + if err := connection.Send(messages); err != nil { + log.Fatalln(err) } + case protocol.FlagGetMsg: - reqMsg := protocol.UnmarshalGetMsg(pac.Body) + reqMsg, err := protocol.UnmarshalGetMsg(pac.Body) + if err != nil { + log.Fatalln(err) + } var message protocol.Packet if reqMsg.Num <= 0 { message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1") } else { message = dataStore.GetMessage(UID, reqMsg.Num) } - if !connection.Send(message) { + if err := connection.Send(message); err != nil { + log.Fatalln(err) break F } dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) + case protocol.FlagSendMsg: - submitMsg := protocol.UnmarshalSendMsg(pac.Body) + submitMsg, err := protocol.UnmarshalSendMsg(pac.Body) + if err != nil { + log.Fatalln(err) + } var answerSendMsgPacket protocol.Packet if submitMsg.ToUID == UID { - answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself") + answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user") } else if !dataStore.userExists(submitMsg.ToUID) { - answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database") + answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist") } else { answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg) } - if !connection.Send(answerSendMsgPacket) { + if err := connection.Send(answerSendMsgPacket); err != nil { + log.Fatalln(err) break F } } } } -func Run(port int) { +func Run() { //Open connection to DB - dataStore := OpenDB() + dataStore, err := OpenDB() + if err != nil { + log.Fatalln(err) + } defer dataStore.db.Close() - //FIX: Get the server's keystore path instead of hardcoding it - //Read server keystore - password := readStdin("Insert keystore passphrase") - serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password) + keystorePassphrase := readStdin("Insert keystore passphrase") + serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase) if err != nil { - LogFatal(err) + log.Fatalln(err) } //Create server listener - server := networking.NewServer[protocol.Packet](&serverKeyStore, port) + server, err := networking.NewServer[protocol.Packet](&serverKeyStore) + if err != nil { + log.Fatalln(err) + } go server.ListenLoop() for { diff --git a/Projs/PD1/internal/utils/networking/client.go b/Projs/PD1/internal/utils/networking/client.go index 19ebb05..cb49c4c 100644 --- a/Projs/PD1/internal/utils/networking/client.go +++ b/Projs/PD1/internal/utils/networking/client.go @@ -2,7 +2,6 @@ package networking import ( "crypto/tls" - "log" ) @@ -14,11 +13,11 @@ type Client[T any] struct { Connection Connection[T] } -func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] { +func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) { dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig()) if err != nil { - log.Panicln("Server connection error:\n",err) + return Client[T]{},err } conn := NewConnection[T](dialConn) - return Client[T]{Connection: conn} + return Client[T]{Connection: conn},nil } diff --git a/Projs/PD1/internal/utils/networking/connection.go b/Projs/PD1/internal/utils/networking/connection.go index 1e40f9b..82efba6 100644 --- a/Projs/PD1/internal/utils/networking/connection.go +++ b/Projs/PD1/internal/utils/networking/connection.go @@ -22,33 +22,27 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] { } } -func (c Connection[T]) Send(obj T) bool { +func (c Connection[T]) Send(obj T) error { if err := c.encoder.Encode(&obj); err!=nil { if err == io.EOF { log.Println("Connection closed by peer") - //Return false as connection not active - return false - } else { - log.Panic(err) - } + } + return err } //Return true as connection active - return true + return nil } -func (c Connection[T]) Receive() (*T, bool) { +func (c Connection[T]) Receive() (*T, error) { var obj T if err := c.decoder.Decode(&obj); err != nil { if err == io.EOF { log.Println("Connection closed by peer") - //Return false as connection not active - return nil,false - } else { - log.Panic(err) - } - } + } + return nil,err + } //Return true as connection active - return &obj, true + return &obj, nil } func (c Connection[T]) GetPeerCertificate() *x509.Certificate { diff --git a/Projs/PD1/internal/utils/networking/server.go b/Projs/PD1/internal/utils/networking/server.go index 16bf4a7..5a960b4 100644 --- a/Projs/PD1/internal/utils/networking/server.go +++ b/Projs/PD1/internal/utils/networking/server.go @@ -2,7 +2,6 @@ package networking import ( "crypto/tls" - "fmt" "log" "net" ) @@ -16,16 +15,16 @@ type Server[T any] struct { C chan Connection[T] } -func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] { +func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) { - listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig()) + listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig()) if err != nil { - log.Fatalln("Server could not bind to address") + return Server[T]{}, err } return Server[T]{ listener: listener, C: make(chan Connection[T]), - } + }, nil } func (s *Server[T]) ListenLoop() { @@ -39,7 +38,9 @@ func (s *Server[T]) ListenLoop() { if !ok { log.Fatalln("Connection is not a TLS connection") } - tlsConn.Handshake() + if err := tlsConn.Handshake(); err != nil { + log.Fatalln(err) + } state := tlsConn.ConnectionState() if len(state.PeerCertificates) == 0 { diff --git a/Projs/PD1/server.db b/Projs/PD1/server.db index 9e8a125..e587c74 100644 Binary files a/Projs/PD1/server.db and b/Projs/PD1/server.db differ diff --git a/Projs/PD1/tokefile.toml b/Projs/PD1/tokefile.toml index 9b456d3..9e98440 100644 --- a/Projs/PD1/tokefile.toml +++ b/Projs/PD1/tokefile.toml @@ -13,7 +13,7 @@ cmd="go build" cmd="go run ./cmd/server/server.go" [targets.send] -cmd="echo client1 | go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject" +cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject" [targets.askQueue] cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 askqueue"