[PD2] Server done?

Co-authored-by: tsousa111 <tiagao2001@hotmail.com>
This commit is contained in:
Afonso Franco 2024-05-28 20:08:25 +01:00
parent 08a73a4f76
commit 49a29e43a7
Signed by: afonso
SSH key fingerprint: SHA256:PQTRDHPH3yALEGtHXnXBp3Orfcn21pK20t0tS1kHg54
66 changed files with 2777 additions and 5 deletions

View file

@ -0,0 +1,323 @@
package client
import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"PD1/internal/utils/networking"
"crypto/x509"
"errors"
"flag"
"log"
"os"
"sort"
"strconv"
)
func Run() {
var userFile string
flag.StringVar(&userFile, "user", "userdata.p12", "Specify user data file")
flag.Parse()
if flag.NArg() == 0 {
log.Fatalln("No command provided. Use 'help' for instructions.")
}
//Get user KeyStore
password := readStdin("Insert keystore passphrase")
clientKeyStore, err := cryptoUtils.LoadKeyStore(userFile, password)
if err != nil {
log.Fatalln(err)
}
command := flag.Arg(0)
switch command {
case "send":
if flag.NArg() != 3 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
uid := flag.Arg(1)
plainSubject := flag.Arg(2)
plainBody := readStdin("Enter message content (limited to 1000 bytes):")
err := sendCommand(clientKeyStore, plainSubject, plainBody, uid)
if err != nil {
log.Fatalln(err)
}
case "askqueue":
if flag.NArg() > 3 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
pageInput := flag.Arg(1)
page := 1
if pageInput != "" {
if val, err := strconv.Atoi(pageInput); err == nil {
page = max(1, val)
}
}
pageSizeInput := flag.Arg(2)
pageSize := 5
if pageSizeInput != "" {
if val, err := strconv.Atoi(pageSizeInput); err == nil {
pageSize = max(1, val)
}
}
err := askQueueCommand(clientKeyStore, page, pageSize)
if err != nil {
log.Fatalln(err)
}
case "getmsg":
if flag.NArg() < 2 {
printError("MSG SERVICE: command error!")
showHelp()
os.Exit(1)
}
numString := flag.Arg(1)
num, err := strconv.Atoi(numString)
if err != nil {
log.Fatalln(err)
}
err = getMsgCommand(clientKeyStore, num)
if err != nil {
printError(err.Error())
}
case "help":
showHelp()
default:
printError("MSG SERVICE: command error!")
showHelp()
}
}
func sendCommand(clientKeyStore cryptoUtils.KeyStore, plainSubject, plainBody, uid string) error {
//Turn content to bytes
plainSubjectBytes, err := Marshal(plainSubject)
if err != nil {
return err
}
plainBodyBytes, err := Marshal(plainBody)
if err != nil {
return err
}
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
receiverCert, err := getUserCert(cl, clientKeyStore, uid)
if err != nil {
return err
}
subject, err := clientKeyStore.EncryptMessageContent(receiverCert, plainSubjectBytes)
if err != nil {
return err
}
body, err := clientKeyStore.EncryptMessageContent(receiverCert, plainBodyBytes)
if err != nil {
return err
}
sendMsgPacket := protocol.NewSendMsgPacket(uid, subject, body)
if err := cl.Connection.Send(sendMsgPacket); err != nil {
return err
}
answerSendMsg, err := cl.Connection.Receive()
if err != nil {
return err
}
if answerSendMsg.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerSendMsg.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
return nil
}
func getMsgCommand(clientKeyStore cryptoUtils.KeyStore, num int) error {
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
packet := protocol.NewGetMsgPacket(num)
if err := cl.Connection.Send(packet); err != nil {
return err
}
receivedMsgPacket, err := cl.Connection.Receive()
if err != nil {
return err
}
if receivedMsgPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(receivedMsgPacket.Body)
if err != nil {
return err
}
return errors.New(reportError.ErrorMessage)
}
answerGetMsg, err := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
if err != nil {
return err
}
senderCert, err := getUserCert(cl, clientKeyStore, answerGetMsg.FromUID)
if err != nil {
return err
}
decSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
if err != nil {
return err
}
decBodyBytes, err := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Body)
if err != nil {
return err
}
subject, err := Unmarshal(decSubjectBytes)
if err != nil {
return err
}
body, err := Unmarshal(decBodyBytes)
if err != nil {
return err
}
message := newClientMessage(answerGetMsg.FromUID, answerGetMsg.ToUID, subject, body, answerGetMsg.Timestamp)
showMessage(message)
return nil
}
func getUserCert(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore, uid string) (*x509.Certificate, error) {
getUserCertPacket := protocol.NewGetUserCertPacket(uid)
if err := cl.Connection.Send(getUserCertPacket); err != nil {
return nil, err
}
var answerGetUserCertPacket *protocol.Packet
answerGetUserCertPacket, err := cl.Connection.Receive()
if err != nil {
return nil, err
}
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
return nil, errors.New(reportError.ErrorMessage)
}
answerGetUserCert, err := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
if err != nil {
return nil, err
}
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
if err != nil {
return nil, err
}
if err := keyStore.CheckCert(userCert, uid); err != nil {
return nil, err
}
return userCert, nil
}
func getManyMessagesInfo(cl networking.Client[protocol.Packet], keyStore cryptoUtils.KeyStore) (protocol.AnswerGetUnreadMsgsInfo, map[string]*x509.Certificate, error) {
answerGetUnreadMsgsInfoPacket, err := cl.Connection.Receive()
if err != nil {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, err
}
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
reportError, err := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil, errors.New(reportError.ErrorMessage)
}
answerGetUnreadMsgsInfo, err := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
if err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, nil, err
}
//Create Set of needed certificates
senderSet := map[string]bool{}
for _, messageInfo := range answerGetUnreadMsgsInfo.MessagesInfo {
senderSet[messageInfo.FromUID] = true
}
certificatesMap := map[string]*x509.Certificate{}
//Get senders' certificates
for senderUID := range senderSet {
senderCert, err := getUserCert(cl, keyStore, senderUID)
if err == nil {
certificatesMap[senderUID] = senderCert
}
}
return answerGetUnreadMsgsInfo, certificatesMap, nil
}
func askQueueCommand(clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
cl, err := networking.NewClient[protocol.Packet](&clientKeyStore)
if err != nil {
return err
}
defer cl.Connection.Conn.Close()
return askQueueRec(cl, clientKeyStore, page, pageSize)
}
func askQueueRec(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) error {
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
if err := cl.Connection.Send(requestUnreadMsgsQueuePacket); err != nil {
return err
}
unreadMsgsInfo, certificates, err := getManyMessagesInfo(cl, clientKeyStore)
if err != nil {
return err
}
var clientMessages []ClientMessageInfo
for _, message := range unreadMsgsInfo.MessagesInfo {
var clientMessageInfo ClientMessageInfo
senderCert, ok := certificates[message.FromUID]
if !ok {
clientMessageInfo = newClientMessageInfo(message.Num,
message.FromUID,
"",
message.Timestamp,
errors.New("certificate needed to decrypt not received"))
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
decryptedSubjectBytes, err := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
subject, err := Unmarshal(decryptedSubjectBytes)
if err != nil {
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, "", message.Timestamp, err)
clientMessages = append(clientMessages, clientMessageInfo)
continue
}
clientMessageInfo = newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp, nil)
clientMessages = append(clientMessages, clientMessageInfo)
}
//Sort the messages
sort.Slice(clientMessages, func(i, j int) bool {
return clientMessages[i].Num > clientMessages[j].Num
})
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
switch action {
case -1:
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
case 1:
return askQueueRec(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
default:
return nil
}
}

View file

@ -0,0 +1,47 @@
package client
import (
"encoding/json"
"time"
)
type ClientMessage struct {
FromUID string
ToUID string
Subject string
Body string
Timestamp time.Time
}
type ClientMessageInfo struct {
Num int
FromUID string
Timestamp time.Time
Subject string
decryptError error
}
func newClientMessage(fromUID string, toUID string, subject string, body string, timestamp time.Time) ClientMessage {
return ClientMessage{FromUID: fromUID, ToUID: toUID, Subject: subject, Body: body, Timestamp: timestamp}
}
func newClientMessageInfo(num int, fromUID string, subject string, timestamp time.Time, err error) ClientMessageInfo {
return ClientMessageInfo{Num: num, FromUID: fromUID, Subject: subject, Timestamp: timestamp, decryptError: err}
}
func Marshal(data any) ([]byte, error) {
subject, err := json.Marshal(data)
if err != nil {
return nil, err
}
return subject, nil
}
func Unmarshal(data []byte) (string, error) {
var c string
err := json.Unmarshal(data, &c)
if err != nil {
return "", err
}
return c, nil
}

View file

@ -0,0 +1,92 @@
package client
import (
"bufio"
"fmt"
"os"
"strings"
)
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
return scanner.Text()
}
func printError(err string) {
fmt.Fprintln(os.Stderr, err)
}
func showHelp() {
fmt.Println("Comandos da aplicação cliente:")
fmt.Println("-user <FNAME>: Especifica o ficheiro com dados do utilizador. Por omissão, será assumido que esse ficheiro é userdata.p12.")
fmt.Println("send <UID> <SUBJECT>: Envia uma mensagem com assunto <SUBJECT> destinada ao utilizador com identificador <UID>. O conteúdo da mensagem será lido do stdin, e o tamanho deve ser limitado a 1000 bytes.")
fmt.Println("askqueue: Solicita ao servidor que lhe envie a lista de mensagens não lidas da queue do utilizador.")
fmt.Println("getmsg <NUM>: Solicita ao servidor o envio da mensagem da sua queue com número <NUM>.")
fmt.Println("help: Imprime instruções de uso do programa.")
}
func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int {
if messages == nil {
fmt.Println("No unread messages in the queue")
return 0
}
for _, message := range messages {
if message.decryptError != nil {
fmt.Printf("ERROR: %v:%v:%v:", message.Num, message.FromUID, message.Timestamp)
fmt.Println(message.decryptError)
} else {
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
}
}
fmt.Printf("Page %v/%v\n", page, numPages)
return messagesInfoPageNavigation(page, numPages)
}
func messagesInfoPageNavigation(page int, numPages int) int {
var action string
switch page {
case 1:
if page == numPages {
return 0
} else {
action = readStdin("Actions: quit/next")
}
case numPages:
action = readStdin("Actions: prev/quit")
default:
action = readStdin("prev/quit/next")
}
switch strings.ToLower(action) {
case "prev":
if page == 1 {
fmt.Println("Unavailable action: Already in first page")
messagesInfoPageNavigation(page, numPages)
} else {
return -1
}
case "quit":
return 0
case "next":
if page == numPages {
fmt.Println("Unavailable action: Already in last page")
messagesInfoPageNavigation(page, numPages)
} else {
return 1
}
default:
fmt.Println("Unknown action")
messagesInfoPageNavigation(page, numPages)
}
return 0
}
func showMessage(message ClientMessage) {
fmt.Printf("From: %s\n", message.FromUID)
fmt.Printf("To: %s\n", message.ToUID)
fmt.Printf("Subject: %s\n", message.Subject)
fmt.Printf("Body: %s\n", message.Body)
}

