package server import ( "PD1/internal/protocol" "crypto/x509" "database/sql" "fmt" "log" "time" _ "github.com/mattn/go-sqlite3" ) type DataStore struct { db *sql.DB } func OpenDB() DataStore { db, err := sql.Open("sqlite3", "server.db") if err != nil { log.Fatalln("Error opening db file") } ds := DataStore{db: db} ds.CreateTables() return ds } func (ds DataStore) CreateTables() error { // Create users table _, err := ds.db.Exec(`CREATE TABLE IF NOT EXISTS users ( UID TEXT PRIMARY KEY, userCert BLOB )`) if err != nil { fmt.Println("Error creating users table", err) return err } // Create messages table _, err = ds.db.Exec(`CREATE TABLE IF NOT EXISTS messages ( fromUID TEXT, toUID TEXT, timestamp TIMESTAMP, queue_position INT DEFAULT 0, subject BLOB, body BLOB, status INT CHECK (status IN (0,1)), PRIMARY KEY (toUID, fromUID, timestamp), FOREIGN KEY(fromUID) REFERENCES users(UID), FOREIGN KEY(toUID) REFERENCES users(UID) )`) if err != nil { fmt.Println("Error creating messages table", err) return err } // Define a trigger to automatically assign numbers for each message of each user starting from 1 _, err = ds.db.Exec(` CREATE TRIGGER IF NOT EXISTS assign_queue_position AFTER INSERT ON messages FOR EACH ROW BEGIN UPDATE messages SET queue_position = ( SELECT COUNT(*) FROM messages WHERE toUID = NEW.toUID ) WHERE toUID = NEW.toUID AND rowid = NEW.rowid; END; `) if err != nil { fmt.Println("Error creating trigger", err) return err } return nil } func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet { var serverMessage protocol.AnswerGetMsg query := ` SELECT fromUID, toUID, subject, body, timestamp FROM messages WHERE toUID = ? AND queue_position = ? ` // Execute the query row := ds.db.QueryRow(query, toUID, position) err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp) if err != nil { log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err) } return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true) } func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) { query := ` UPDATE messages SET status = 1 WHERE (fromUID,toUID,timestamp) = ( SELECT fromUID,toUID,timestamp FROM messages WHERE toUID = ? AND queue_position = ? ) ` // Execute the SQL statement _, err := ds.db.Exec(query, toUID, position) if err != nil { log.Printf("Error marking the message in position %v from UID %v as read: %v", position, toUID, err) } } func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet { // Retrieve the total count of unread messages var totalCount int err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount) if err != nil { log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err) } // Query to retrieve all messages from the user's queue query := ` SELECT fromUID, toUID, timestamp, queue_position, subject, status FROM messages WHERE toUID = ? AND status = 0 ORDER BY queue_position DESC LIMIT ? OFFSET ?; ` // Execute the query rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize) if err != nil { log.Printf("Error getting all messages for UID %v: %v", toUID, err) } defer rows.Close() messageInfoPackets := []protocol.MsgInfo{} for rows.Next() { var fromUID string var subject []byte var timestamp time.Time var queuePosition, status int if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil { panic(err) } answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp) messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo) } if err := rows.Err(); err != nil { log.Printf("Error when getting messages for UID %v: %v", toUID, err) } numberOfPages := (totalCount + pageSize - 1) / pageSize currentPage := min(numberOfPages, page) return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets) } func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) { query := ` INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status) VALUES (?, ?, ?, ?, ?, 0) ` // Execute the SQL statement currentTime := time.Now() _, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime) if err != nil { log.Printf("Error adding message to UID %v: %v", fromUID, err) } } func (ds DataStore) GetUserCertificate(uid string) protocol.Packet { query := ` SELECT userCert FROM users WHERE UID = ? ` // Execute the SQL query var userCertBytes []byte err := ds.db.QueryRow(query, uid).Scan(&userCertBytes) if err == sql.ErrNoRows { log.Panicf("No certificate for UID %v found in the database", uid) } //userCert,err := x509.ParseCertificate(userCertBytes) //if err!=nil { // log.Panicf("Error parsing certificate for UID %v",uid) //} return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes) } func (ds DataStore) userExists(uid string) bool { // Prepare the SQL statement for checking if a user exists query := ` SELECT COUNT(*) FROM users WHERE UID = ? ` var count int // Execute the SQL query err := ds.db.QueryRow(query, uid).Scan(&count) if err == sql.ErrNoRows { log.Printf("User with UID %v does not exist", uid) return false } else { return true } } func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) { // Check if the user already exists if ds.userExists(uid) { log.Printf("User certificate for UID %s already exists.\n", uid) return } // Insert the user certificate insertQuery := ` INSERT INTO users (UID, userCert) VALUES (?, ?) ` _, err := ds.db.Exec(insertQuery, uid, cert.Raw) if err != nil { log.Printf("Error storing user certificate for UID %s: %v\n", uid, err) return } log.Printf("User certificate for UID %s stored successfully.\n", uid) }