[PD1] small changes

This commit is contained in:
Afonso Franco 2024-04-23 11:12:18 +01:00
parent 568b6e6739
commit 2cafc3163c
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
10 changed files with 160 additions and 71 deletions

1
Projs/PD1/.ignore Normal file
View file

@ -0,0 +1 @@
certs

View file

@ -20,7 +20,7 @@ func Run() {
panic("No command provided. Use 'help' for instructions.") panic("No command provided. Use 'help' for instructions.")
} }
//Get user KeyStore //Get user KeyStore
password := AskUserPassword() password := readStdin("Insert keystore passphrase")
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password) clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
command := flag.Arg(0) command := flag.Arg(0)
@ -49,7 +49,14 @@ func Run() {
if !cl.Connection.Send(sendMsgPacket) { if !cl.Connection.Send(sendMsgPacket) {
return return
} }
cl.Connection.Conn.Close() answerSendMsg, active := cl.Connection.Receive()
if !active {
return
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
log.Println(reportError.ErrorMessage)
}
case "askqueue": case "askqueue":
pageInput := flag.Arg(1) pageInput := flag.Arg(1)
@ -69,7 +76,7 @@ func Run() {
cl := networking.NewClient[protocol.Packet](&clientKeyStore) cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close() defer cl.Connection.Conn.Close()
askQueue(cl,clientKeyStore, page, pageSize) askQueue(cl, clientKeyStore, page, pageSize)
case "getmsg": case "getmsg":
if flag.NArg() < 2 { if flag.NArg() < 2 {
@ -89,6 +96,11 @@ func Run() {
if !active { if !active {
return return
} }
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
log.Println(reportError.ErrorMessage)
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body) answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID) senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject) decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
@ -117,6 +129,11 @@ func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certif
if !active { if !active {
return nil return nil
} }
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
log.Println(reportError.ErrorMessage)
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body) answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate) userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil { if err != nil {
@ -130,6 +147,11 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
if !active { if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
} }
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
log.Println(reportError.ErrorMessage)
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body) answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates //Create Set of needed certificates
@ -146,7 +168,7 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
return answerGetUnreadMsgsInfo, certificatesMap return answerGetUnreadMsgsInfo, certificatesMap
} }
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) { func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize) requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) { if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return return
@ -156,8 +178,13 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
for _, message := range unreadMsgsInfo.MessagesInfo { for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID] senderCert, ok := certificates[message.FromUID]
if ok { if ok {
var subject string
if senderCert != nil {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject) decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject := Unmarshal(decryptedSubjectBytes) subject = Unmarshal(decryptedSubjectBytes)
} else {
subject = ""
}
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp) clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage) clientMessages = append(clientMessages, clientMessage)
} }
@ -170,10 +197,10 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages) action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action { switch action {
case -1: case -1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize) askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 0: case 0:
return return
case 1: case 1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize) askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
} }
} }

View file

@ -11,15 +11,6 @@ func readStdin(message string) string {
fmt.Println(message) fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
scanner.Scan() scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text()
}
func AskUserPassword() string {
fmt.Println("Enter key store password")
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text() return scanner.Text()
} }
@ -43,9 +34,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
return 0 return 0
} }
for _, message := range messages { for _, message := range messages {
if message.Subject == "" {
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
continue
}
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject) fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
} }
fmt.Printf("Page %v/%v\n",page,numPages) fmt.Printf("Page %v/%v\n", page, numPages)
return messagesInfoPageNavigation(page, numPages) return messagesInfoPageNavigation(page, numPages)
} }
@ -89,11 +84,9 @@ func messagesInfoPageNavigation(page int, numPages int) int {
return 0 return 0
} }
func showMessage(message ClientMessage) { func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID) fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID) fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject) fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body) fmt.Printf("Body: %s\n", message.Body)
} }

View file

