[PD1] Errors handling project-wide
This commit is contained in:
parent
f5b3726673
commit
64791174b4
13 changed files with 364 additions and 245 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"PD1/internal/protocol"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
@ -15,14 +16,17 @@ type DataStore struct {
|
|||
db *sql.DB
|
||||
}
|
||||
|
||||
func OpenDB() DataStore {
|
||||
func OpenDB() (DataStore, error) {
|
||||
db, err := sql.Open("sqlite3", "server.db")
|
||||
if err != nil {
|
||||
log.Fatalln("Error opening db file")
|
||||
return DataStore{}, err
|
||||
}
|
||||
ds := DataStore{db: db}
|
||||
ds.CreateTables()
|
||||
return ds
|
||||
err = ds.CreateTables()
|
||||
if err != nil {
|
||||
return DataStore{}, err
|
||||
}
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
func (ds DataStore) CreateTables() error {
|
||||
|
@ -32,7 +36,6 @@ func (ds DataStore) CreateTables() error {
|
|||
userCert BLOB
|
||||
)`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating users table", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -50,7 +53,6 @@ func (ds DataStore) CreateTables() error {
|
|||
FOREIGN KEY(toUID) REFERENCES users(UID)
|
||||
)`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating messages table", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -70,7 +72,6 @@ func (ds DataStore) CreateTables() error {
|
|||
END;
|
||||
`)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating trigger", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -91,7 +92,7 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
|
|||
if err == sql.ErrNoRows {
|
||||
log.Printf("No message with NUM %v for UID %v\n", position, toUID)
|
||||
errorMessage := fmt.Sprintf("No message with NUM %v", position)
|
||||
return protocol.NewReportErrorPacket(errorMessage)
|
||||
return protocol.NewReportErrorPacket(errorMessage)
|
||||
}
|
||||
|
||||
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
|
||||
|
@ -116,14 +117,13 @@ func (ds DataStore) MarkMessageInQueueAsRead(toUID string, position int) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) protocol.Packet {
|
||||
func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) (protocol.Packet, error) {
|
||||
|
||||
// Retrieve the total count of unread messages
|
||||
var totalCount int
|
||||
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
||||
if err == sql.ErrNoRows {
|
||||
log.Printf("No unread messages for UID %v: %v", toUID, err)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{}), nil
|
||||
}
|
||||
|
||||
// Query to retrieve all messages from the user's queue
|
||||
|
@ -157,19 +157,18 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
|||
var timestamp time.Time
|
||||
var queuePosition, status int
|
||||
if err := rows.Scan(&fromUID, &toUID, ×tamp, &queuePosition, &subject, &status); err != nil {
|
||||
panic(err)
|
||||
return protocol.Packet{}, err
|
||||
}
|
||||
answerGetUnreadMsgsInfo := protocol.NewMsgInfo(queuePosition, fromUID, subject, timestamp)
|
||||
messageInfoPackets = append(messageInfoPackets, answerGetUnreadMsgsInfo)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
||||
return protocol.NewReportErrorPacket(err.Error())
|
||||
return protocol.Packet{}, err
|
||||
}
|
||||
|
||||
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
||||
currentPage := min(numberOfPages, page)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
|
||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets), nil
|
||||
}
|
||||
|
||||
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
|
||||
|
@ -218,17 +217,16 @@ func (ds DataStore) userExists(uid string) bool {
|
|||
// Execute the SQL query
|
||||
err := ds.db.QueryRow(query, uid).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
log.Printf("User with UID %v does not exist", uid)
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
log.Println("user with UID %v does not exist", uid)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
|
||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) error {
|
||||
// Check if the user already exists
|
||||
if ds.userExists(uid) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert the user certificate
|
||||
|
@ -238,8 +236,8 @@ func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate)
|
|||
`
|
||||
_, err := ds.db.Exec(insertQuery, uid, cert.Raw)
|
||||
if err != nil {
|
||||
log.Printf("Error storing user certificate for UID %s: %v\n", uid, err)
|
||||
return
|
||||
return errors.New(fmt.Sprintf("Error storing user certificate for UID %s: %v\n", uid, err))
|
||||
}
|
||||
log.Printf("User certificate for UID %s stored successfully.\n", uid)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package server
|
|||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
|
@ -13,7 +12,3 @@ func readStdin(message string) string {
|
|||
scanner.Scan()
|
||||
return scanner.Text()
|
||||
}
|
||||
|
||||
func LogFatal(err error) {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
|
|
@ -19,84 +19,111 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
|
|||
//Check if certificate usage is MSG SERVICE
|
||||
usage := oidMap["2.5.4.11"]
|
||||
if usage == "" {
|
||||
log.Println("User certificate does not have the correct usage")
|
||||
return
|
||||
log.Fatalln("User certificate does not have the correct usage")
|
||||
}
|
||||
//Get the UID of this user
|
||||
UID := oidMap["2.5.4.65"]
|
||||
if UID == "" {
|
||||
log.Println("User certificate does not specify it's PSEUDONYM")
|
||||
log.Fatalln("User certificate does not specify it's PSEUDONYM")
|
||||
}
|
||||
err := dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||
F:
|
||||
for {
|
||||
pac, active := connection.Receive()
|
||||
if !active {
|
||||
pac, err := connection.Receive()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch pac.Flag {
|
||||
case protocol.FlagGetUserCert:
|
||||
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
|
||||
reqUserCert, err := protocol.UnmarshalGetUserCert(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
||||
if !connection.Send(userCertPacket) {
|
||||
if err := connection.Send(userCertPacket); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
|
||||
case protocol.FlagGetUnreadMsgsInfo:
|
||||
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||
getUnreadMsgsInfo, err := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var messages protocol.Packet
|
||||
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
|
||||
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
|
||||
} else {
|
||||
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||
messages, err = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
if !connection.Send(messages) {
|
||||
break F
|
||||
if err := connection.Send(messages); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
case protocol.FlagGetMsg:
|
||||
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
|
||||
reqMsg, err := protocol.UnmarshalGetMsg(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var message protocol.Packet
|
||||
if reqMsg.Num <= 0 {
|
||||
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
|
||||
} else {
|
||||
message = dataStore.GetMessage(UID, reqMsg.Num)
|
||||
}
|
||||
if !connection.Send(message) {
|
||||
if err := connection.Send(message); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
||||
|
||||
case protocol.FlagSendMsg:
|
||||
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
|
||||
submitMsg, err := protocol.UnmarshalSendMsg(pac.Body)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var answerSendMsgPacket protocol.Packet
|
||||
if submitMsg.ToUID == UID {
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message sender and receiver cannot be the same user")
|
||||
} else if !dataStore.userExists(submitMsg.ToUID) {
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
|
||||
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist")
|
||||
} else {
|
||||
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
|
||||
}
|
||||
if !connection.Send(answerSendMsgPacket) {
|
||||
if err := connection.Send(answerSendMsgPacket); err != nil {
|
||||
log.Fatalln(err)
|
||||
break F
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Run(port int) {
|
||||
func Run() {
|
||||
//Open connection to DB
|
||||
dataStore := OpenDB()
|
||||
dataStore, err := OpenDB()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
defer dataStore.db.Close()
|
||||
|
||||
//FIX: Get the server's keystore path instead of hardcoding it
|
||||
|
||||
//Read server keystore
|
||||
password := readStdin("Insert keystore passphrase")
|
||||
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
|
||||
keystorePassphrase := readStdin("Insert keystore passphrase")
|
||||
serverKeyStore, err := cryptoUtils.LoadKeyStore("certs/server/server.p12", keystorePassphrase)
|
||||
if err != nil {
|
||||
LogFatal(err)
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
//Create server listener
|
||||
server := networking.NewServer[protocol.Packet](&serverKeyStore, port)
|
||||
server, err := networking.NewServer[protocol.Packet](&serverKeyStore)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
go server.ListenLoop()
|
||||
|
||||
for {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue