[PD2] Server done?

Co-authored-by: tsousa111 <tiagao2001@hotmail.com>
This commit is contained in:
Afonso Franco 2024-05-28 20:08:25 +01:00
parent 08a73a4f76
commit 49a29e43a7
Signed by: afonso
SSH key fingerprint: SHA256:PQTRDHPH3yALEGtHXnXBp3Orfcn21pK20t0tS1kHg54
66 changed files with 2777 additions and 5 deletions

View file

@ -0,0 +1,245 @@
package server
import (
"PD1/internal/protocol"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"log"
"time"
_ "github.com/mattn/go-sqlite3"
)
type DataStore struct {
db *sql.DB
}
func OpenDB() (DataStore, error) {
db, err := sql.Open("sqlite3", "server.db")
if err != nil {
return DataStore{}, err
}
ds := DataStore{db: db}
err = ds.CreateTables()
if err != nil {
return DataStore{}, err
}
return ds, nil
}
func (ds DataStore) CreateTables() error {
// Create users table
_, err := ds.db.Exec(`CREATE TABLE IF NOT EXISTS users (
UID TEXT PRIMARY KEY,
userCert BLOB
)`)
if err != nil {
return err
}
// Create messages table
_, err = ds.db.Exec(`CREATE TABLE IF NOT EXISTS messages (
fromUID TEXT,
toUID TEXT,
timestamp TIMESTAMP,
queue_position INT DEFAULT 0,
subject BLOB,
body BLOB,
status INT CHECK (status IN (0,1)),
PRIMARY KEY (toUID, fromUID, timestamp),
FOREIGN KEY(fromUID) REFERENCES users(UID),
FOREIGN KEY(toUID) REFERENCES users(UID)
)`)
if err != nil {
return err
}
// Define a trigger to automatically assign numbers for each message of each user starting from 1
_, err = ds.db.Exec(`
CREATE TRIGGER IF NOT EXISTS assign_queue_position
AFTER INSERT ON messages
FOR EACH ROW
BEGIN
UPDATE messages
SET queue_position = (
SELECT COUNT(*)
FROM messages
WHERE toUID = NEW.toUID
)
WHERE toUID = NEW.toUID AND rowid = NEW.rowid;
END;
`)
if err != nil {
return err
}
return nil
}
func (ds DataStore) GetMessage(toUID string, position int) (*protocol.AnswerGetMsg, error) {
var serverMessage protocol.AnswerGetMsg
query := `
SELECT fromUID, toUID, subject, body, timestamp
FROM messages
WHERE toUID = ? AND queue_position = ?
`
// Execute the query
row := ds.db.QueryRow(query, toUID, position)
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
if err == sql.ErrNoRows {
log.Printf("No message with NUM %v for UID %v\n", position, toUID)
errorMessage := fmt.Sprintln("MSG SERVICE: unknown message!")
error := errors.New(errorMessage)
return nil, error
}
answer := protocol.NewAnswerGetMsg(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
return &answer, nil
}
func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
query := `
UPDATE messages
SET status = 1
WHERE (fromUID,toUID,timestamp) = (
SELECT fromUID,toUID,timestamp
FROM messages
WHERE toUID = ? AND queue_position = ?
)
`
// Execute the SQL statement
_, err := ds.db.Exec(query, toUID, position)
if err != nil {
log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err)
}
}
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.AnswerGetUnreadMsgsInfo, error) {
// Retrieve the total count of unread messages
var totalCount int
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
if err == sql.ErrNoRows {
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, []protocol.MsgInfo{}), nil
}
// Query to retrieve all messages from the user's queue
query := `
SELECT
fromUID,
toUID,
timestamp,
queue_position,
subject,
status
FROM messages
WHERE
toUID = ? AND status = 0
ORDER BY
queue_position DESC
LIMIT ? OFFSET ?;
`
// Execute the query
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
if err != nil {
log.Printf("Error getting unread messages for UID %v: %v", toUID, err)
}
defer rows.Close()
messageInfoPackets := []protocol.MsgInfo{}
for rows.Next() {
var fromUID string
var subject []byte
var timestamp time.Time
var queuePosition, status int
if err := rows.Scan(&fromUID, &toUID, &timestamp, &queuePosition, &subject, &status); err != nil {
return protocol.AnswerGetUnreadMsgsInfo{}, err
}
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
}
if err := rows.Err(); err != nil {
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
return protocol.AnswerGetUnreadMsgsInfo{}, err
}
numberOfPages := (totalCount + pageSize - 1) / pageSize
currentPage := min(numberOfPages, page)
return protocol.NewAnswerGetUnreadMsgsInfo(currentPage, numberOfPages, messageInfoPackets), nil
}
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) error {
query := `
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
VALUES (?, ?, ?, ?, ?, 0)
`
// Execute the SQL statement
currentTime := time.Now()
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
if err != nil {
log.Printf("Error adding message to UID %v: %v", fromUID, err)
return err
}
return nil
}
func (ds DataStore) GetUserCertificate(uid string) (protocol.AnswerGetUserCert,error) {
query := `
SELECT userCert
FROM users
WHERE UID = ?
`
// Execute the SQL query
var userCertBytes []byte
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
if err == sql.ErrNoRows {
errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid)
log.Println(errorMessage)
return protocol.AnswerGetUserCert{},errors.New(errorMessage)
}
return protocol.NewAnswerGetUserCert(uid, userCertBytes),nil
}
func (ds DataStore) userExists(uid string) bool {
// Prepare the SQL statement for checking if a user exists
query := `
SELECT COUNT(*)
FROM users
WHERE UID = ?
`
var count int
// Execute the SQL query
err := ds.db.QueryRow(query, uid).Scan(&count)
if err != nil || count == 0 {
log.Printf("user with UID %v does not exist\n", uid)
return false
}
return true
}
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
// Check if the user already exists
if ds.userExists(uid) {
return nil
}
// Insert the user certificate
insertQuery := `
INSERT INTO users (UID, userCert)
VALUES (?, ?)
`
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
if err != nil {
return fmt.Errorf("error storing user certificate for UID %s: %v", uid, err)
}
log.Printf("User certificate for UID %s stored successfully.\n", uid)
return nil
}

