[PD1] small changes
This commit is contained in:
parent
568b6e6739
commit
2cafc3163c
10 changed files with 160 additions and 71 deletions
1
Projs/PD1/.ignore
Normal file
1
Projs/PD1/.ignore
Normal file
|
@ -0,0 +1 @@
|
||||||
|
certs
|
|
@ -20,7 +20,7 @@ func Run() {
|
||||||
panic("No command provided. Use 'help' for instructions.")
|
panic("No command provided. Use 'help' for instructions.")
|
||||||
}
|
}
|
||||||
//Get user KeyStore
|
//Get user KeyStore
|
||||||
password := AskUserPassword()
|
password := readStdin("Insert keystore passphrase")
|
||||||
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
|
clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password)
|
||||||
|
|
||||||
command := flag.Arg(0)
|
command := flag.Arg(0)
|
||||||
|
@ -49,7 +49,14 @@ func Run() {
|
||||||
if !cl.Connection.Send(sendMsgPacket) {
|
if !cl.Connection.Send(sendMsgPacket) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cl.Connection.Conn.Close()
|
answerSendMsg, active := cl.Connection.Receive()
|
||||||
|
if !active {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if answerSendMsg.Flag == protocol.FlagReportError {
|
||||||
|
reportError := protocol.UnmarshalReportError(answerSendMsg.Body)
|
||||||
|
log.Println(reportError.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
case "askqueue":
|
case "askqueue":
|
||||||
pageInput := flag.Arg(1)
|
pageInput := flag.Arg(1)
|
||||||
|
@ -69,7 +76,7 @@ func Run() {
|
||||||
|
|
||||||
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
cl := networking.NewClient[protocol.Packet](&clientKeyStore)
|
||||||
defer cl.Connection.Conn.Close()
|
defer cl.Connection.Conn.Close()
|
||||||
askQueue(cl,clientKeyStore, page, pageSize)
|
askQueue(cl, clientKeyStore, page, pageSize)
|
||||||
|
|
||||||
case "getmsg":
|
case "getmsg":
|
||||||
if flag.NArg() < 2 {
|
if flag.NArg() < 2 {
|
||||||
|
@ -89,6 +96,11 @@ func Run() {
|
||||||
if !active {
|
if !active {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if receivedMsgPacket.Flag == protocol.FlagReportError {
|
||||||
|
reportError := protocol.UnmarshalReportError(receivedMsgPacket.Body)
|
||||||
|
log.Println(reportError.ErrorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
answerGetMsg := protocol.UnmarshalAnswerGetMsg(receivedMsgPacket.Body)
|
||||||
senderCert := getUserCert(cl, answerGetMsg.FromUID)
|
senderCert := getUserCert(cl, answerGetMsg.FromUID)
|
||||||
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
decSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, answerGetMsg.Subject)
|
||||||
|
@ -117,6 +129,11 @@ func getUserCert(cl networking.Client[protocol.Packet], uid string) *x509.Certif
|
||||||
if !active {
|
if !active {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if answerGetUserCertPacket.Flag == protocol.FlagReportError {
|
||||||
|
reportError := protocol.UnmarshalReportError(answerGetUserCertPacket.Body)
|
||||||
|
log.Println(reportError.ErrorMessage)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
answerGetUserCert := protocol.UnmarshalAnswerGetUserCert(answerGetUserCertPacket.Body)
|
||||||
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
userCert, err := x509.ParseCertificate(answerGetUserCert.Certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -130,6 +147,11 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
|
||||||
if !active {
|
if !active {
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
||||||
}
|
}
|
||||||
|
if answerGetUnreadMsgsInfoPacket.Flag == protocol.FlagReportError {
|
||||||
|
reportError := protocol.UnmarshalReportError(answerGetUnreadMsgsInfoPacket.Body)
|
||||||
|
log.Println(reportError.ErrorMessage)
|
||||||
|
return protocol.NewAnswerGetUnreadMsgsInfo(0, 0, nil), nil
|
||||||
|
}
|
||||||
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
answerGetUnreadMsgsInfo := protocol.UnmarshalAnswerGetUnreadMsgsInfo(answerGetUnreadMsgsInfoPacket.Body)
|
||||||
|
|
||||||
//Create Set of needed certificates
|
//Create Set of needed certificates
|
||||||
|
@ -146,7 +168,7 @@ func getManyMessagesInfo(cl networking.Client[protocol.Packet]) (protocol.Answer
|
||||||
return answerGetUnreadMsgsInfo, certificatesMap
|
return answerGetUnreadMsgsInfo, certificatesMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
|
func askQueue(cl networking.Client[protocol.Packet], clientKeyStore cryptoUtils.KeyStore, page int, pageSize int) {
|
||||||
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
requestUnreadMsgsQueuePacket := protocol.NewGetUnreadMsgsInfoPacket(page, pageSize)
|
||||||
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
|
if !cl.Connection.Send(requestUnreadMsgsQueuePacket) {
|
||||||
return
|
return
|
||||||
|
@ -156,8 +178,13 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
|
||||||
for _, message := range unreadMsgsInfo.MessagesInfo {
|
for _, message := range unreadMsgsInfo.MessagesInfo {
|
||||||
senderCert, ok := certificates[message.FromUID]
|
senderCert, ok := certificates[message.FromUID]
|
||||||
if ok {
|
if ok {
|
||||||
|
var subject string
|
||||||
|
if senderCert != nil {
|
||||||
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
decryptedSubjectBytes := clientKeyStore.DecryptMessageContent(senderCert, message.Subject)
|
||||||
subject := Unmarshal(decryptedSubjectBytes)
|
subject = Unmarshal(decryptedSubjectBytes)
|
||||||
|
} else {
|
||||||
|
subject = ""
|
||||||
|
}
|
||||||
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
|
clientMessage := newClientMessageInfo(message.Num, message.FromUID, subject, message.Timestamp)
|
||||||
clientMessages = append(clientMessages, clientMessage)
|
clientMessages = append(clientMessages, clientMessage)
|
||||||
}
|
}
|
||||||
|
@ -170,10 +197,10 @@ func askQueue(cl networking.Client[protocol.Packet],clientKeyStore cryptoUtils.K
|
||||||
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
action := showMessagesInfo(unreadMsgsInfo.Page, unreadMsgsInfo.NumPages, clientMessages)
|
||||||
switch action {
|
switch action {
|
||||||
case -1:
|
case -1:
|
||||||
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page-1) , pageSize)
|
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page-1), pageSize)
|
||||||
case 0:
|
case 0:
|
||||||
return
|
return
|
||||||
case 1:
|
case 1:
|
||||||
askQueue(cl, clientKeyStore , max(1,unreadMsgsInfo.Page+1) , pageSize)
|
askQueue(cl, clientKeyStore, max(1, unreadMsgsInfo.Page+1), pageSize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,15 +11,6 @@ func readStdin(message string) string {
|
||||||
fmt.Println(message)
|
fmt.Println(message)
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
scanner.Scan()
|
scanner.Scan()
|
||||||
// FIX: make sure this doesnt die
|
|
||||||
return scanner.Text()
|
|
||||||
}
|
|
||||||
|
|
||||||
func AskUserPassword() string {
|
|
||||||
fmt.Println("Enter key store password")
|
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
|
||||||
scanner.Scan()
|
|
||||||
// FIX: make sure this doesnt die
|
|
||||||
return scanner.Text()
|
return scanner.Text()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,9 +34,13 @@ func showMessagesInfo(page int, numPages int, messages []ClientMessageInfo) int
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
|
if message.Subject == "" {
|
||||||
|
fmt.Printf("ERROR DECRYPTING MESSAGE %v IN QUEUE FROM UID %v\n", message.Num, message.FromUID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
|
fmt.Printf("%v:%v:%v:%v\n", message.Num, message.FromUID, message.Timestamp, message.Subject)
|
||||||
}
|
}
|
||||||
fmt.Printf("Page %v/%v\n",page,numPages)
|
fmt.Printf("Page %v/%v\n", page, numPages)
|
||||||
return messagesInfoPageNavigation(page, numPages)
|
return messagesInfoPageNavigation(page, numPages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,11 +84,9 @@ func messagesInfoPageNavigation(page int, numPages int) int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func showMessage(message ClientMessage) {
|
func showMessage(message ClientMessage) {
|
||||||
fmt.Printf("From: %s\n", message.FromUID)
|
fmt.Printf("From: %s\n", message.FromUID)
|
||||||
fmt.Printf("To: %s\n", message.ToUID)
|
fmt.Printf("To: %s\n", message.ToUID)
|
||||||
fmt.Printf("Subject: %s\n", message.Subject)
|
fmt.Printf("Subject: %s\n", message.Subject)
|
||||||
fmt.Printf("Body: %s\n", message.Body)
|
fmt.Printf("Body: %s\n", message.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,12 @@ const (
|
||||||
|
|
||||||
// Server sends requested message
|
// Server sends requested message
|
||||||
FlagAnswerGetMsg
|
FlagAnswerGetMsg
|
||||||
|
|
||||||
|
// Server tells the client that the message was successfully sent
|
||||||
|
FlagAnswerSendMsg
|
||||||
|
|
||||||
|
// Report an error
|
||||||
|
FlagReportError
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -76,6 +82,10 @@ type (
|
||||||
Body []byte `json:"body"`
|
Body []byte `json:"body"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ReportError struct {
|
||||||
|
ErrorMessage string `json:"error"`
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketBody interface{}
|
type PacketBody interface{}
|
||||||
|
@ -127,7 +137,7 @@ func NewAnswerGetUserCert(uid string, certificate []byte) AnswerGetUserCert {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
|
func NewAnswerGetUnreadMsgsInfo(page int, numPages int, messagesInfo []MsgInfo) AnswerGetUnreadMsgsInfo {
|
||||||
return AnswerGetUnreadMsgsInfo{Page:page,NumPages:numPages,MessagesInfo: messagesInfo}
|
return AnswerGetUnreadMsgsInfo{Page: page, NumPages: numPages, MessagesInfo: messagesInfo}
|
||||||
}
|
}
|
||||||
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
|
func NewMsgInfo(num int, fromUID string, subject []byte, timestamp time.Time) MsgInfo {
|
||||||
return MsgInfo{
|
return MsgInfo{
|
||||||
|
@ -148,6 +158,12 @@ func NewAnswerGetMsg(fromUID, toUID string, subject []byte, body []byte, timesta
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewReportError(errorMessage string) ReportError {
|
||||||
|
return ReportError{
|
||||||
|
ErrorMessage: errorMessage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewGetUserCertPacket(UID string) Packet {
|
func NewGetUserCertPacket(UID string) Packet {
|
||||||
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
|
return NewPacket(FlagGetUserCert, NewGetUserCert(UID))
|
||||||
}
|
}
|
||||||
|
@ -169,13 +185,22 @@ func NewAnswerGetUserCertPacket(uid string, certificate []byte) Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
|
func NewAnswerGetUnreadMsgsInfoPacket(page int, numPages int, messagesInfo []MsgInfo) Packet {
|
||||||
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page,numPages,messagesInfo))
|
return NewPacket(FlagAnswerGetUnreadMsgsInfo, NewAnswerGetUnreadMsgsInfo(page, numPages, messagesInfo))
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
|
func NewAnswerGetMsgPacket(fromUID, toUID string, subject []byte, body []byte, timestamp time.Time, last bool) Packet {
|
||||||
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
|
return NewPacket(FlagAnswerGetMsg, NewAnswerGetMsg(fromUID, toUID, subject, body, timestamp, last))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewAnswerSendMsgPacket() Packet{
|
||||||
|
//This packet has no body
|
||||||
|
return NewPacket(FlagAnswerSendMsg,nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReportErrorPacket(errorMessage string) Packet {
|
||||||
|
return NewPacket(FlagReportError, NewReportError(errorMessage))
|
||||||
|
}
|
||||||
|
|
||||||
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
|
func UnmarshalGetUserCert(data PacketBody) GetUserCert {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -270,3 +295,15 @@ func UnmarshalAnswerGetMsg(data PacketBody) AnswerGetMsg {
|
||||||
}
|
}
|
||||||
return packet
|
return packet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UnmarshalReportError(data PacketBody) ReportError {
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to marshal data: %v", err))
|
||||||
|
}
|
||||||
|
var packet ReportError
|
||||||
|
if err := json.Unmarshal(jsonData, &packet); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to unmarshal into AnswerGetMsg: %v", err))
|
||||||
|
}
|
||||||
|
return packet
|
||||||
|
}
|
||||||
|
|
|
@ -88,8 +88,10 @@ func (ds DataStore) GetMessage(toUID string, position int) protocol.Packet {
|
||||||
// Execute the query
|
// Execute the query
|
||||||
row := ds.db.QueryRow(query, toUID, position)
|
row := ds.db.QueryRow(query, toUID, position)
|
||||||
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
|
err := row.Scan(&serverMessage.FromUID, &serverMessage.ToUID, &serverMessage.Subject, &serverMessage.Body, &serverMessage.Timestamp)
|
||||||
if err != nil {
|
if err == sql.ErrNoRows {
|
||||||
log.Printf("Error getting the message in position %v from UID %v: %v", position, toUID, err)
|
log.Printf("No message with NUM %v for UID %v\n", position, toUID)
|
||||||
|
errorMessage := fmt.Sprintf("No message with NUM %v", position)
|
||||||
|
return protocol.NewReportErrorPacket(errorMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
|
return protocol.NewAnswerGetMsgPacket(serverMessage.FromUID, serverMessage.ToUID, serverMessage.Subject, serverMessage.Body, serverMessage.Timestamp, true)
|
||||||
|
@ -119,8 +121,9 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
||||||
// Retrieve the total count of unread messages
|
// Retrieve the total count of unread messages
|
||||||
var totalCount int
|
var totalCount int
|
||||||
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
err := ds.db.QueryRow("SELECT COUNT(*) FROM messages WHERE toUID = ? AND status = 0", toUID).Scan(&totalCount)
|
||||||
if err != nil {
|
if err == sql.ErrNoRows {
|
||||||
log.Printf("Error getting total count of unread messages for UID %v: %v", toUID, err)
|
log.Printf("No unread messages for UID %v: %v", toUID, err)
|
||||||
|
return protocol.NewAnswerGetUnreadMsgsInfoPacket(0, 0, []protocol.MsgInfo{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query to retrieve all messages from the user's queue
|
// Query to retrieve all messages from the user's queue
|
||||||
|
@ -143,7 +146,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
||||||
// Execute the query
|
// Execute the query
|
||||||
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
|
rows, err := ds.db.Query(query, toUID, pageSize, (page-1)*pageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error getting all messages for UID %v: %v", toUID, err)
|
log.Printf("Error getting unread messages for UID %v: %v", toUID, err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
|
@ -161,6 +164,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
log.Printf("Error when getting messages for UID %v: %v", toUID, err)
|
||||||
|
return protocol.NewReportErrorPacket(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
numberOfPages := (totalCount + pageSize - 1) / pageSize
|
||||||
|
@ -168,7 +172,7 @@ func (ds DataStore) GetUnreadMsgsInfo(toUID string, page int, pageSize int) prot
|
||||||
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
|
return protocol.NewAnswerGetUnreadMsgsInfoPacket(currentPage, numberOfPages, messageInfoPackets)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) {
|
func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg) protocol.Packet {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
|
INSERT INTO messages (fromUID, toUID, subject, body, timestamp, status)
|
||||||
VALUES (?, ?, ?, ?, ?, 0)
|
VALUES (?, ?, ?, ?, ?, 0)
|
||||||
|
@ -179,7 +183,9 @@ func (ds DataStore) AddMessageToQueue(fromUID string, message protocol.SendMsg)
|
||||||
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
|
_, err := ds.db.Exec(query, fromUID, message.ToUID, message.Subject, message.Body, currentTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error adding message to UID %v: %v", fromUID, err)
|
log.Printf("Error adding message to UID %v: %v", fromUID, err)
|
||||||
|
return protocol.NewReportErrorPacket(err.Error())
|
||||||
}
|
}
|
||||||
|
return protocol.NewAnswerSendMsgPacket()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
|
func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
|
||||||
|
@ -193,12 +199,10 @@ func (ds DataStore) GetUserCertificate(uid string) protocol.Packet {
|
||||||
var userCertBytes []byte
|
var userCertBytes []byte
|
||||||
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
|
err := ds.db.QueryRow(query, uid).Scan(&userCertBytes)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
log.Panicf("No certificate for UID %v found in the database", uid)
|
errorMessage := fmt.Sprintf("No certificate for UID %v found in the database", uid)
|
||||||
|
log.Println(errorMessage)
|
||||||
|
return protocol.NewReportErrorPacket(errorMessage)
|
||||||
}
|
}
|
||||||
//userCert,err := x509.ParseCertificate(userCertBytes)
|
|
||||||
//if err!=nil {
|
|
||||||
// log.Panicf("Error parsing certificate for UID %v",uid)
|
|
||||||
//}
|
|
||||||
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
|
return protocol.NewAnswerGetUserCertPacket(uid, userCertBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +228,6 @@ func (ds DataStore) userExists(uid string) bool {
|
||||||
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
|
func (ds DataStore) storeUserCertIfNotExists(uid string, cert x509.Certificate) {
|
||||||
// Check if the user already exists
|
// Check if the user already exists
|
||||||
if ds.userExists(uid) {
|
if ds.userExists(uid) {
|
||||||
log.Printf("User certificate for UID %s already exists.\n", uid)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,9 @@ import (
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AskServerPassword() string {
|
func readStdin(message string) string {
|
||||||
fmt.Println("Enter key store password")
|
fmt.Println(message)
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
scanner.Scan()
|
scanner.Scan()
|
||||||
// FIX: make sure this doesnt die
|
|
||||||
return scanner.Text()
|
return scanner.Text()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,10 @@ import (
|
||||||
"PD1/internal/protocol"
|
"PD1/internal/protocol"
|
||||||
"PD1/internal/utils/cryptoUtils"
|
"PD1/internal/utils/cryptoUtils"
|
||||||
"PD1/internal/utils/networking"
|
"PD1/internal/utils/networking"
|
||||||
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: CREATE SERVER SIDE CHECKS FOR EVERYTHING
|
|
||||||
//TODO: LOGGING SYSTEM
|
//TODO: LOGGING SYSTEM
|
||||||
//TODO: TELL THE USER THAT THE MESSAGE HAS BEEN RECEIVED BY THE SERVER
|
|
||||||
//TODO: ERROR PACKET TO SEND BACK TO USER
|
|
||||||
|
|
||||||
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
|
func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) {
|
||||||
defer connection.Conn.Close()
|
defer connection.Conn.Close()
|
||||||
|
@ -18,10 +16,16 @@ func clientHandler(connection networking.Connection[protocol.Packet], dataStore
|
||||||
clientCert := connection.GetPeerCertificate()
|
clientCert := connection.GetPeerCertificate()
|
||||||
//Get the OID values
|
//Get the OID values
|
||||||
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
|
oidMap := cryptoUtils.ExtractAllOIDValues(clientCert)
|
||||||
|
//Check if certificate usage is MSG SERVICE
|
||||||
|
usage := oidMap["2.5.4.11"]
|
||||||
|
if usage == "" {
|
||||||
|
log.Println("User certificate does not have the correct usage")
|
||||||
|
return
|
||||||
|
}
|
||||||
//Get the UID of this user
|
//Get the UID of this user
|
||||||
UID := oidMap["2.5.4.65"]
|
UID := oidMap["2.5.4.65"]
|
||||||
if UID == "" {
|
if UID == "" {
|
||||||
panic("User certificate does not specify it's PSEUDONYM")
|
log.Println("User certificate does not specify it's PSEUDONYM")
|
||||||
}
|
}
|
||||||
dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
dataStore.storeUserCertIfNotExists(UID, *clientCert)
|
||||||
F:
|
F:
|
||||||
|
@ -34,31 +38,47 @@ F:
|
||||||
case protocol.FlagGetUserCert:
|
case protocol.FlagGetUserCert:
|
||||||
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
|
reqUserCert := protocol.UnmarshalGetUserCert(pac.Body)
|
||||||
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
userCertPacket := dataStore.GetUserCertificate(reqUserCert.UID)
|
||||||
if active := connection.Send(userCertPacket); !active {
|
if !connection.Send(userCertPacket) {
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
case protocol.FlagGetUnreadMsgsInfo:
|
case protocol.FlagGetUnreadMsgsInfo:
|
||||||
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
getUnreadMsgsInfo := protocol.UnmarshalGetUnreadMsgsInfo(pac.Body)
|
||||||
messages := dataStore.GetUnreadMsgsInfo(UID,getUnreadMsgsInfo.Page,getUnreadMsgsInfo.PageSize)
|
var messages protocol.Packet
|
||||||
|
if getUnreadMsgsInfo.Page <= 0 || getUnreadMsgsInfo.PageSize <= 0 {
|
||||||
|
messages = protocol.NewReportErrorPacket("Page and PageSize need to be >= 1")
|
||||||
|
} else {
|
||||||
|
messages = dataStore.GetUnreadMsgsInfo(UID, getUnreadMsgsInfo.Page, getUnreadMsgsInfo.PageSize)
|
||||||
|
}
|
||||||
if !connection.Send(messages) {
|
if !connection.Send(messages) {
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
case protocol.FlagGetMsg:
|
case protocol.FlagGetMsg:
|
||||||
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
|
reqMsg := protocol.UnmarshalGetMsg(pac.Body)
|
||||||
message := dataStore.GetMessage(UID, reqMsg.Num)
|
var message protocol.Packet
|
||||||
if active := connection.Send(message); !active {
|
if reqMsg.Num <= 0 {
|
||||||
|
message = protocol.NewReportErrorPacket("Message NUM needs to be >= 1")
|
||||||
|
} else {
|
||||||
|
message = dataStore.GetMessage(UID, reqMsg.Num)
|
||||||
|
}
|
||||||
|
if !connection.Send(message) {
|
||||||
break F
|
break F
|
||||||
}
|
}
|
||||||
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
dataStore.MarkMessageInQueueAsRead(UID, reqMsg.Num)
|
||||||
case protocol.FlagSendMsg:
|
case protocol.FlagSendMsg:
|
||||||
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
|
submitMsg := protocol.UnmarshalSendMsg(pac.Body)
|
||||||
if submitMsg.ToUID != UID && dataStore.userExists(submitMsg.ToUID) {
|
var answerSendMsgPacket protocol.Packet
|
||||||
dataStore.AddMessageToQueue(UID, submitMsg)
|
if submitMsg.ToUID == UID {
|
||||||
|
answerSendMsgPacket = protocol.NewReportErrorPacket("Cannot message yourself")
|
||||||
|
} else if !dataStore.userExists(submitMsg.ToUID) {
|
||||||
|
answerSendMsgPacket = protocol.NewReportErrorPacket("Message receiver does not exist in database")
|
||||||
|
} else {
|
||||||
|
answerSendMsgPacket = dataStore.AddMessageToQueue(UID, submitMsg)
|
||||||
|
}
|
||||||
|
if !connection.Send(answerSendMsgPacket) {
|
||||||
|
break F
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run(port int) {
|
func Run(port int) {
|
||||||
|
@ -69,7 +89,7 @@ func Run(port int) {
|
||||||
//FIX: Get the server's keystore path instead of hardcoding it
|
//FIX: Get the server's keystore path instead of hardcoding it
|
||||||
|
|
||||||
//Read server keystore
|
//Read server keystore
|
||||||
password := AskServerPassword()
|
password := readStdin("Insert keystore passphrase")
|
||||||
serverKeyStore := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
|
serverKeyStore := cryptoUtils.LoadKeyStore("certs/server/server.p12", password)
|
||||||
|
|
||||||
//Create server listener
|
//Create server listener
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
//"errors"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -94,9 +93,25 @@ func (k *KeyStore) GetServerTLSConfig() *tls.Config {
|
||||||
caCertPool.AddCert(caCert)
|
caCertPool.AddCert(caCert)
|
||||||
}
|
}
|
||||||
tlsConfig.ClientCAs = caCertPool
|
tlsConfig.ClientCAs = caCertPool
|
||||||
//FIX: SERVER ACCEPTS CONNECTIONS WITH UNMATCHING OR
|
tlsConfig.ClientAuth = tls.RequestClientCert
|
||||||
// NO CERTIFICATE, NEEDS TO BE CHANGED SOMEHOW
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
// Verify the peer's certificate
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: caCertPool,
|
||||||
|
}
|
||||||
|
for _, certBytes := range rawCerts {
|
||||||
|
cert, err := x509.ParseCertificate(certBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Check if the certificate is signed by the specified CA
|
||||||
|
_, err = cert.Verify(opts)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("certificate not signed by trusted CA")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return tlsConfig
|
return tlsConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -12,14 +12,8 @@ cmd="go build"
|
||||||
[targets.server]
|
[targets.server]
|
||||||
cmd="go run ./cmd/server/server.go"
|
cmd="go run ./cmd/server/server.go"
|
||||||
|
|
||||||
[targets.client1]
|
[targets.queue]
|
||||||
cmd="go run ./cmd/client/client.go -user certs/client1/client1.p12 send CL2 testsubject"
|
cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 askqueue"
|
||||||
|
|
||||||
[targets.FakeClient1]
|
[targets.send]
|
||||||
cmd="go run ./cmd/client/client.go -user certs/FakeClient1/client1.p12 send CL2 testsubject"
|
cmd="go run ./cmd/client/client.go -user certs/client${NUM}/client${NUM}.p12 send ${DEST} ${SUBJECT}"
|
||||||
|
|
||||||
[targets.client2]
|
|
||||||
cmd="go run ./cmd/client/client.go -user certs/client2/client2.p12 send CL3 testsubject"
|
|
||||||
|
|
||||||
[targets.client3]
|
|
||||||
cmd="go run ./cmd/client/client.go -user certs/client3/client3.p12 send CL1 testsubject"
|
|
||||||
|
|
Loading…
Reference in a new issue