[PD1] merge

This commit is contained in:
Tiago Sousa 2024-04-22 19:30:50 +01:00
commit b8efcf19b7
Signed by: tiago
SSH key fingerprint: SHA256:odOD9vln9U7qNe1R8o3UCbE3jkQCkr5/q5mgd5hwua0
14 changed files with 729 additions and 192 deletions

View file

@ -4,7 +4,11 @@ import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"crypto/x509"
"flag"
"log"
"sort"
"strconv"
)
func Run() {
@ -15,7 +19,7 @@ func Run() {
if flag.NArg() == 0 {
panic("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
//Get user KeyStore
password := AskUserPassword()
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
@ -26,30 +30,73 @@ func Run() {
panic("Insufficient arguments for 'send' command. Usage: send <UID> <SUBJECT>")
}
uid := flag.Arg(1)
//subject := flag.Arg(2)
//messageContent := readMessageContent()
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
//Turn content to bytes
plainSubjectBytes := Marshal(plainSubject)
plainBodyBytes := Marshal(plainBody)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
certRequestPacket := protocol.NewRequestUserCertPacket(uid)
cl.Connection.Send(certRequestPacket)
//certPacket := cl.Connection.Receive()
// TODO: Encrypt message
//submitMessage(cl, uid, cipherContent)
receiverCert := getUserCert(cl, uid)
if receiverCert == nil {
return
}
subject := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
body := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if !cl.Connection.Send(sendMsgPacket) {
return
}
cl.Connection.Conn.Close()
case "askqueue":
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
askQueue(cl,clientKeyStore, page, pageSize)
case "getmsg":
if flag.NArg() < 2 {
panic("Insufficient arguments for 'getmsg' command. Usage: getmsg <NUM>")
}
//num := flag.Arg(1)
numString := flag.Arg(1)
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
defer cl.Connection.Conn.Close()
num, err := strconv.Atoi(numString)
if err != nil {
log.Panicln("NUM argument provided is not a number")
}
packet := protocol.NewGetMsgPacket(num)
cl.Connection.Send(packet)
receivedMsgPacket, active := cl.Connection.Receive()
if !active {
return
}
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
senderCert := getUserCert(cl, answerGetMsg.FromUID)
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
decBodyBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
subject := Unmarshal(decSubjectBytes)
body := Unmarshal(decBodyBytes)
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
case "help":
showHelp()
@ -60,7 +107,73 @@ func Run() {
}
func submitMessage(cl networking.Client[protocol.Packet], uid string, content []byte) {
pack := protocol.NewSubmitMessagePacket(uid, content)
cl.Connection.Send(pack)
func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certificate {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if !cl.Connection.Send(getUserCertPacket) {
return nil
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, active := cl.Connection.Receive()
if !active {
return nil
}
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil
}
return userCert
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate) {
answerGetUnreadMsgsInfoPacket, active := cl.Connection.Receive()
if !active {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
}
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert := getUserCert(cl, senderUID)
certificatesMap[senderUID] = senderCert
}
return answerGetUnreadMsgsInfo, certificatesMap
}
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
return
}
unreadMsgsInfo, certificates := getManyMessagesInfo(cl)
var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo {
senderCert, ok := certificates[message.FromUID]
if ok {
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
subject := Unmarshal(decryptedSubjectBytes)
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
clientMessages = append(clientMessages, clientMessage)
}
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
return clientMessages[i].Num > clientMessages[j].Num
})
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize)
case 0:
return
case 1:
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize)
}
}

View file

@ -1,15 +1,47 @@
package client
import "time"
import (
"encoding/json"
"log"
"time"
)
type Content struct {
Subject []byte
Body []byte
}
type RecievedMessage struct {
type ClientMessage struct {
FromUID string
ToUID string
Content Content
Subject string
Body string
Timestamp time.Time
}
type ClientMessageInfo struct {
Num int
FromUID string
Timestamp time.Time
Subject string
}
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
}
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time) ClientMessageInfo {
return ClientMessageInfo{Num:num,FromUID: fromUID,Subject: subject,Timestamp: timestamp}
}
func Marshal(data any) []byte {
subject, err := json.Marshal(data)
if err != nil {
log.Panicf("Error when marshalling message: %v", err)
}
return subject
}
func Unmarshal(data []byte) string {
var c string
err := json.Unmarshal(data, &c)
if err != nil {
log.Panicln("Could not unmarshal data")
}
return c
}