@ -29,6 +29,12 @@ const (
// Server sends requested message // Server sends requested message
FlagAnswerGetMsg FlagAnswerGetMsg
// Server tells the client that the message was successfully sent
FlagAnswerSendMsg
// Report an error
FlagReportError
) )
type ( type (
@ -76,6 +82,10 @@ type (
Body []byte `json:"body"` Body []byte `json:"body"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
} }
ReportError struct {
ErrorMessage string `json:"error"`
}
) )
type PacketBody interface{} type PacketBody interface{}
@ -127,7 +137,7 @@ func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
} }
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo { func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo} return AnswerGetUnreadMsgsInfo{Page: page, NumPages: numPages, MessagesInfo: messagesInfo}
} }
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo { func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{ return MsgInfo{
@ -148,6 +158,12 @@ func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timesta
} }
} }
func NewReportError(errorMessage string) ReportError {
return ReportError{
ErrorMessage: errorMessage,
}
}
func NewGetUserCertPacket(UID string) Packet { func NewGetUserCertPacket(UID string) Packet {
return NewPacket(FlagGetUserCert, NewGetUserCert(UID)) return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
} }
@ -169,13 +185,22 @@ func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
} }
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet { func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo)) return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page, numPages, messagesInfo))
} }
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet { func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last)) return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
} }
func NewAnswerSendMsgPacket() Packet{
//This packet has no body
return NewPacket(FlagAnswerSendMsg,nil)
}
func NewReportErrorPacket(errorMessage string) Packet {
return NewPacket(FlagReportError, NewReportError(errorMessage))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert { func UnmarshalGetUserCert(data PacketBody) GetUserCert {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
@ -270,3 +295,15 @@ func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
} }
return packet return packet
} }
func UnmarshalReportError(data PacketBody) ReportError {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ReportError
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
}
return packet
}

View file

@ -88,8 +88,10 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
// Execute the query // Execute the query
row := ds.db.QueryRow(query, toUID, position) row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp) err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err != nil { if err == sql.ErrNoRows {
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err) log.Printf("No message with NUM %v for UID %v\n", position, toUID)
errorMessage := fmt.Sprintf("No message with NUM %v", position)
return protocol.NewReportErrorPacket(errorMessage)
} }
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
@ -119,8 +121,9 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
// Retrieve the total count of unread messages // Retrieve the total count of unread messages
var totalCount int var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil { if err == sql.ErrNoRows {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err) log.Printf("No unread messages for UID %v: %v", toUID, err)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
} }
// Query to retrieve all messages from the user's queue // Query to retrieve all messages from the user's queue
@ -143,7 +146,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
// Execute the query // Execute the query
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize) rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil { if err != nil {
log.Printf("Error getting all messages for UID %v: %v", toUID, err) log.Printf("Error getting unread messages for UID %v: %v", toUID, err)
} }
defer rows.Close() defer rows.Close()
@ -161,6 +164,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err) log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.NewReportErrorPacket(err.Error())
} }
numberOfPages := (totalCount + pageSize - 1) / pageSize numberOfPages := (totalCount + pageSize - 1) / pageSize
@ -168,7 +172,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets) return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
} }
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) { func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
query := ` query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0) VALUES (?, ?, ?, ?, ?, 0)
@ -179,7 +183,9 @@ func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg)
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime) _, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil { if err != nil {
log.Printf("Error adding message to UID %v: %v", fromUID, err) log.Printf("Error adding message to UID %v: %v", fromUID, err)
return protocol.NewReportErrorPacket(err.Error())
} }
return protocol.NewAnswerSendMsgPacket()
} }
func (ds DataStore) GetUserCertificate(uid string) protocol.Packet { func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
@ -193,12 +199,10 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
var userCertBytes []byte var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes) err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
log.Panicf("No certificate for UID %v found in the database", uid) errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid)
log.Println(errorMessage)
return protocol.NewReportErrorPacket(errorMessage)
} }
//userCert,err := x509.ParseCertificate(userCertBytes)
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes) return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
} }
@ -224,7 +228,6 @@ func (ds DataStore) userExists(uid string) bool {
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) { func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
// Check if the user already exists // Check if the user already exists
if ds.userExists(uid) { if ds.userExists(uid) {
log.Printf("User certificate for UID %s already exists.\n", uid)
return return
} }

View file

@ -6,10 +6,9 @@ import (
"os" "os"
) )
func AskServerPassword() string { func readStdin(message string) string {
fmt.Println("Enter key store password") fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
scanner.Scan() scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text() return scanner.Text()
} }

View file

