CSI-ES-2324/Projs/PD2/internal/client/client.go

605 lines
16 KiB
Go

package client
import (
"PD2/internal/protocol"
"PD2/internal/utils/cryptoUtils"
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"sort"
"strconv"
)
const baseURL = "https://127.0.0.1:9090"
const tokenFolder = "token/"
func Run() {
var userFile string
flag.StringVar(&userFile, "user", "userdata.p12", "Specify user data file")
flag.Parse()
if flag.NArg() == 0 {
log.Fatalln("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
password := readStdin("Insert keystore passphrase")
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
if err != nil {
log.Fatalln(err)
}
oidMap := cryptoUtils.ExtractAllOIDValues(clientKeyStore.GetCert())
myUID := oidMap["2.5.4.65"]
if myUID == "" {
printError("no pseudonym field on my certificate")
os.Exit(1)
}
command := flag.Arg(0)
// Check if token is in memory
var token string
if command != "login" && command != "register" {
tokenFile, err := os.ReadFile(tokenFolder + myUID)
if err != nil {
printError(err.Error())
os.Exit(1)
}
token := string(tokenFile)
if token == "" {
printError("MSG SERVICE: token read error")
os.Exit(1)
}
}
switch command {
case "send":
if flag.NArg() != 3 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
uid := flag.Arg(1)
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid, token)
if err != nil {
log.Fatalln(err)
}
case "askqueue":
if flag.NArg() > 3 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
page, pageSize, listClientMessageInfo, err := askQueueCommand(clientKeyStore, page, pageSize, token)
if err != nil {
log.Fatalln(err)
}
showMessagesInfo(page, pageSize, listClientMessageInfo)
case "getmsg":
if flag.NArg() < 2 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
numString := flag.Arg(1)
num, err := strconv.Atoi(numString)
if err != nil {
log.Fatalln(err)
}
msg, err := getMsgCommand(clientKeyStore, num, token)
if err != nil {
printError(err.Error())
os.Exit(1)
}
showMessage(msg)
case "register":
// call register
if flag.NArg() > 2 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
userId := flag.Arg(1)
if userId == "" {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
password := readStdin("Enter password: ")
if password == "" {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
passwordConfirmation := readStdin("Confirm password: ")
if password != passwordConfirmation {
printError("MSG SERVICE: passwords do not match")
os.Exit(1)
}
err := registerUser(userId, password, clientKeyStore)
if err != nil {
printError(err.Error())
os.Exit(1)
}
// TODO: print register successful
case "login":
if flag.NArg() != 1 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
password := readStdin("Enter password: ")
if password == "" {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
token, err := login(myUID, password, clientKeyStore)
if err != nil {
printError(err.Error())
os.Exit(1)
}
tokenFile, err := os.OpenFile(tokenFolder+myUID, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
printError(err.Error())
os.Exit(1)
}
defer tokenFile.Close()
// TODO: Maybe encrypt token
_, err = tokenFile.WriteString(token)
if err != nil {
printError(err.Error())
os.Exit(1)
}
// TODO: print logged in
case "help":
showHelp()
default:
printError("MSG SERVICE: command error!")
showHelp()
}
}
func getHTTPClient(tlsConfig *tls.Config) *http.Client {
transport := &http.Transport{TLSClientConfig: tlsConfig}
return &http.Client{Transport: transport}
}
func registerUser(userId string, password string, clientKeyStore cryptoUtils.KeyStore) error {
postRegister := protocol.NewPostRegister(userId, password, clientKeyStore.GetCert().Raw)
jsonData, err := json.Marshal(postRegister)
if err != nil {
return err
}
client := getHTTPClient(clientKeyStore.GetClientTLSConfig())
parsedURL, err := url.Parse(baseURL)
if err != nil {
return err
}
parsedURL.JoinPath("register")
req, err := http.NewRequest("POST", parsedURL.String(), bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
var reportError protocol.ReportError
if err := json.Unmarshal(responseBody, &reportError); err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
return nil
}
func login(userId string, password string, clientKeyStore cryptoUtils.KeyStore) (string, error) {
postLogin := protocol.NewPostLogin(userId, password)
jsonData, err := json.Marshal(postLogin)
if err != nil {
return "", err
}
client := getHTTPClient(clientKeyStore.GetClientTLSConfig())
parsedURL, err := url.Parse(baseURL)
if err != nil {
return "", err
}
parsedURL.JoinPath("login")
req, err := http.NewRequest("POST", parsedURL.String(), bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
var reportError protocol.ReportError
if err := json.Unmarshal(responseBody, &reportError); err != nil {
return "", err
}
return "", errors.New(reportError.ErrorMessage)
}
token := struct {
TokenString string `json:"token"`
}{}
err = json.Unmarshal(responseBody, &token)
if err != nil {
return "", err
}
return token.TokenString, nil
}
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, recieverUID string, token string) error {
//Turn content to bytes
plainSubjectBytes, err := Marshal(plainSubject)
if err != nil {
return err
}
plainBodyBytes, err := Marshal(plainBody)
if err != nil {
return err
}
receiverCert, err := getUserCert(clientKeyStore, recieverUID, token)
if err != nil {
return err
}
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, recieverUID, plainSubjectBytes)
if err != nil {
return err
}
body, err := clientKeyStore.EncryptMessageContent(receiverCert, recieverUID, plainBodyBytes)
if err != nil {
return err
}
client := getHTTPClient(clientKeyStore.GetClientTLSConfig())
// Parse the base URL
parsedURL, err := url.Parse(baseURL)
if err != nil {
return fmt.Errorf("error parsing URL: %v", err)
}
parsedURL.JoinPath("message")
sendMsgPacket := protocol.NewSendMsg(recieverUID, subject, body)
jsonData, err := json.Marshal(sendMsgPacket)
if err != nil {
return fmt.Errorf("error marshaling JSON: %v", err)
}
//TODO: ADD THE HEADER WITH THE TOKEN
req, err := http.NewRequest("POST", parsedURL.String(), bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
var reportError protocol.ReportError
if err := json.Unmarshal(responseBody, &reportError); err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
return nil
}
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int, token string) (ClientMessage, error) {
client := getHTTPClient(clientKeyStore.GetClientTLSConfig())
// Parse the base URL
parsedURL, err := url.Parse(baseURL)
if err != nil {
return ClientMessage{}, fmt.Errorf("error parsing URL: %v", err)
}
parsedURL.JoinPath("message")
newGetMsg := protocol.NewGetMsg(num)
jsonData, err := json.Marshal(newGetMsg)
if err != nil {
return ClientMessage{}, fmt.Errorf("error marshaling JSON: %v", err)
}
//TODO: ADD THE HEADER WITH THE TOKEN
req, err := http.NewRequest("GET", parsedURL.String(), bytes.NewBuffer(jsonData))
if err != nil {
return ClientMessage{}, fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return ClientMessage{}, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return ClientMessage{}, fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
var reportError protocol.ReportError
if err := json.Unmarshal(body, &reportError); err != nil {
return ClientMessage{}, err
}
return ClientMessage{}, errors.New(reportError.ErrorMessage)
}
var answerGetMsg protocol.AnswerGetMsg
if err := json.Unmarshal(body, &answerGetMsg); err != nil {
return ClientMessage{}, err
}
senderCert, err := getUserCert(clientKeyStore, answerGetMsg.FromUID, token)
if err != nil {
return ClientMessage{}, err
}
oidMap := cryptoUtils.ExtractAllOIDValues(clientKeyStore.GetCert())
myUID := oidMap["2.5.4.65"]
if myUID == "" {
return ClientMessage{}, errors.New("no pseudonym field on my certificate")
}
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, myUID, answerGetMsg.Subject)
if err != nil {
return ClientMessage{}, err
}
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, myUID, answerGetMsg.Body)
if err != nil {
return ClientMessage{}, err
}
messageSubject, err := Unmarshal(decSubjectBytes)
if err != nil {
return ClientMessage{}, err
}
messageBody, err := Unmarshal(decBodyBytes)
if err != nil {
return ClientMessage{}, err
}
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, messageSubject, messageBody, answerGetMsg.Timestamp)
return message, nil
}
func getUserCert(keyStore cryptoUtils.KeyStore, uid string, token string) (*x509.Certificate, error) {
client := getHTTPClient(keyStore.GetClientTLSConfig())
// Parse the base URL
parsedURL, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("error parsing URL: %v", err)
}
parsedURL.JoinPath("cert")
parsedURL.JoinPath(uid)
//TODO: ADD THE HEADER WITH THE TOKEN
req, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode == http.StatusOK {
var answerGetUserCert protocol.AnswerGetUserCert
if err := json.Unmarshal(body, &answerGetUserCert); err != nil {
return nil, err
}
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil, err
}
if err := keyStore.CheckCert(userCert, uid, "MSG SERVICE"); err != nil {
return nil, err
}
return userCert, nil
} else {
var reportError protocol.ReportError
if err := json.Unmarshal(body, &reportError); err != nil {
return nil, err
}
return nil, errors.New(reportError.ErrorMessage)
}
}
func getUnreadMessagesInfo(keyStore cryptoUtils.KeyStore, page int, pageSize int, token string) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
client := getHTTPClient(keyStore.GetClientTLSConfig())
// Parse the base URL
parsedURL, err := url.Parse(baseURL)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, fmt.Errorf("error parsing URL: %v", err)
}
parsedURL.JoinPath("queue")
getUnreadMessagesInfo := protocol.NewGetUnreadMsgsInfo(page, pageSize)
jsonData, err := json.Marshal(getUnreadMessagesInfo)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, fmt.Errorf("error marshaling JSON: %v", err)
}
//TODO: ADD THE HEADER WITH THE TOKEN
req, err := http.NewRequest("GET", parsedURL.String(), bytes.NewBuffer(jsonData))
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, fmt.Errorf("error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, fmt.Errorf("error reading response body: %v", err)
}
if resp.StatusCode == http.StatusOK {
//Create Set of needed certificates
var answerGetUnreadMsgsInfo protocol.AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(body, &answerGetUnreadMsgsInfo); err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert, err := getUserCert(keyStore, senderUID, token)
if err == nil {
certificatesMap[senderUID] = senderCert
}
}
return answerGetUnreadMsgsInfo, certificatesMap, nil
} else {
var reportError protocol.ReportError
if err := json.Unmarshal(body, &reportError); err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
return protocol.AnswerGetUnreadMsgsInfo{}, nil, errors.New(reportError.ErrorMessage)
}
}
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int, token string) (int, int, []ClientMessageInfo, error) {
unreadMsgsInfo, certificates, err := getUnreadMessagesInfo(clientKeyStore, page, pageSize, token)
if err != nil {
return 0, 0, nil, err
}
var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo {
var clientMessageInfo ClientMessageInfo
senderCert, ok := certificates[message.FromUID]
if !ok {
clientMessageInfo = newClientMessageInfo(message.Num,
message.FromUID,
"",
message.Timestamp,
errors.New("certificate needed to decrypt not received"))
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
oidMap := cryptoUtils.ExtractAllOIDValues(clientKeyStore.GetCert())
myUID := oidMap["2.5.4.65"]
if myUID == "" {
return 0, 0, nil, errors.New("no pseudonym field on my certificate")
}
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, myUID, message.Subject)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
subject, err := Unmarshal(decryptedSubjectBytes)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
clientMessages = append(clientMessages, clientMessageInfo)
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
return clientMessages[i].Num > clientMessages[j].Num
})
return unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages, nil
}