CSI-ES-2324/Projs/PD1/internal/client/client.go
2024-04-23 11:12:18 +01:00

206 lines
6.4 KiB
Go

package client
import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"crypto/x509"
"flag"
"log"
"sort"
"strconv"
)
func Run() {
var userFile string
flag.StringVar(&userFile, "user", "userdata.p12", "Specify user data file")
flag.Parse()
if flag.NArg() == 0 {
panic("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
password := readStdin("Insert keystore passphrase")
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
command := flag.Arg(0)
switch command {
case "send":
if flag.NArg() < 3 {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
}
uid := flag.Arg(1)
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes
plainSubjectBytes := Marshal(plainSubject)
plainBodyBytes := Marshal(plainBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
receiverCert := getUserCert(cl, uid)
if receiverCert == nil {
return
}
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(sendMsgPacket) {
return
}
answerSendMsg, active := cl.Connection.Receive()
if !active {
return
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
log.Println(reportError.ErrorMessage)
}
case "askqueue":
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
askQueue(cl, clientKeyStore, page, pageSize)
case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString)
if err != nil {
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewGetMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
log.Println(reportError.ErrorMessage)
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help":
showHelp()
default:
commandError()
}
}
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if !cl.Connection.Send(getUserCertPacket) {
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
log.Println(reportError.ErrorMessage)
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil
}
return userCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
log.Println(reportError.ErrorMessage)
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return answerGetUnreadMsgsInfo, certificatesMap
}
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
}
unreadMsgsInfo, certificates := getManyMessagesInfo(cl)
var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID]
if ok {
var subject string
if senderCert != nil {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject = Unmarshal(decryptedSubjectBytes)
} else {
subject = ""
}
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage)
}
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
return clientMessages[i].Num > clientMessages[j].Num
})
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 0:
return
case 1:
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
}
}