View file

@ -4,10 +4,11 @@ import (
"bufio"
"fmt"
"os"
"strings"
)
func readMessageContent() string {
fmt.Println("Enter message content (limited to 1000 bytes):")
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
// FIX: make sure this doesnt die
@ -23,11 +24,10 @@ func AskUserPassword() string {
}
func commandError() {
fmt.Println("MSG SERVICE: command error!")
showHelp()
fmt.Println("MSG SERVICE: command error!")
showHelp()
}
func showHelp() {
fmt.Println("Comandos da aplicação cliente:")
fmt.Println("-user <FNAME>: Especifica o ficheiro com dados do utilizador. Por omissão, será assumido que esse ficheiro é userdata.p12.")
@ -36,3 +36,64 @@ func showHelp() {
fmt.Println("getmsg <NUM>: Solicita ao servidor o envio da mensagem da sua queue com número <NUM>.")
fmt.Println("help: Imprime instruções de uso do programa.")
}
func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int {
if messages == nil {
fmt.Println("No unread messages in the queue")
return 0
}
for _, message := range messages {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
fmt.Printf("Page %v/%v\n",page,numPages)
return messagesInfoPageNavigation(page, numPages)
}
func messagesInfoPageNavigation(page int, numPages int) int {
var action string
switch page {
case 1:
if page == numPages {
action = readStdin("Actions: quit")
} else {
action = readStdin("Actions: quit/next")
}
case numPages:
action = readStdin("Actions: prev/quit")
default:
action = readStdin("prev/quit/next")
}
switch strings.ToLower(action) {
case "prev":
if page == 1 {
fmt.Println("Unavailable action: Already in first page")
messagesInfoPageNavigation(page, numPages)
} else {
return -1
}
case "quit":
return 0
case "next":
if page == numPages {
fmt.Println("Unavailable action: Already in last page")
messagesInfoPageNavigation(page, numPages)
} else {
return 1
}
default:
fmt.Println("Unknown action")
messagesInfoPageNavigation(page, numPages)
}
return 0
}
func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body)
}

View file

@ -1,116 +1,272 @@
package protocol
import (
"encoding/json"
"fmt"
"time"
)
type PacketType int
const (
ReqUserCertPkt PacketType = iota
ReqAllMsgPkt
ReqMsgPkt
SubmitMsgPkt
SendUserCertPkt
ServerMsgPkt
// Client requests user certificate
FlagGetUserCert PacketType = iota
// Client requests unread message info
FlagGetUnreadMsgsInfo
// Client requests a message from the queue
FlagGetMsg
// Client sends a message
FlagSendMsg
// Server sends user certificate
FlagAnswerGetUserCert
// Server sends list of unread messages
FlagAnswerGetUnreadMsgsInfo
// Server sends requested message
FlagAnswerGetMsg
)
type (
GetUserCert struct {
UID string `json:"uid"`
}
GetUnreadMsgsInfo struct {
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
GetMsg struct {
Num int `json:"num"`
}
SendMsg struct {
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
Body []byte `json:"body"`
}
AnswerGetUserCert struct {
UID string `json:"uid"`
Certificate []byte `json:"certificate"`
}
AnswerGetUnreadMsgsInfo struct {
Page int `json:"page"`
NumPages int `json:"num_pages"`
MessagesInfo []MsgInfo `json:"messages_info"`
}
MsgInfo struct {
Num int `json:"num"`
FromUID string `json:"from_uid"`
Subject []byte `json:"subject"`
Timestamp time.Time `json:"timestamp"`
}
AnswerGetMsg struct {
FromUID string `json:"from_uid"`
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
Body []byte `json:"body"`
Timestamp time.Time `json:"timestamp"`
}
)
type PacketBody interface{}
type Packet struct {
Flag PacketType
Body PacketBody
Flag PacketType `json:"flag"`
Body PacketBody `json:"body"`
}
// Client --> Server: Ask for a user's certificate
type RequestUserCertPacket struct {
UID string
}
func NewRequestUserCertPacket(UID string) Packet {
func NewPacket(fl PacketType, body PacketBody) Packet {
return Packet{
Flag: ReqUserCertPkt,
Body: RequestUserCertPacket{
UID: UID,
},
Flag: fl,
Body: body,
}
}
func NewGetUserCert(UID string) GetUserCert {
return GetUserCert{
UID: UID,
}
}
// Client --> Server: Ask for all the client's messages in the queue
type RequestAllMsgPacket struct {
FromUID string
func NewGetUnreadMsgsInfo(page int, pageSize int) GetUnreadMsgsInfo {
return GetUnreadMsgsInfo{
Page: page,
PageSize: pageSize}
}
func NewRequestAllMsgPacket(fromUID string) Packet {
return Packet{
Flag: ReqAllMsgPkt,
Body: RequestAllMsgPacket{
FromUID: fromUID,
},
func NewGetMsg(num int) GetMsg {
return GetMsg{
Num: num,
}
}
// Client --> Server: Ask for a specific message in the queue
type RequestMsgPacket struct {
Num uint16
}
func NewRequestMsgPacket(num uint16) Packet {
return Packet{
Flag: ReqMsgPkt,
Body: RequestMsgPacket{
Num: num,
},
func NewSendMsg(toUID string, subject []byte, body []byte) SendMsg {
return SendMsg{
ToUID: toUID,
Subject: subject,
Body: body,
}
}
// Client --> Server: Send message from client to server
type SubmitMessagePacket struct {
ToUID string
Content []byte
}
func NewSubmitMessagePacket(toUID string, content []byte) Packet {
return Packet{
Flag: SubmitMsgPkt,
Body: SubmitMessagePacket{
ToUID: toUID,
Content: content},
func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
return AnswerGetUserCert{
UID: uid,
Certificate: certificate,
}
}
// Server --> Client: Send the client the requested public key
type SendUserCertPacket struct {
UID string
Key []byte
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo}
}
func NewSendUserCertPacket(uid string, key []byte) Packet {
return Packet{
Flag: SendUserCertPkt,
Body: SendUserCertPacket{
UID: uid,
Key: key,
},
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{
Num: num,
FromUID: fromUID,
Subject: subject,
Timestamp: timestamp,
}
}
// Server --> Client: Send the client a message
type ServerMessagePacket struct {
FromUID string
ToUID string
Content []byte
Timestamp time.Time
}
func NewServerMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet {
return Packet{
Flag: ServerMsgPkt,
Body: ServerMessagePacket{
FromUID: fromUID,
ToUID: toUID,
Content: content,
Timestamp: timestamp,
},
func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) AnswerGetMsg {
return AnswerGetMsg{
FromUID: fromUID,
ToUID: toUID,
Subject: subject,
Body: body,
Timestamp: timestamp,
}
}
func NewGetUserCertPacket(UID string) Packet {
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
}
func NewGetUnreadMsgsInfoPacket(page int, pageSize int) Packet {
return NewPacket(FlagGetUnreadMsgsInfo, NewGetUnreadMsgsInfo(page, pageSize))
}
func NewGetMsgPacket(num int) Packet {
return NewPacket(FlagGetMsg, NewGetMsg(num))
}
func NewSendMsgPacket(toUID string, subject []byte, body []byte) Packet {
return NewPacket(FlagSendMsg, NewSendMsg(toUID, subject, body))
}
func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
return NewPacket(FlagAnswerGetUserCert, NewAnswerGetUserCert(uid, certificate))
}
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo))
}
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
}
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet GetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUserCert: %v", err))
}
return packet
}
func UnmarshalGetUnreadMsgsInfo(data PacketBody) GetUnreadMsgsInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet GetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetUnreadMsgsInfo: %v", err))
}
return packet
}
func UnmarshalGetMsg(data PacketBody) GetMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet GetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into GetMsg: %v", err))
}
return packet
}
func UnmarshalSendMsg(data PacketBody) SendMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet SendMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into SendMsg: %v", err))
}
return packet
}
func UnmarshalAnswerGetUserCert(data PacketBody) AnswerGetUserCert {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet AnswerGetUserCert
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUserCert: %v", err))
}
return packet
}
func UnmarshalUnreadMsgInfo(data PacketBody) MsgInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet MsgInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into UnreadMsgInfo: %v", err))
}
return packet
}
func UnmarshalAnswerGetUnreadMsgsInfo(data PacketBody) AnswerGetUnreadMsgsInfo {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet AnswerGetUnreadMsgsInfo
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetUnreadMsgsInfo: %v", err))
}
return packet
}
func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
jsonData, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal data: %v", err))
}
var packet AnswerGetMsg
if err := json.Unmarshal(jsonData, &packet); err != nil {
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
}
return packet
}

