[PD1] small changes

This commit is contained in:
Afonso Franco 2024-04-23 11:12:18 +01:00
parent 568b6e6739
commit 2cafc3163c
Signed by: afonso
SSH key fingerprint: SHA256:aiLbdlPwXKJS5wMnghdtod0SPy8imZjlVvCyUX9DJNk
10 changed files with 160 additions and 71 deletions

1
Projs/PD1/.ignore Normal file
View file

@ -0,0 +1 @@
certs

View file

@ -20,7 +20,7 @@ func Run() {
panic("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
password := AskUserPassword()
password := readStdin("Insert keystore passphrase")
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
command := flag.Arg(0)
@ -49,7 +49,14 @@ func Run() {
if !cl.Connection.Send(sendMsgPacket) {
return
}
cl.Connection.Conn.Close()
answerSendMsg, active := cl.Connection.Receive()
if !active {
return
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
log.Println(reportError.ErrorMessage)
}
case "askqueue":
pageInput := flag.Arg(1)
@ -69,7 +76,7 @@ func Run() {
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
askQueue(cl,clientKeyStore, page, pageSize)
askQueue(cl, clientKeyStore, page, pageSize)
case "getmsg":
if flag.NArg() < 2 {
@ -89,6 +96,11 @@ func Run() {
if !active {
return
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
log.Println(reportError.ErrorMessage)
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
@ -117,6 +129,11 @@ func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certif
if !active {
return nil
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
log.Println(reportError.ErrorMessage)
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
@ -130,6 +147,11 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
log.Println(reportError.ErrorMessage)
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
@ -146,7 +168,7 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
return answerGetUnreadMsgsInfo, certificatesMap
}
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
@ -156,8 +178,13 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID]
if ok {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject := Unmarshal(decryptedSubjectBytes)
var subject string
if senderCert != nil {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject = Unmarshal(decryptedSubjectBytes)
} else {
subject = ""
}
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage)
}
@ -167,13 +194,13 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
return clientMessages[i].Num > clientMessages[j].Num
})
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize)
case 0:
return
case 1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize)
}
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 0:
return
case 1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
}
}

View file

@ -11,15 +11,6 @@ func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text()
}
func AskUserPassword() string {
fmt.Println("Enter key store password")
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text()
}
@ -43,9 +34,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
return 0
}
for _, message := range messages {
if message.Subject == "" {
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
continue
}
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
fmt.Printf("Page %v/%v\n",page,numPages)
fmt.Printf("Page %v/%v\n", page, numPages)
return messagesInfoPageNavigation(page, numPages)
}
@ -89,11 +84,9 @@ func messagesInfoPageNavigation(page int, numPages int) int {
return 0
}
func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body)
}

View file

@ -29,6 +29,12 @@ const (
// Server sends requested message
FlagAnswerGetMsg
// Server tells the client that the message was successfully sent
FlagAnswerSendMsg
// Report an error
FlagReportError
)
type (
@ -76,6 +82,10 @@ type (
Body []byte `json:"body"`
Timestamp time.Time `json:"timestamp"`
}
ReportError struct {
ErrorMessage string `json:"error"`
}
)
type PacketBody interface{}
@ -127,7 +137,7 @@ func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
}
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo}
return AnswerGetUnreadMsgsInfo{Page: page, NumPages: numPages, MessagesInfo: messagesInfo}
}
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{
@ -148,6 +158,12 @@ func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timesta
}
}
func NewReportError(errorMessage string) ReportError {
return ReportError{
ErrorMessage: errorMessage,
}
}
func NewGetUserCertPacket(UID string) Packet {
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
}
@ -169,13 +185,22 @@ func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
}
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo))
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page, numPages, messagesInfo))
}
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
}
func NewAnswerSendMsgPacket() Packet{
//This packet has no body
return NewPacket(FlagAnswerSendMsg,nil)
}
func NewReportErrorPacket(errorMessage string) Packet {
return NewPacket(FlagReportError, NewReportError(errorMessage))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
jsonData, err := json.Marshal(data)
if err != nil {
@ -270,3 +295,15 @@ func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
}
return packet
}
func UnmarshalReportError(data PacketBody) ReportError {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet ReportError
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
}
return packet
}

View file

