diff --git a/Projs/PD1/certs/CLI1.p12 b/Projs/PD1/certs/CLI1.p12 new file mode 100644 index 0000000..c070f92 Binary files /dev/null and b/Projs/PD1/certs/CLI1.p12 differ diff --git a/Projs/PD1/certs/serverdata.p12 b/Projs/PD1/certs/serverdata.p12 new file mode 100644 index 0000000..89b0823 Binary files /dev/null and b/Projs/PD1/certs/serverdata.p12 differ diff --git a/Projs/PD1/internal/client/client.go b/Projs/PD1/internal/client/client.go index 4268ca5..00aa35c 100644 --- a/Projs/PD1/internal/client/client.go +++ b/Projs/PD1/internal/client/client.go @@ -5,7 +5,6 @@ import ( "PD1/internal/utils/cryptoUtils" "PD1/internal/utils/networking" "flag" - "fmt" ) func Run() { @@ -16,6 +15,9 @@ func Run() { if flag.NArg() == 0 { panic("No command provided. Use 'help' for instructions.") } + //Get user KeyStore + password := AskUserPassword() + clientKeyStore := cryptoUtils.LoadKeyStore(userFile, password) command := flag.Arg(0) switch command { @@ -24,44 +26,41 @@ func Run() { panic("Insufficient arguments for 'send' command. Usage: send ") } uid := flag.Arg(1) - subject := flag.Arg(2) - messageContent := readMessageContent() + //subject := flag.Arg(2) + //messageContent := readMessageContent() - clientCert := cryptoUtils.LoadKeyStore("userdata.p12") + cl := networking.NewClient[protocol.Packet](&clientKeyStore) + defer cl.Connection.Conn.Close() - cl := networking.NewClient[protocol.Packet](clientCert) - defer cl.Connection.Conn.Close() + certRequestPacket := protocol.NewRequestUserCertPacket(uid) + cl.Connection.Send(certRequestPacket) + //certPacket := cl.Connection.Receive() - - certRequestPacket := protocol.NewRequestUserCertPacket(uid) - cl.Connection.Send(certRequestPacket) - certPacket := cl.Connection.Receive() - - // TODO: Encrypt message - submitMessage(cl,uid,cipherContent) + // TODO: Encrypt message + //submitMessage(cl, uid, cipherContent) case "askqueue": - cl := networking.NewClient[protocol.Packet]() - defer cl.Connection.Conn.Close() + cl := networking.NewClient[protocol.Packet](&clientKeyStore) + defer cl.Connection.Conn.Close() case "getmsg": if flag.NArg() < 2 { panic("Insufficient arguments for 'getmsg' command. Usage: getmsg ") } - num := flag.Arg(1) - cl := networking.NewClient[protocol.Packet]() - defer cl.Connection.Conn.Close() + //num := flag.Arg(1) + cl := networking.NewClient[protocol.Packet](&clientKeyStore) + defer cl.Connection.Conn.Close() case "help": - showHelp() + showHelp() default: - commandError() + commandError() } } -func submitMessage(cl networking.Client[protocol.Packet],uid string, content []byte) { - pack := protocol.NewSubmitMessage(uid,content) +func submitMessage(cl networking.Client[protocol.Packet], uid string, content []byte) { + pack := protocol.NewSubmitMessagePacket(uid, content) cl.Connection.Send(pack) } diff --git a/Projs/PD1/internal/client/interface.go b/Projs/PD1/internal/client/interface.go index 0e48357..02cee71 100644 --- a/Projs/PD1/internal/client/interface.go +++ b/Projs/PD1/internal/client/interface.go @@ -14,11 +14,8 @@ func readMessageContent() string { return scanner.Text() } -//FIX: Why is this function in the client if it's called by crypto? -// It should be called by the client and the result -// should then be passed into the crypto library func AskUserPassword() string { - fmt.Println("Enter message content (limited to 1000 bytes):") + fmt.Println("Enter key store password") scanner := bufio.NewScanner(os.Stdin) scanner.Scan() // FIX: make sure this doesnt die diff --git a/Projs/PD1/internal/protocol/protocol.go b/Projs/PD1/internal/protocol/protocol.go index be05e8d..1e08ad6 100644 --- a/Projs/PD1/internal/protocol/protocol.go +++ b/Projs/PD1/internal/protocol/protocol.go @@ -103,9 +103,9 @@ type ServerMessagePacket struct { Timestamp time.Time } -func NewMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet { +func NewServerMessagePacket(fromUID, toUID string, content []byte, timestamp time.Time) Packet { return Packet{ - Flag: Msg, + Flag: ServerMsgPkt, Body: ServerMessagePacket{ FromUID: fromUID, ToUID: toUID, diff --git a/Projs/PD1/internal/server/datastore.go b/Projs/PD1/internal/server/datastore.go index b92cba2..9513fdc 100644 --- a/Projs/PD1/internal/server/datastore.go +++ b/Projs/PD1/internal/server/datastore.go @@ -115,7 +115,7 @@ func (ds DataStore) GetAllMessages(toUID string) []protocol.Packet { if err := rows.Scan(&fromUID, &toUID, &content, ×tamp); err != nil { log.Panicln("Failed to scan row:", err) } - message := protocol.NewMessagePacket(fromUID, toUID, content, timestamp) + message := protocol.NewServerMessagePacket(fromUID, toUID, content, timestamp) messagePackets = append(messagePackets, message) } if err := rows.Err(); err != nil { diff --git a/Projs/PD1/internal/server/interface.go b/Projs/PD1/internal/server/interface.go new file mode 100644 index 0000000..f593c69 --- /dev/null +++ b/Projs/PD1/internal/server/interface.go @@ -0,0 +1,15 @@ +package server + +import ( + "bufio" + "fmt" + "os" +) + +func AskServerPassword() string { + fmt.Println("Enter key store password") + scanner := bufio.NewScanner(os.Stdin) + scanner.Scan() + // FIX: make sure this doesnt die + return scanner.Text() +} diff --git a/Projs/PD1/internal/server/server.go b/Projs/PD1/internal/server/server.go index 046b7a3..2d1fb93 100644 --- a/Projs/PD1/internal/server/server.go +++ b/Projs/PD1/internal/server/server.go @@ -10,15 +10,17 @@ import ( func clientHandler(connection networking.Connection[protocol.Packet], dataStore DataStore) { defer connection.Conn.Close() - // FIX: GET THE UID FROM THE USER CERTIFICATE FROM THE TLS SESSION - uid := "0" + clientCert := connection.GetPeerCertificate() + oidValueMap := cryptoUtils.ExtractAllOIDValues(clientCert) + fmt.Println(oidValueMap) + for { pac := connection.Receive() switch pac.Flag { case protocol.ReqUserCertPkt: - userCertPacket := dataStore.GetUserCertificate(uid) - connection.Send(userCertPacket) + //userCertPacket := dataStore.GetUserCertificate(uid) + //connection.Send(userCertPacket) case protocol.ReqAllMsgPkt: fmt.Println("ReqAllMsg") case protocol.ReqMsgPkt: @@ -35,13 +37,14 @@ func Run(port int) { dataStore := OpenDB() defer dataStore.db.Close() - //TODO: Get the server's keystore path instead of hardcoding it + //FIX: Get the server's keystore path instead of hardcoding it //Read server keystore - serverKeyStore := cryptoUtils.LoadKeyStore("serverdata.p12") + password := AskServerPassword() + serverKeyStore := cryptoUtils.LoadKeyStore("certs/serverdata.p12",password) //Create server listener - server := networking.NewServer[protocol.Packet](serverKeyStore,port) + server := networking.NewServer[protocol.Packet](&serverKeyStore,port) go server.ListenLoop() for { diff --git a/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go b/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go index ce91c51..48a5185 100644 --- a/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go +++ b/Projs/PD1/internal/utils/cryptoUtils/cryptoUtils.go @@ -1,7 +1,6 @@ package cryptoUtils import ( - "PD1/internal/client" "crypto/rsa" "crypto/tls" "crypto/x509" @@ -15,7 +14,7 @@ import ( type KeyStore struct { cert *x509.Certificate caCertChain []*x509.Certificate - privKey rsa.PrivateKey + privKey *rsa.PrivateKey } func (k KeyStore) GetCert() *x509.Certificate { @@ -26,67 +25,90 @@ func (k KeyStore) GetCACertChain() []*x509.Certificate { return k.caCertChain } -func (k KeyStore) GetPrivKey() rsa.PrivateKey { +func (k KeyStore) GetPrivKey() *rsa.PrivateKey { return k.privKey } -func LoadKeyStore(keyStorePath string) KeyStore { +func ExtractAllOIDValues(cert *x509.Certificate) map[string]string { + oidValueMap := make(map[string]string) + for _, name := range cert.Subject.Names { + oid := name.Type.String() + value := name.Value.(string) + oidValueMap[oid] = value + } + return oidValueMap +} - var privKey rsa.PrivateKey +func LoadKeyStore(keyStorePath string, password string) KeyStore { + + var privKey *rsa.PrivateKey certFile, err := os.ReadFile(keyStorePath) if err != nil { - log.Panicln("Provided certificate %v couldn't be opened", keyStorePath) + log.Panicln("Provided certificate couldn't be opened") } - password := client.AskUserPassword() privKeyInterface, cert, caCerts, err := pkcs12.DecodeChain(certFile, password) - privKey = privKeyInterface.(rsa.PrivateKey) if err != nil { log.Panicln("PKCS12 key store couldn't be decoded") } + privKey, ok := privKeyInterface.(*rsa.PrivateKey) + if !ok { + log.Panicln("Failed to convert private key to RSA private key") + } + if err := privKey.Validate(); err != nil { log.Panicln("Private key is not valid") } return KeyStore{cert: cert, caCertChain: caCerts, privKey: privKey} } -func (k KeyStore)GetTLSConfig() *tls.Config { - certificate ,err := tls.X509KeyPair(k.cert.Raw, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(&k.privKey)})) - if err!=nil{ - log.Panicln("Could not load certificate and privkey to TLS") - } - - //Add the CA certificate chain to a CertPool - caCertPool := x509.NewCertPool() +func (k *KeyStore) GetTLSConfig() *tls.Config { + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: k.cert.Raw}) + privKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k.privKey)}) + certificate, err := tls.X509KeyPair(certPEM, privKeyPEM) + if err != nil { + log.Panicln("Could not load certificate and privkey to TLS", err) + } + + config := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + } + return config +} + +func (k *KeyStore) GetTLSConfigServer() *tls.Config { + config := k.GetTLSConfig() + + //Add the CA certificate chain to a CertPool + caCertPool := x509.NewCertPool() for _, caCert := range k.caCertChain { caCertPool.AddCert(caCert) } - - config := &tls.Config{ - Certificates: []tls.Certificate{certificate}, - ClientCAs: caCertPool, - } - return config -} - -func (k KeyStore)GetTLSConfigServer() *tls.Config { - config := k.GetTLSConfig() - + config.ClientCAs = caCertPool config.ClientAuth = tls.RequireAndVerifyClientCert return config } -func (k KeyStore)GetTLSConfigClient() *tls.Config { - config:= k.GetTLSConfig() +func (k *KeyStore) GetTLSConfigClient() *tls.Config { + config := k.GetTLSConfig() - config.ServerName = "SERVER" + //Add the CA certificate chain to a CertPool + caCertPool := x509.NewCertPool() + for _, caCert := range k.caCertChain { + caCertPool.AddCert(caCert) + } + config.RootCAs = caCertPool + + //TODO: FIX THE VERIFICATION OF THE SERVER + //config.ServerName = "SERVER" return config } -func (k KeyStore)EncryptMessageContent(peerPubKey rsa.PublicKey, content []byte) []byte { +func (k KeyStore) EncryptMessageContent(peerPubKey rsa.PublicKey, content []byte) []byte { // Digital envolope - return nil + return nil } diff --git a/Projs/PD1/internal/utils/networking/client.go b/Projs/PD1/internal/utils/networking/client.go index 292c13a..11accd8 100644 --- a/Projs/PD1/internal/utils/networking/client.go +++ b/Projs/PD1/internal/utils/networking/client.go @@ -2,7 +2,7 @@ package networking import ( "crypto/tls" - "net" + "log" ) @@ -17,7 +17,7 @@ type Client[T any] struct { func NewClient[T any](clientTLSConfigProvider ClientTLSConfigProvider) Client[T] { dialConn, err := tls.Dial("tcp", "localhost:8080", clientTLSConfigProvider.GetTLSConfigClient()) if err != nil { - panic("Could not open connection to server") + log.Panicln("Could not open connection to server",err) } conn := NewConnection[T](dialConn) return Client[T]{Connection: conn} diff --git a/Projs/PD1/internal/utils/networking/connection.go b/Projs/PD1/internal/utils/networking/connection.go index d6ab641..e1bff4d 100644 --- a/Projs/PD1/internal/utils/networking/connection.go +++ b/Projs/PD1/internal/utils/networking/connection.go @@ -1,18 +1,18 @@ package networking import ( + "crypto/tls" + "crypto/x509" "encoding/json" - "net" ) type Connection[T any] struct { - Conn net.Conn + Conn *tls.Conn encoder *json.Encoder decoder *json.Decoder } - -func NewConnection[T any](netConn net.Conn) Connection[T] { +func NewConnection[T any](netConn *tls.Conn) Connection[T] { return Connection[T]{ Conn: netConn, encoder: json.NewEncoder(netConn), @@ -20,16 +20,21 @@ func NewConnection[T any](netConn net.Conn) Connection[T] { } } -func (jc Connection[T]) Send(obj T) { - if err := jc.encoder.Encode(&obj); err != nil { +func (c Connection[T]) Send(obj T) { + if err := c.encoder.Encode(&obj); err != nil { panic("Failed encoding data or sending it to connection") } } -func (jc Connection[T]) Receive() T { +func (c Connection[T]) Receive() T { var obj T - if err := jc.decoder.Decode(&obj); err != nil { + if err := c.decoder.Decode(&obj); err != nil { panic("Failed decoding data or reading it from connection") } return obj } + +func (c Connection[T]) GetPeerCertificate() *x509.Certificate { + state := c.Conn.ConnectionState() + return state.PeerCertificates[0] +} diff --git a/Projs/PD1/internal/utils/networking/server.go b/Projs/PD1/internal/utils/networking/server.go index 0d86cdd..d1c3f03 100644 --- a/Projs/PD1/internal/utils/networking/server.go +++ b/Projs/PD1/internal/utils/networking/server.go @@ -3,11 +3,12 @@ package networking import ( "crypto/tls" "fmt" + "log" "net" ) type ServerTLSConfigProvider interface { - GetServerTLSConfig() *tls.Config + GetTLSConfigServer() *tls.Config } type Server[T any] struct { @@ -15,16 +16,16 @@ type Server[T any] struct { C chan Connection[T] } -func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider,port int) Server[T]{ +func NewServer[T any](serverTLSConfigProvider ServerTLSConfigProvider, port int) Server[T] { - listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetServerTLSConfig()) + listener, err := tls.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port), serverTLSConfigProvider.GetTLSConfigServer()) if err != nil { panic("Server could not bind to address") } - return Server[T]{ - listener:listener, - C: make(chan Connection[T]), - } + return Server[T]{ + listener: listener, + C: make(chan Connection[T]), + } } func (s *Server[T]) ListenLoop() { @@ -34,7 +35,16 @@ func (s *Server[T]) ListenLoop() { if err != nil { panic("Server could not accept connection") } - conn := NewConnection[T](listenerConn) + tlsConn, ok := listenerConn.(*tls.Conn) + if !ok { + panic("Connection is not a TLS connection") + } + + state := tlsConn.ConnectionState() + if len(state.PeerCertificates) == 0 { + log.Panicln("Client did not provide a certificate") + } + conn := NewConnection[T](tlsConn) s.C <- conn } } diff --git a/Projs/PD1/tokefile.toml b/Projs/PD1/tokefile.toml index 61d046d..513abf6 100644 --- a/Projs/PD1/tokefile.toml +++ b/Projs/PD1/tokefile.toml @@ -15,4 +15,4 @@ cmd="go run ./cmd/server/server.go" [targets.client] deps=["check"] -cmd="go run ./cmd/client/client.go" +cmd="go run ./cmd/client/client.go -user certs/CLI1.p12 send CLI1 testsubject"