View file

@ -2,7 +2,9 @@ package server
import (
"PD1/internal/protocol"
"crypto/x509"
"database/sql"
"fmt"
"log"
"time"
@ -18,7 +20,9 @@ func OpenDB() DataStore {
if err != nil {
log.Fatalln("Error opening db file")
}
return DataStore{db: db}
ds := DataStore{db: db}
ds.CreateTables()
return ds
}
func (ds DataStore) CreateTables() error {
@ -28,6 +32,7 @@ func (ds DataStore) CreateTables() error {
userCert BLOB
)`)
if err != nil {
fmt.Println("Error creating users table", err)
return err
}
@ -36,106 +41,144 @@ func (ds DataStore) CreateTables() error {
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
content BLOB,
queue_position INT DEFAULT 0,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
PRIMARY KEY (toUID, fromUID, timestamp),
FOREIGN KEY(fromUID) REFERENCES users(UID),
FOREIGN KEY(toUID) REFERENCES users(UID)
)`)
if err != nil {
fmt.Println("Error creating messages table", err)
return err
}
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
fmt.Println("Error creating trigger", err)
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) protocol.ServerMessagePacket {
func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
var serverMessage protocol.ServerMessagePacket
var serverMessage protocol.AnswerGetMsg
query := `
SELECT fromUID, toUID, content, timestamp
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ?
AND status = 0
ORDER BY timestamp
LIMIT 1 OFFSET ?
WHERE toUID = ? AND queue_position = ?
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Content, &serverMessage.Timestamp)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err != nil {
log.Panicln("Could not map DB query to ServerMessage")
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
}
return serverMessage
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
}
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) error {
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
query := `
UPDATE messages
SET status = 1
WHERE toUID = ? AND status = 0
ORDER BY timestamp
LIMIT 1 OFFSET ?
SET status = 1
WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp
FROM messages
WHERE toUID = ? AND queue_position = ?
)
`
// Execute the SQL statement
_, err := ds.db.Exec(query, toUID, position)
if err != nil {
return err
log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err)
}
return nil
}
func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet {
var messagePackets []protocol.Packet
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err != nil {
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
}
// Query to retrieve all messages from the user's queue
query := `
SELECT fromUID, toUID, content, timestamp
SELECT
fromUID,
toUID,
timestamp,
queue_position,
subject,
status
FROM messages
WHERE toUID = ?
AND status = 0
ORDER BY timestamp
WHERE
toUID = ? AND status = 0
ORDER BY
queue_position DESC
LIMIT ? OFFSET ?;
`
// Execute the query
rows, err := ds.db.Query(query, toUID)
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Panicln("Failed to execute query:", err)
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
}
defer rows.Close()
// Iterate through the result set and scan each row into a ServerMessage struct
messageInfoPackets := []protocol.MsgInfo{}
for rows.Next() {
var fromUID string
var toUID string
var content []byte
var subject []byte
var timestamp time.Time
if err := rows.Scan(&fromUID, &toUID, &content, &timestamp); err != nil {
log.Panicln("Failed to scan row:", err)
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
panic(err)
}
message := protocol.NewServerMessagePacket(fromUID, toUID, content, timestamp)
messagePackets = append(messagePackets, message)
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Panicln("Error when getting user's messages")
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
}
return messagePackets
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
}
func (ds DataStore) AddMessageToQueue(uid string, message protocol.SubmitMessagePacket) {
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
query := `
INSERT INTO messages (fromUID, toUID, content, timestamp, status)
VALUES (?, ?, ?, ?, 0)
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
`
// Execute the SQL statement
currentTime := time.Now()
_, err := ds.db.Exec(query, uid, message.ToUID, message.Content, currentTime)
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil {
log.Panicln("Error adding message to database")
log.Printf("Error adding message to UID %v: %v", fromUID, err)
}
}
@ -147,10 +190,53 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
`
// Execute the SQL query
var userCert []byte
err := ds.db.QueryRow(query, uid).Scan(&userCert)
if err != nil {
log.Panicln("Error getting user certificate from the database")
var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows {
log.Panicf("No certificate for UID %v found in the database", uid)
}
return protocol.NewSendUserCertPacket(uid, userCert)
//userCert,err := x509.ParseCertificate(userCertBytes)
//if err!=nil {
// log.Panicf("Error parsing certificate for UID %v",uid)
//}
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
}
func (ds DataStore) userExists(uid string) bool {
// Prepare the SQL statement for checking if a user exists
query := `
SELECT COUNT(*)
FROM users
WHERE UID = ?
`
var count int
// Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count)
if err == sql.ErrNoRows {
log.Printf("User with UID %v does not exist", uid)
return false
} else {
return true
}
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
// Check if the user already exists
if ds.userExists(uid) {
log.Printf("User certificate for UID %s already exists.\n", uid)
return
}
// Insert the user certificate
insertQuery := `
INSERT INTO users (UID, userCert)
VALUES (?, ?)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
return
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
}

View file

@ -4,28 +4,58 @@ import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"fmt"
)
//TODO: CREATE SERVER SIDE CHECKS FOR EVERYTHING
//TODO: LOGGING SYSTEM
//TODO: TELL THE USER THAT THE MESSAGE HAS BEEN RECEIVED BY THE SERVER
//TODO: ERROR PACKET TO SEND BACK TO USER
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
defer connection.Conn.Close()
_ = dataStore
clientCert := connection.GetPeerCertificate()
oidValueMap := cryptoUtils.ExtractAllOIDValues(clientCert)
fmt.Println(oidValueMap)
//Get certificate sent by user
clientCert := connection.GetPeerCertificate()
//Get the OID values
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
//Get the UID of this user
UID := oidMap["2.5.4.65"]
if UID == "" {
panic("User certificate does not specify it's PSEUDONYM")
}
dataStore.storeUserCertIfNotExists(UID, *clientCert)
F:
for {
pac := connection.Receive()
pac, active := connection.Receive()
if !active {
break
}
switch pac.Flag {
case protocol.ReqUserCertPkt:
//userCertPacket := dataStore.GetUserCertificate(uid)
//connection.Send(userCertPacket)
case protocol.ReqAllMsgPkt:
fmt.Println("ReqAllMsg")
case protocol.ReqMsgPkt:
fmt.Println("ReqMsg")
case protocol.SubmitMsgPkt:
fmt.Println("SubmitMsg")
case protocol.FlagGetUserCert:
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
if active := connection.Send(userCertPacket); !active {
break F
}
case protocol.FlagGetUnreadMsgsInfo:
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
if !connection.Send(messages) {
break F
}
case protocol.FlagGetMsg:
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
message := dataStore.GetMessage(UID, reqMsg.Num)
if active := connection.Send(message); !active {
break F
}
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
case protocol.FlagSendMsg:
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
dataStore.AddMessageToQueue(UID, submitMsg)
}
}
}

View file

@ -94,9 +94,8 @@ func (k *KeyStore) GetServerTLSConfig() *tls.Config {
caCertPool.AddCert(caCert)
}
tlsConfig.ClientCAs = caCertPool
//Request one valid or invalid certificate
// FIX: SERVER ACCEPTS CONNECTIONS WITH UNMATCHING OR
// NO CERTIFICATE, NEEDS TO BE CHANGED SOMEHOW
//FIX: SERVER ACCEPTS CONNECTIONS WITH UNMATCHING OR
// NO CERTIFICATE, NEEDS TO BE CHANGED SOMEHOW
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
return tlsConfig
}

View file

@ -4,6 +4,8 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"io"
"log"
)
type Connection[T any] struct {
@ -20,18 +22,33 @@ func NewConnection[T any](netConn *tls.Conn) Connection[T] {
}
}
func (c Connection[T]) Send(obj T) {
if err := c.encoder.Encode(&obj); err != nil {
panic("Failed encoding data or sending it to connection")
}
func (c Connection[T]) Send(obj T) bool {
if err := c.encoder.Encode(&obj); err!=nil {
if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return false
} else {
log.Panic(err)
}
}
//Return true as connection active
return true
}
func (c Connection[T]) Receive() T {
func (c Connection[T]) Receive() (*T, bool) {
var obj T
if err := c.decoder.Decode(&obj); err != nil {
panic("Failed decoding data or reading it from connection")
if err == io.EOF {
log.Println("Connection closed by peer")
//Return false as connection not active
return nil,false
} else {
log.Panic(err)
}
}
return obj
//Return true as connection active
return &obj, true
}
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {

View file

@ -43,7 +43,6 @@ func (s *Server[T]) ListenLoop() {
state := tlsConn.ConnectionState()
if len(state.PeerCertificates) == 0 {
fmt.Println(state.PeerCertificates)
log.Panicln("Client did not provide a certificate")
}
conn := NewConnection[T](tlsConn)