@ -88,8 +88,10 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err != nil {
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
if err == sql.ErrNoRows {
log.Printf("No message with NUM %v for UID %v\n", position, toUID)
errorMessage := fmt.Sprintf("No message with NUM %v", position)
return protocol.NewReportErrorPacket(errorMessage)
}
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
@ -119,8 +121,9 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
if err == sql.ErrNoRows {
log.Printf("No unread messages for UID %v: %v", toUID, err)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
}
// Query to retrieve all messages from the user's queue
@ -143,7 +146,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
// Execute the query
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
log.Printf("Error getting unread messages for UID %v: %v", toUID, err)
}
defer rows.Close()
@ -161,6 +164,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.NewReportErrorPacket(err.Error())
}
numberOfPages := (totalCount + pageSize - 1) / pageSize
@ -168,7 +172,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
@ -179,7 +183,9 @@ func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg)
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil {
log.Printf("Error adding message to UID %v: %v", fromUID, err)
return protocol.NewReportErrorPacket(err.Error())
}
return protocol.NewAnswerSendMsgPacket()
}
func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
@ -193,12 +199,10 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows {
log.Panicf("No certificate for UID %v found in the database", uid)
errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid)
log.Println(errorMessage)
return protocol.NewReportErrorPacket(errorMessage)
}
//userCert,err := x509.ParseCertificate(userCertBytes)
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
}
@ -224,7 +228,6 @@ func (ds DataStore) userExists(uid string) bool {
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
// Check if the user already exists
if ds.userExists(uid) {
log.Printf("User certificate for UID %s already exists.\n", uid)
return
}

View file

@ -6,10 +6,9 @@ import (
"os"
)
func AskServerPassword() string {
fmt.Println("Enter key store password")
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
return scanner.Text()
}

View file

@ -4,12 +4,10 @@ import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"log"
)
//TODO: CREATE SERVER SIDE CHECKS FOR EVERYTHING
//TODO: LOGGING SYSTEM
//TODO: TELL THE USER THAT THE MESSAGE HAS BEEN RECEIVED BY THE SERVER
//TODO: ERROR PACKET TO SEND BACK TO USER
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
defer connection.Conn.Close()
@ -18,10 +16,16 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
clientCert := connection.GetPeerCertificate()
//Get the OID values
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
//Check if certificate usage is MSG SERVICE
usage := oidMap["2.5.4.11"]
if usage == "" {
log.Println("User certificate does not have the correct usage")
return
}
//Get the UID of this user
UID := oidMap["2.5.4.65"]
if UID == "" {
panic("User certificate does not specify it's PSEUDONYM")
log.Println("User certificate does not specify it's PSEUDONYM")
}
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F:
@ -34,31 +38,47 @@ F:
case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active {
if !connection.Send(userCertPacket) {
break F
}
case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
var messages protocol.Packet
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
} else {
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
}
if !connection.Send(messages) {
break F
}
case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
var message protocol.Packet
if reqMsg.Num <= 0 {
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
} else {
message = dataStore.GetMessage(UID, reqMsg.Num)
}
if !connection.Send(message) {
break F
}
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
var answerSendMsgPacket protocol.Packet
if submitMsg.ToUID == UID {
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
} else if !dataStore.userExists(submitMsg.ToUID) {
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
} else {
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
}
if !connection.Send(answerSendMsgPacket) {
break F
}
}
}
}
func Run(port int) {
@ -69,7 +89,7 @@ func Run(port int) {
//FIX: Get the server's keystore path instead of hardcoding it
//Read server keystore
password := AskServerPassword()
password := readStdin("Insert keystore passphrase")
serverKeyStore := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
//Create server listener

View file

@ -10,7 +10,6 @@ import (
"encoding/binary"
"errors"
//"errors"
"log"
"os"
@ -94,9 +93,25 @@ func (k *KeyStore) GetServerTLSConfig() *tls.Config {
caCertPool.AddCert(caCert)
}
tlsConfig.ClientCAs = caCertPool
//FIX: SERVER ACCEPTS CONNECTIONS WITH UNMATCHING OR
// NO CERTIFICATE, NEEDS TO BE CHANGED SOMEHOW
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientAuth = tls.RequestClientCert
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Verify the peer's certificate
opts := x509.VerifyOptions{
Roots: caCertPool,
}
for _, certBytes := range rawCerts {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return err
}
// Check if the certificate is signed by the specified CA
_, err = cert.Verify(opts)
if err != nil {
return errors.New("certificate not signed by trusted CA")
}
}
return nil
}
return tlsConfig
}

Binary file not shown.

View file

@ -12,14 +12,8 @@ cmd="go build"
[targets.server]
cmd="go run ./cmd/server/server.go"
[targets.client1]
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
[targets.queue]
cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 askqueue"
[targets.FakeClient1]
cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CL2 testsubject"
[targets.client2]
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CL3 testsubject"
[targets.client3]
cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CL1 testsubject"
[targets.send]
cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 send ${DEST} ${SUBJECT}"