324 lines
9 KiB
Go
324 lines
9 KiB
Go
|
package client
|
||
|
|
||
|
import (
|
||
|
"PD1/internal/protocol"
|
||
|
"PD1/internal/utils/cryptoUtils"
|
||
|
"PD1/internal/utils/networking"
|
||
|
"crypto/x509"
|
||
|
"errors"
|
||
|
"flag"
|
||
|
"log"
|
||
|
"os"
|
||
|
"sort"
|
||
|
"strconv"
|
||
|
)
|
||
|
|
||
|
func Run() {
|
||
|
var userFile string
|
||
|
flag.StringVar(&userFile, "user", "userdata.p12", "Specify user data file")
|
||
|
flag.Parse()
|
||
|
|
||
|
if flag.NArg() == 0 {
|
||
|
log.Fatalln("No command provided. Use 'help' for instructions.")
|
||
|
}
|
||
|
//Get user KeyStore
|
||
|
password := readStdin("Insert keystore passphrase")
|
||
|
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
|
||
|
if err != nil {
|
||
|
log.Fatalln(err)
|
||
|
}
|
||
|
|
||
|
command := flag.Arg(0)
|
||
|
switch command {
|
||
|
case "send":
|
||
|
if flag.NArg() != 3 {
|
||
|
printError("MSG SERVICE: command error!")
|
||
|
showHelp()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
uid := flag.Arg(1)
|
||
|
plainSubject := flag.Arg(2)
|
||
|
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
|
||
|
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
|
||
|
if err != nil {
|
||
|
log.Fatalln(err)
|
||
|
}
|
||
|
|
||
|
case "askqueue":
|
||
|
if flag.NArg() > 3 {
|
||
|
printError("MSG SERVICE: command error!")
|
||
|
showHelp()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
pageInput := flag.Arg(1)
|
||
|
page := 1
|
||
|
if pageInput != "" {
|
||
|
if val, err := strconv.Atoi(pageInput); err == nil {
|
||
|
page = max(1, val)
|
||
|
}
|
||
|
}
|
||
|
pageSizeInput := flag.Arg(2)
|
||
|
pageSize := 5
|
||
|
if pageSizeInput != "" {
|
||
|
if val, err := strconv.Atoi(pageSizeInput); err == nil {
|
||
|
pageSize = max(1, val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err := askQueueCommand(clientKeyStore, page, pageSize)
|
||
|
if err != nil {
|
||
|
log.Fatalln(err)
|
||
|
}
|
||
|
|
||
|
case "getmsg":
|
||
|
if flag.NArg() < 2 {
|
||
|
printError("MSG SERVICE: command error!")
|
||
|
showHelp()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
numString := flag.Arg(1)
|
||
|
num, err := strconv.Atoi(numString)
|
||
|
if err != nil {
|
||
|
log.Fatalln(err)
|
||
|
}
|
||
|
err = getMsgCommand(clientKeyStore, num)
|
||
|
if err != nil {
|
||
|
printError(err.Error())
|
||
|
}
|
||
|
|
||
|
case "help":
|
||
|
showHelp()
|
||
|
|
||
|
default:
|
||
|
printError("MSG SERVICE: command error!")
|
||
|
showHelp()
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
|
||
|
//Turn content to bytes
|
||
|
plainSubjectBytes, err := Marshal(plainSubject)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
plainBodyBytes, err := Marshal(plainBody)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer cl.Connection.Conn.Close()
|
||
|
|
||
|
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
|
||
|
if err := cl.Connection.Send(sendMsgPacket); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
answerSendMsg, err := cl.Connection.Receive()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if answerSendMsg.Flag == protocol.FlagReportError {
|
||
|
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return errors.New(reportError.ErrorMessage)
|
||
|
}
|
||
|
return nil
|
||
|
|
||
|
}
|
||
|
|
||
|
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
|
||
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer cl.Connection.Conn.Close()
|
||
|
packet := protocol.NewGetMsgPacket(num)
|
||
|
if err := cl.Connection.Send(packet); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
receivedMsgPacket, err := cl.Connection.Receive()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
||
|
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return errors.New(reportError.ErrorMessage)
|
||
|
}
|
||
|
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
subject, err := Unmarshal(decSubjectBytes)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
body, err := Unmarshal(decBodyBytes)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
|
||
|
showMessage(message)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
|
||
|
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
|
||
|
if err := cl.Connection.Send(getUserCertPacket); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
var answerGetUserCertPacket *protocol.Packet
|
||
|
answerGetUserCertPacket, err := cl.Connection.Receive()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
||
|
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return nil, errors.New(reportError.ErrorMessage)
|
||
|
}
|
||
|
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := keyStore.CheckCert(userCert, uid); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return userCert, nil
|
||
|
}
|
||
|
|
||
|
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
|
||
|
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
|
||
|
if err != nil {
|
||
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
|
||
|
}
|
||
|
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
|
||
|
reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
||
|
if err != nil {
|
||
|
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||
|
}
|
||
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
|
||
|
}
|
||
|
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
||
|
if err != nil {
|
||
|
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
|
||
|
}
|
||
|
|
||
|
//Create Set of needed certificates
|
||
|
senderSet := map[string]bool{}
|
||
|
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
|
||
|
senderSet[messageInfo.FromUID] = true
|
||
|
}
|
||
|
certificatesMap := map[string]*x509.Certificate{}
|
||
|
//Get senders' certificates
|
||
|
for senderUID := range senderSet {
|
||
|
senderCert, err := getUserCert(cl, keyStore, senderUID)
|
||
|
if err == nil {
|
||
|
certificatesMap[senderUID] = senderCert
|
||
|
}
|
||
|
}
|
||
|
return answerGetUnreadMsgsInfo, certificatesMap, nil
|
||
|
}
|
||
|
|
||
|
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||
|
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer cl.Connection.Conn.Close()
|
||
|
return askQueueRec(cl, clientKeyStore, page, pageSize)
|
||
|
}
|
||
|
|
||
|
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
|
||
|
|
||
|
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
||
|
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
var clientMessages []ClientMessageInfo
|
||
|
for _, message := range unreadMsgsInfo.MessagesInfo {
|
||
|
var clientMessageInfo ClientMessageInfo
|
||
|
senderCert, ok := certificates[message.FromUID]
|
||
|
if !ok {
|
||
|
clientMessageInfo = newClientMessageInfo(message.Num,
|
||
|
message.FromUID,
|
||
|
"",
|
||
|
message.Timestamp,
|
||
|
errors.New("certificate needed to decrypt not received"))
|
||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||
|
continue
|
||
|
}
|
||
|
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
||
|
if err != nil {
|
||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||
|
continue
|
||
|
}
|
||
|
subject, err := Unmarshal(decryptedSubjectBytes)
|
||
|
if err != nil {
|
||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
|
||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||
|
continue
|
||
|
}
|
||
|
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
|
||
|
clientMessages = append(clientMessages, clientMessageInfo)
|
||
|
}
|
||
|
//Sort the messages
|
||
|
sort.Slice(clientMessages, func(i, j int) bool {
|
||
|
return clientMessages[i].Num > clientMessages[j].Num
|
||
|
})
|
||
|
|
||
|
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
||
|
switch action {
|
||
|
case -1:
|
||
|
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
||
|
case 1:
|
||
|
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|