@ -4,12 +4,10 @@ import (
"PD1/internal/protocol" "PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils" "PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking" "PD1/internal/utils/networking"
"log"
) )
//TODO: CREATE SERVER SIDE CHECKS FOR EVERYTHING
//TODO: LOGGING SYSTEM //TODO: LOGGING SYSTEM
//TODO: TELL THE USER THAT THE MESSAGE HAS BEEN RECEIVED BY THE SERVER
//TODO: ERROR PACKET TO SEND BACK TO USER
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) { func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
defer connection.Conn.Close() defer connection.Conn.Close()
@ -18,10 +16,16 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
clientCert := connection.GetPeerCertificate() clientCert := connection.GetPeerCertificate()
//Get the OID values //Get the OID values
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert) oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
//Check if certificate usage is MSG SERVICE
usage := oidMap["2.5.4.11"]
if usage == "" {
log.Println("User certificate does not have the correct usage")
return
}
//Get the UID of this user //Get the UID of this user
UID := oidMap["2.5.4.65"] UID := oidMap["2.5.4.65"]
if UID == "" { if UID == "" {
panic("User certificate does not specify it's PSEUDONYM") log.Println("User certificate does not specify it's PSEUDONYM")
} }
dataStore.storeUserCertIfNotExists(UID, *clientCert) dataStore.storeUserCertIfNotExists(UID, *clientCert)
F: F:
@ -34,31 +38,47 @@ F:
case protocol.FlagGetUserCert: case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body) reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID) userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active { if !connection.Send(userCertPacket) {
break F break F
} }
case protocol.FlagGetUnreadMsgsInfo: case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body) getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize) var messages protocol.Packet
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
} else {
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
}
if !connection.Send(messages) { if !connection.Send(messages) {
break F break F
} }
case protocol.FlagGetMsg: case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body) reqMsg := protocol.UnmarshalGetMsg(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num) var message protocol.Packet
if active := connection.Send(message); !active { if reqMsg.Num <= 0 {
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
} else {
message = dataStore.GetMessage(UID, reqMsg.Num)
}
if !connection.Send(message) {
break F break F
} }
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num) dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg: case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body) submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) { var answerSendMsgPacket protocol.Packet
dataStore.AddMessageToQueue(UID, submitMsg) if submitMsg.ToUID == UID {
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
} else if !dataStore.userExists(submitMsg.ToUID) {
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
} else {
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
}
if !connection.Send(answerSendMsgPacket) {
break F
} }
} }
} }
} }
func Run(port int) { func Run(port int) {
@ -69,7 +89,7 @@ func Run(port int) {
//FIX: Get the server's keystore path instead of hardcoding it //FIX: Get the server's keystore path instead of hardcoding it
//Read server keystore //Read server keystore
password := AskServerPassword() password := readStdin("Insert keystore passphrase")
serverKeyStore := cryptoUtils.LoadKeyStore("certs/server/server.p12", password) serverKeyStore := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
//Create server listener //Create server listener

View file

@ -10,7 +10,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
//"errors"
"log" "log"
"os" "os"
@ -94,9 +93,25 @@ func (k *KeyStore) GetServerTLSConfig() *tls.Config {
caCertPool.AddCert(caCert) caCertPool.AddCert(caCert)
} }
tlsConfig.ClientCAs = caCertPool tlsConfig.ClientCAs = caCertPool
//FIX: SERVER ACCEPTS CONNECTIONS WITH UNMATCHING OR tlsConfig.ClientAuth = tls.RequestClientCert
// NO CERTIFICATE, NEEDS TO BE CHANGED SOMEHOW tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Verify the peer's certificate
opts := x509.VerifyOptions{
Roots: caCertPool,
}
for _, certBytes := range rawCerts {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return err
}
// Check if the certificate is signed by the specified CA
_, err = cert.Verify(opts)
if err != nil {
return errors.New("certificate not signed by trusted CA")
}
}
return nil
}
return tlsConfig return tlsConfig
} }

Binary file not shown.

View file

@ -12,14 +12,8 @@ cmd="go build"
[targets.server] [targets.server]
cmd="go run ./cmd/server/server.go" cmd="go run ./cmd/server/server.go"
[targets.client1] [targets.queue]
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject" cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 askqueue"
[targets.FakeClient1] [targets.send]
cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CL2 testsubject" cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 send ${DEST} ${SUBJECT}"
[targets.client2]
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CL3 testsubject"
[targets.client3]
cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CL1 testsubject"