View file

@ -0,0 +1,5 @@
package gateway
func Run(){
}

View file

@ -0,0 +1,119 @@
package protocol
import (
"time"
)
type Body interface{}
type (
GetUserCert struct {
UID string `json:"uid"`
}
GetUnreadMsgsInfo struct {
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
GetMsg struct {
Num int `json:"num"`
}
SendMsg struct {
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
Body []byte `json:"body"`
}
AnswerGetUserCert struct {
UID string `json:"uid"`
Certificate []byte `json:"certificate"`
}
AnswerGetUnreadMsgsInfo struct {
Page int `json:"page"`
NumPages int `json:"num_pages"`
MessagesInfo []MsgInfo `json:"messages_info"`
}
MsgInfo struct {
Num int `json:"num"`
FromUID string `json:"from_uid"`
Subject []byte `json:"subject"`
Timestamp time.Time `json:"timestamp"`
}
AnswerGetMsg struct {
FromUID string `json:"from_uid"`
ToUID string `json:"to_uid"`
Subject []byte `json:"subject"`
Body []byte `json:"body"`
Timestamp time.Time `json:"timestamp"`
}
ReportError struct {
ErrorMessage string `json:"error"`
}
)
func NewGetUserCert(UID string) GetUserCert {
return GetUserCert{
UID: UID,
}
}
func NewGetUnreadMsgsInfo(page int, pageSize int) GetUnreadMsgsInfo {
return GetUnreadMsgsInfo{
Page: page,
PageSize: pageSize}
}
func NewGetMsg(num int) GetMsg {
return GetMsg{
Num: num,
}
}
func NewSendMsg(toUID string, subject []byte, body []byte) SendMsg {
return SendMsg{
ToUID: toUID,
Subject: subject,
Body: body,
}
}
func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
return AnswerGetUserCert{
UID: uid,
Certificate: certificate,
}
}
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
return AnswerGetUnreadMsgsInfo{Page: page, NumPages: numPages, MessagesInfo: messagesInfo}
}
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
return MsgInfo{
Num: num,
FromUID: fromUID,
Subject: subject,
Timestamp: timestamp,
}
}
func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) AnswerGetMsg {
return AnswerGetMsg{
FromUID: fromUID,
ToUID: toUID,
Subject: subject,
Body: body,
Timestamp: timestamp,
}
}
func NewReportError(errorMessage string) ReportError {
return ReportError{
ErrorMessage: errorMessage,
}
}

