package server import ( "PD2/internal/protocol" "crypto/x509" "database/sql" "errors" "fmt" "log" "time" _ "github.com/mattn/go-sqlite3" ) type DataStore struct { db *sql.DB } func OpenDB() (DataStore, error) { db, err := sql.Open("sqlite3", "server.db") if err != nil { return DataStore{}, err } ds := DataStore{db: db} err = ds.CreateTables() if err != nil { return DataStore{}, err } return ds, nil } func (ds DataStore) CreateTables() error { // Create users table _, err := ds.db.Exec(`CREATE TABLE IF NOT EXISTS users ( UID TEXT PRIMARY KEY, userCert BLOB )`) if err != nil { return err } // Create messages table _, err = ds.db.Exec(`CREATE TABLE IF NOT EXISTS messages ( fromUID TEXT, toUID TEXT, timestamp TIMESTAMP, queue_position INT DEFAULT 0, subject BLOB, body BLOB, status INT CHECK (status IN (0,1)), PRIMARY KEY (toUID, fromUID, timestamp), FOREIGN KEY(fromUID) REFERENCES users(UID), FOREIGN KEY(toUID) REFERENCES users(UID) )`) if err != nil { return err } // Define a trigger to automatically assign numbers for each message of each user starting from 1 _, err = ds.db.Exec(` CREATE TRIGGER IF NOT EXISTS assign_queue_position AFTER INSERT ON messages FOR EACH ROW BEGIN UPDATE messages SET queue_position = ( SELECT COUNT(*) FROM messages WHERE toUID = NEW.toUID ) WHERE toUID = NEW.toUID AND rowid = NEW.rowid; END; `) if err != nil { return err } return nil } func (ds DataStore) GetMessage(toUID string, position int) (*protocol.AnswerGetMsg, error) { var serverMessage protocol.AnswerGetMsg query := ` SELECT fromUID, toUID, subject, body, timestamp FROM messages WHERE toUID = ? AND queue_position = ? ` // Execute the query row := ds.db.QueryRow(query, toUID, position) err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp) if err == sql.ErrNoRows { log.Printf("No message with NUM %v for UID %v\n", position, toUID) errorMessage := fmt.Sprintln("MSG SERVICE: unknown message!") error := errors.New(errorMessage) return nil, error } answer := protocol.NewAnswerGetMsg(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) return &answer, nil } func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { query := ` UPDATE messages SET status = 1 WHERE (fromUID,toUID,timestamp) = ( SELECT fromUID,toUID,timestamp FROM messages WHERE toUID = ? AND queue_position = ? ) ` // Execute the SQL statement _, err := ds.db.Exec(query, toUID, position) if err != nil { log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err) } } func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.AnswerGetUnreadMsgsInfo, error) { // Retrieve the total count of unread messages var totalCount int err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) if err == sql.ErrNoRows { return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, []protocol.MsgInfo{}), nil } // Query to retrieve all messages from the user's queue query := ` SELECT fromUID, toUID, timestamp, queue_position, subject, status FROM messages WHERE toUID = ? AND status = 0 ORDER BY queue_position DESC LIMIT ? OFFSET ?; ` // Execute the query rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize) if err != nil { log.Printf("Error getting unread messages for UID %v: %v", toUID, err) } defer rows.Close() messageInfoPackets := []protocol.MsgInfo{} for rows.Next() { var fromUID string var subject []byte var timestamp time.Time var queuePosition, status int if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil { return protocol.AnswerGetUnreadMsgsInfo{}, err } answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp) messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo) } if err := rows.Err(); err != nil { log.Printf("Error when getting messages for UID %v: %v", toUID, err) return protocol.AnswerGetUnreadMsgsInfo{}, err } numberOfPages := (totalCount + pageSize - 1) / pageSize currentPage := min(numberOfPages, page) return protocol.NewAnswerGetUnreadMsgsInfo(currentPage, numberOfPages, messageInfoPackets), nil } func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) error { query := ` INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) VALUES (?, ?, ?, ?, ?, 0) ` // Execute the SQL statement currentTime := time.Now() _, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime) if err != nil { log.Printf("Error adding message to UID %v: %v", fromUID, err) return err } return nil } func (ds DataStore) GetUserCertificate(uid string) (protocol.AnswerGetUserCert,error) { query := ` SELECT userCert FROM users WHERE UID = ? ` // Execute the SQL query var userCertBytes []byte err := ds.db.QueryRow(query, uid).Scan(&userCertBytes) if err == sql.ErrNoRows { errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid) log.Println(errorMessage) return protocol.AnswerGetUserCert{},errors.New(errorMessage) } return protocol.NewAnswerGetUserCert(uid, userCertBytes),nil } func (ds DataStore) userExists(uid string) bool { // Prepare the SQL statement for checking if a user exists query := ` SELECT COUNT(*) FROM users WHERE UID = ? ` var count int // Execute the SQL query err := ds.db.QueryRow(query, uid).Scan(&count) if err != nil || count == 0 { log.Printf("user with UID %v does not exist\n", uid) return false } return true } func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error { // Insert the user certificate insertQuery := ` INSERT INTO users (UID, userCert) VALUES (?, ?) ` _, err := ds.db.Exec(insertQuery, uid, cert.Raw) if err != nil { return fmt.Errorf("error storing user certificate for UID %s: %v", uid, err) } log.Printf("User certificate for UID %s stored successfully.\n", uid) return nil }