View file

@ -0,0 +1,14 @@
package server
import (
"bufio"
"fmt"
"os"
)
func readStdin(message string) string {
fmt.Println(message)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
return scanner.Text()
}

View file

@ -0,0 +1,151 @@
package server
import (
"PD1/internal/protocol"
"PD1/internal/utils/cryptoUtils"
"log"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
//func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
// defer connection.Conn.Close()
//
// //Get certificate sent by user
// clientCert := connection.GetPeerCertificate()
// //Get the OID values
// oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
// //Check if certificate usage is MSG SERVICE
// usage := oidMap["2.5.4.11"]
// if usage == "" {
// log.Fatalln("User certificate does not have the correct usage")
// }
// //Get the UID of this user
// UID := oidMap["2.5.4.65"]
// if UID == "" {
// log.Fatalln("User certificate does not specify it's PSEUDONYM")
// }
// err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
// if err != nil {
// log.Fatalln(err)
// }
//}
func HandleGetUserCert(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
userCertPacket, err := dataStore.GetUserCertificate(user)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
} else {
c.JSON(http.StatusOK, userCertPacket)
}
}
func HandleGetUnreadMsgsInfo(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
var getUnreadMsgsInfo protocol.GetUnreadMsgsInfo
if err := c.BindJSON(getUnreadMsgsInfo); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Page and PageSize need to be >= 1"})
return
}
unreadMsgsInfo, err := dataStore.GetUnreadMsgsInfo(user, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, unreadMsgsInfo)
}
func HandleSendMessage(c *gin.Context, dataStore DataStore) {
sender := c.Param("user")
var message protocol.SendMsg
if err := c.BindJSON(message); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if message.ToUID == sender {
c.JSON(http.StatusBadRequest, gin.H{"error": "Message sender and receiver cannot be the same user"})
return
}
if !dataStore.userExists(message.ToUID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Message receiver does not exist"})
return
}
err := dataStore.AddMessageToQueue(sender, message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
}
func HandleGetMessage(c *gin.Context, dataStore DataStore) {
user := c.Param("user")
numStr := c.Param("num")
num, err := strconv.Atoi(numStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
message, reportError := dataStore.GetMessage(user, num)
if reportError != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
}
dataStore.MarkMessageInQueueAsRead(user, num)
c.JSON(http.StatusOK, message)
}
func Run() {
//Open connection to DB
dataStore, err := OpenDB()
if err != nil {
log.Fatalln(err)
}
defer dataStore.db.Close()
//Read server keystore
keystorePassphrase := readStdin("Insert keystore passphrase")
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
if err != nil {
log.Fatalln(err)
}
r := gin.Default()
r.GET("/message/:user/:num", func(c *gin.Context) {
HandleGetMessage(c, dataStore)
})
r.GET("/queue/:user", func(c *gin.Context) {
HandleGetUnreadMsgsInfo(c, dataStore)
})
r.GET("/cert/:user", func(c *gin.Context) {
HandleGetUserCert(c, dataStore)
})
r.POST("/message/:user", func(c *gin.Context) {
HandleSendMessage(c, dataStore)
})
server := http.Server{
Addr: "0.0.0.0:8080",
Handler: r,
TLSConfig: serverKeyStore.GetTLSConfig(),
}
err = server.ListenAndServeTLS("", "")
if err!=nil {
log.Fatal(err.Error())
}
}