View file

@ -0,0 +1,245 @@
package server
import (
"PD1/internal/protocol"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"log"
"time"
_ "github.com/mattn/go-sqlite3"
)
type DataStore struct {
db *sql.DB
}
func OpenDB() (DataStore, error) {
db, err := sql.Open("sqlite3", "server.db")
if err != nil {
return DataStore{}, err
}
ds := DataStore{db: db}
err = ds.CreateTables()
if err != nil {
return DataStore{}, err
}
return ds, nil
}
func (ds DataStore) CreateTables() error {
// Create users table
_, err := ds.db.Exec(`CREATE TABLE IF NOT EXISTS users (
UID TEXT PRIMARY KEY,
userCert BLOB
)`)
if err != nil {
return err
}
// Create messages table
_, err = ds.db.Exec(`CREATE TABLE IF NOT EXISTS messages (
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
queue_position INT DEFAULT 0,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
PRIMARY KEY (toUID, fromUID, timestamp),
FOREIGN KEY(fromUID) REFERENCES users(UID),
FOREIGN KEY(toUID) REFERENCES users(UID)
)`)
if err != nil {
return err
}
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) (*protocol.AnswerGetMsg, error) {
var serverMessage protocol.AnswerGetMsg
query := `
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ? AND queue_position = ?
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err == sql.ErrNoRows {
log.Printf("No message with NUM %v for UID %v\n", position, toUID)
errorMessage := fmt.Sprintln("MSG SERVICE: unknown message!")
error := errors.New(errorMessage)
return nil, error
}
answer := protocol.NewAnswerGetMsg(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
return &answer, nil
}
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
query := `
UPDATE messages
SET status = 1
WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp
FROM messages
WHERE toUID = ? AND queue_position = ?
)
`
// Execute the SQL statement
_, err := ds.db.Exec(query, toUID, position)
if err != nil {
log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err)
}
}
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.AnswerGetUnreadMsgsInfo, error) {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err == sql.ErrNoRows {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, []protocol.MsgInfo{}), nil
}
// Query to retrieve all messages from the user's queue
query := `
SELECT
fromUID,
toUID,
timestamp,
queue_position,
subject,
status
FROM messages
WHERE
toUID = ? AND status = 0
ORDER BY
queue_position DESC
LIMIT ? OFFSET ?;
`
// Execute the query
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Printf("Error getting unread messages for UID %v: %v", toUID, err)
}
defer rows.Close()
messageInfoPackets := []protocol.MsgInfo{}
for rows.Next() {
var fromUID string
var subject []byte
var timestamp time.Time
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, err
}
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.AnswerGetUnreadMsgsInfo{}, err
}
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfo(currentPage, numberOfPages, messageInfoPackets), nil
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) error {
query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
`
// Execute the SQL statement
currentTime := time.Now()
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil {
log.Printf("Error adding message to UID %v: %v", fromUID, err)
return err
}
return nil
}
func (ds DataStore) GetUserCertificate(uid string) (protocol.AnswerGetUserCert,error) {
query := `
SELECT userCert
FROM users
WHERE UID = ?
`
// Execute the SQL query
var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows {
errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid)
log.Println(errorMessage)
return protocol.AnswerGetUserCert{},errors.New(errorMessage)
}
return protocol.NewAnswerGetUserCert(uid, userCertBytes),nil
}
func (ds DataStore) userExists(uid string) bool {
// Prepare the SQL statement for checking if a user exists
query := `
SELECT COUNT(*)
FROM users
WHERE UID = ?
`
var count int
// Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count)
if err != nil || count == 0 {
log.Printf("user with UID %v does not exist\n", uid)
return false
}
return true
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
// Check if the user already exists
if ds.userExists(uid) {
return nil
}
// Insert the user certificate
insertQuery := `
INSERT INTO users (UID, userCert)
VALUES (?, ?)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
return fmt.Errorf("error storing user certificate for UID %s: %v", uid, err)
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
return nil
}

View file

@ -0,0 +1,14 @@
package server
import (
"bufio"
"fmt"
"os"
)
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
return scanner.Text()
}

View file

@ -0,0 +1,151 @@
package server
import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"log"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
//func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
// defer connection.Conn.Close()
//
// //Get certificate sent by user
// clientCert := connection.GetPeerCertificate()
// //Get the OID values
// oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
// //Check if certificate usage is MSG SERVICE
// usage := oidMap["2.5.4.11"]
// if usage == "" {
// log.Fatalln("User certificate does not have the correct usage")
// }
// //Get the UID of this user
// UID := oidMap["2.5.4.65"]
// if UID == "" {
// log.Fatalln("User certificate does not specify it's PSEUDONYM")
// }
// err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
// if err != nil {
// log.Fatalln(err)
// }
//}
func HandleGetUserCert(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
userCertPacket, err := dataStore.GetUserCertificate(user)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
} else {
c.JSON(http.StatusOK, userCertPacket)
}
}
func HandleGetUnreadMsgsInfo(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
var getUnreadMsgsInfo protocol.GetUnreadMsgsInfo
if err := c.BindJSON(getUnreadMsgsInfo); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Page and PageSize need to be >= 1"})
return
}
unreadMsgsInfo, err := dataStore.GetUnreadMsgsInfo(user, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, unreadMsgsInfo)
}
func HandleSendMessage(c *gin.Context, dataStore DataStore) {
sender := c.Param("user")
var message protocol.SendMsg
if err := c.BindJSON(message); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if message.ToUID == sender {
c.JSON(http.StatusBadRequest, gin.H{"error": "Message sender and receiver cannot be the same user"})
return
}
if !dataStore.userExists(message.ToUID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Message receiver does not exist"})
return
}
err := dataStore.AddMessageToQueue(sender, message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
}
func HandleGetMessage(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
numStr := c.Param("num")
num, err := strconv.Atoi(numStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
message, reportError := dataStore.GetMessage(user, num)
if reportError != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
}
dataStore.MarkMessageInQueueAsRead(user, num)
c.JSON(http.StatusOK, message)
}
func Run() {
//Open connection to DB
dataStore, err := OpenDB()
if err != nil {
log.Fatalln(err)
}
defer dataStore.db.Close()
//Read server keystore
keystorePassphrase := readStdin("Insert keystore passphrase")
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
if err != nil {
log.Fatalln(err)
}
r := gin.Default()
r.GET("/message/:user/:num", func(c *gin.Context) {
HandleGetMessage(c, dataStore)
})
r.GET("/queue/:user", func(c *gin.Context) {
HandleGetUnreadMsgsInfo(c, dataStore)
})
r.GET("/cert/:user", func(c *gin.Context) {
HandleGetUserCert(c, dataStore)
})
r.POST("/message/:user", func(c *gin.Context) {
HandleSendMessage(c, dataStore)
})
server := http.Server{
Addr: "0.0.0.0:8080",
Handler: r,
TLSConfig: serverKeyStore.GetTLSConfig(),
}
err = server.ListenAndServeTLS("", "")
if err!=nil {
log.Fatal(err.Error())
}
}

View file

@ -0,0 +1,283 @@
package cryptoUtils
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"time"
"log"
"os"
"golang.org/x/crypto/chacha20poly1305"
"software.sslmate.com/src/go-pkcs12"
)
type KeyStore struct {
cert *x509.Certificate
caCertChain []*x509.Certificate
privKey *rsa.PrivateKey
}
func (k KeyStore) GetCert() *x509.Certificate {
return k.cert
}
func (k KeyStore) GetCACertChain() []*x509.Certificate {
return k.caCertChain
}
func (k KeyStore) GetPrivKey() *rsa.PrivateKey {
return k.privKey
}
func ExtractAllOIDValues(cert *x509.Certificate) map[string]string {
oidValueMap := make(map[string]string)
for _, name := range cert.Subject.Names {
oid := name.Type.String()
value := name.Value.(string)
oidValueMap[oid] = value
}
return oidValueMap
}
func LoadKeyStore(keyStorePath string, password string) (KeyStore, error) {
var privKey *rsa.PrivateKey
keystoreBytes, err := os.ReadFile(keyStorePath)
if err != nil {
return KeyStore{}, err
}
privKeyInterface, cert, caCerts, err := pkcs12.DecodeChain(keystoreBytes, password)
if err != nil {
return KeyStore{}, err
}
privKey, ok := privKeyInterface.(*rsa.PrivateKey)
if !ok {
return KeyStore{}, err
}
if err := privKey.Validate(); err != nil {
return KeyStore{}, err
}
return KeyStore{cert: cert, caCertChain: caCerts, privKey: privKey}, err
}
// Check if the cert is signed by the CA and is for the correct user
func (k KeyStore) CheckCert(cert *x509.Certificate, uid string) error {
caCertPool := x509.NewCertPool()
for _, caCert := range k.caCertChain {
caCertPool.AddCert(caCert)
}
opts := x509.VerifyOptions{
Roots: caCertPool,
}
// Check if the certificate is signed by the specified CA
_, err := cert.Verify(opts)
if err != nil {
log.Println("Certificate not signed by a trusted CA")
return err
}
if cert.NotAfter.Before(time.Now()) {
return errors.New("certificate has expired")
}
if cert.NotBefore.After(time.Now()) {
return errors.New("certificate is not valid yet")
}
//Check if the pseudonym field is set to UID
oidMap := ExtractAllOIDValues(cert)
if oidMap["2.5.4.65"] != uid {
log.Println("Certificate does not belong to the message's receiver")
return err
}
return nil
}
func (k *KeyStore) GetTLSConfig() *tls.Config {
certificate := tls.Certificate{Certificate: [][]byte{k.cert.Raw}, PrivateKey: k.privKey, Leaf: k.cert}
//Add the CA certificate chain to a CertPool
caCertPool := x509.NewCertPool()
for _, caCert := range k.caCertChain {
caCertPool.AddCert(caCert)
}
config := &tls.Config{
Certificates: []tls.Certificate{certificate},
}
return config
}
func (k *KeyStore) GetServerTLSConfig() *tls.Config {
tlsConfig := k.GetTLSConfig()
//Add the CA certificate chain to a CertPool
caCertPool := x509.NewCertPool()
for _, caCert := range k.caCertChain {
caCertPool.AddCert(caCert)
}
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAnyClientCert
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Verify the peer's certificate
opts := x509.VerifyOptions{
Roots: caCertPool,
}
for _, certBytes := range rawCerts {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return err
}
if cert.NotAfter.Before(time.Now()) {
return errors.New("certificate has expired")
}
if cert.NotBefore.After(time.Now()) {
return errors.New("certificate is not valid yet")
}
// Check if the certificate is signed by the specified CA
_, err = cert.Verify(opts)
if err != nil {
return errors.New("certificate not signed by trusted CA")
}
}
return nil
}
return tlsConfig
}
func (k *KeyStore) GetClientTLSConfig() *tls.Config {
tlsConfig := k.GetTLSConfig()
//Add the CA certificate chain to a CertPool
caCertPool := x509.NewCertPool()
for _, caCert := range k.caCertChain {
caCertPool.AddCert(caCert)
}
tlsConfig.RootCAs = caCertPool
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Verify the peer's certificate
opts := x509.VerifyOptions{
Roots: caCertPool,
}
for _, certBytes := range rawCerts {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return err
}
if cert.NotAfter.Before(time.Now()) {
return errors.New("certificate has expired")
}
if cert.NotBefore.After(time.Now()) {
return errors.New("certificate is not valid yet")
}
oidMap := ExtractAllOIDValues(cert)
// Check if the certificate is signed by the specified CA
_, err = cert.Verify(opts)
if err != nil {
return errors.New("certificate not signed by trusted CA")
}
//Check if the pseudonym field is set to "SERVER"
if oidMap["2.5.4.65"] != "SERVER" {
return errors.New("peer isn't the server")
}
}
return nil
}
return tlsConfig
}
func (k KeyStore) EncryptMessageContent(receiverCert *x509.Certificate, content []byte) ([]byte, error) {
// Digital envolope
// Create a random symmetric key
dataKey := make([]byte, 32)
if _, err := rand.Read(dataKey); err != nil {
return nil, err
}
cipher, err := chacha20poly1305.New(dataKey)
if err != nil {
return nil, err
}
nonce := make([]byte, cipher.NonceSize(), cipher.NonceSize()+len(content)+cipher.Overhead())
if _, err = rand.Read(nonce); err != nil {
return nil, err
}
// sign the message and append the signature
hashedContent := sha256.Sum256(content)
signature, err := rsa.SignPKCS1v15(nil, k.privKey, crypto.SHA256, hashedContent[:])
if err != nil {
return nil, err
}
content = pair(signature, content)
ciphertext := cipher.Seal(nonce, nonce, content, nil)
receiverPubKey := receiverCert.PublicKey.(*rsa.PublicKey)
encryptedDataKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, receiverPubKey, dataKey, nil)
if err != nil {
return nil, err
}
return pair(encryptedDataKey, ciphertext), nil
}
func (k KeyStore) DecryptMessageContent(senderCert *x509.Certificate, cipherContent []byte) ([]byte, error) {
encryptedDataKey, encryptedMsg := unPair(cipherContent)
dataKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, k.GetPrivKey(), encryptedDataKey, nil)
if err != nil {
return nil, err
}
// decrypt ciphertext
cipher, err := chacha20poly1305.New(dataKey)
if err != nil {
return nil, err
}
nonce, ciphertext := encryptedMsg[:cipher.NonceSize()], encryptedMsg[cipher.NonceSize():]
contentAndSig, err := cipher.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
// check signature with sender public key
signature, content := unPair(contentAndSig)
hashedContent := sha256.Sum256(content)
senderKey := senderCert.PublicKey.(*rsa.PublicKey)
if err := rsa.VerifyPKCS1v15(senderKey, crypto.SHA256, hashedContent[:], signature); err != nil {
return nil, err
}
return content, nil
}
func pair(l []byte, r []byte) []byte {
length := len(l)
lenBytes := make([]byte, 2)
binary.BigEndian.PutUint16(lenBytes, uint16(length))
lWithLen := append(lenBytes, l...)
return append(lWithLen, r...)
}
func unPair(pair []byte) ([]byte, []byte) {
lenBytes := pair[:2]
pair = pair[2:]
length := binary.BigEndian.Uint16(lenBytes)
l := pair[:length]
r := pair[length:]
return l, r
}

View file

@ -0,0 +1,23 @@
package networking
import (
"crypto/tls"
)
type ClientTLSConfigProvider interface {
GetClientTLSConfig() *tls.Config
}
type Client[T any] struct {
Connection Connection[T]
}
func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) (Client[T],error) {
dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetClientTLSConfig())
if err != nil {
return Client[T]{},err
}
conn := NewConnection[T](dialConn)
return Client[T]{Connection: conn},nil
}

View file

@ -0,0 +1,51 @@
package networking
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"io"
"log"
)
type Connection[T any] struct {
Conn *tls.Conn
encoder *json.Encoder
decoder *json.Decoder
}
func NewConnection[T any](netConn *tls.Conn) Connection[T] {
return Connection[T]{
Conn: netConn,
encoder: json.NewEncoder(netConn),
decoder: json.NewDecoder(netConn),
}
}
func (c Connection[T]) Send(obj T) error {
if err := c.encoder.Encode(&obj); err!=nil {
if err == io.EOF {
log.Println("Connection closed by peer")
}
return err
}
//Return true as connection active
return nil
}
func (c Connection[T]) Receive() (*T, error) {
var obj T
if err := c.decoder.Decode(&obj); err != nil {
if err == io.EOF {
log.Println("Connection closed by peer")
}
return nil,err
}
//Return true as connection active
return &obj, nil
}
func (c Connection[T]) GetPeerCertificate() *x509.Certificate {
state := c.Conn.ConnectionState()
return state.PeerCertificates[0]
}

View file

@ -0,0 +1,56 @@
package networking
import (
"crypto/tls"
"log"
"net"
)
type ServerTLSConfigProvider interface {
GetServerTLSConfig() *tls.Config
}
type Server[T any] struct {
listener net.Listener
C chan Connection[T]
}
func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider) (Server[T], error) {
listener, err := tls.Listen("tcp", "127.0.0.1:8080", serverTLSConfigProvider.GetServerTLSConfig())
if err != nil {
return Server[T]{}, err
}
return Server[T]{
listener: listener,
C: make(chan Connection[T]),
}, nil
}
func (s *Server[T]) ListenLoop() {
for {
listenerConn, err := s.listener.Accept()
if err != nil {
log.Println("Server could not accept connection")
continue
}
tlsConn, ok := listenerConn.(*tls.Conn)
if !ok {
log.Println("Connection is not a TLS connection")
continue
}
if err := tlsConn.Handshake(); err != nil {
log.Println(err)
continue
}
state := tlsConn.ConnectionState()
if len(state.PeerCertificates) == 0 {
log.Println("Client did not provide a certificate")
continue
}
conn := NewConnection[T](tlsConn)
s.C <- conn
}
}