package cryptoUtils

import (
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/tls"
	"crypto/x509"
	"encoding/binary"
	"errors"
	"time"

	"os"

	"golang.org/x/crypto/chacha20poly1305"
	"software.sslmate.com/src/go-pkcs12"
)

type KeyStore struct {
	cert       *x509.Certificate
	caCertPool *x509.CertPool
	privKey    *rsa.PrivateKey
}

func (k KeyStore) GetCert() *x509.Certificate {
	return k.cert
}

func (k KeyStore) GetPrivKey() *rsa.PrivateKey {
	return k.privKey
}

func ExtractAllOIDValues(cert *x509.Certificate) map[string]string {
	oidValueMap := make(map[string]string)
	for _, name := range cert.Subject.Names {
		oid := name.Type.String()
		value := name.Value.(string)
		oidValueMap[oid] = value
	}
	return oidValueMap
}

func LoadKeyStore(keyStorePath string, password string) (KeyStore, error) {

	var privKey *rsa.PrivateKey

	keystoreBytes, err := os.ReadFile(keyStorePath)
	if err != nil {
		return KeyStore{}, err
	}

	privKeyInterface, cert, caCerts, err := pkcs12.DecodeChain(keystoreBytes, password)
	if err != nil {
		return KeyStore{}, err
	}
	privKey, ok := privKeyInterface.(*rsa.PrivateKey)
	if !ok {
		return KeyStore{}, err
	}

	if err := privKey.Validate(); err != nil {
		return KeyStore{}, err
	}

	caCertPool := x509.NewCertPool()
	for _, caCert := range caCerts {
		caCertPool.AddCert(caCert)
	}

	return KeyStore{cert: cert, caCertPool: caCertPool, privKey: privKey}, err
}

// Check if the cert is signed by a known CA
func (k KeyStore) CheckCertCA(cert *x509.Certificate) error {
	// Verify the peer's certificate
	opts := x509.VerifyOptions{
		Roots: k.caCertPool,
	}
	// Check if the certificate is signed by the specified CA
	_, err := cert.Verify(opts)
	if err != nil {
		return errors.New("certificate not signed by trusted CA")
	}
	return nil
}

// Check if the cert is valid
func (k KeyStore) CheckCertTime(cert *x509.Certificate) error {
	if cert.NotAfter.Before(time.Now()) {
		return errors.New("certificate has expired")
	}
	if cert.NotBefore.After(time.Now()) {
		return errors.New("certificate is not valid yet")
	}

	return nil
}

// Check if the pseudonym field is set to the correct pseudonym
func (k KeyStore) CheckCertPseudonym(cert *x509.Certificate, pseudonym string) error {
	oidMap := ExtractAllOIDValues(cert)
	if oidMap["2.5.4.65"] != pseudonym {
		return errors.New("Certificate does not belong to the correct pseudonym")
	}
	return nil
}

func (k KeyStore) CheckCertUsage(cert *x509.Certificate, usage string) error {
	oidMap := ExtractAllOIDValues(cert)
	if oidMap["2.5.4.11"] != usage {
		return errors.New("Certificate does not have the correct usage")
	}
	return nil
}

func (k KeyStore) CheckCert(cert *x509.Certificate, pseudonym string, usage string) error {
	if err := k.CheckCertCA(cert); err != nil {
		return err
	}
	if err := k.CheckCertTime(cert); err != nil {
		return err
	}
	if err := k.CheckCertPseudonym(cert, pseudonym); err != nil {
		return err
	}
	if err := k.CheckCertUsage(cert, usage); err != nil {
		return err
	}
	return nil
}

func (k *KeyStore) GetTLSConfig() *tls.Config {

	certificate := tls.Certificate{Certificate: [][]byte{k.cert.Raw}, PrivateKey: k.privKey, Leaf: k.cert}

	config := &tls.Config{
		Certificates: []tls.Certificate{certificate},
	}
	return config
}
func (k *KeyStore) GetGatewayIncomingTLSConfig() *tls.Config {
	tlsConfig := k.GetTLSConfig()

	tlsConfig.ClientAuth = tls.NoClientCert
	return tlsConfig
}

func (k *KeyStore) GetGatewayOutgoingTLSConfig() *tls.Config {
	tlsConfig := k.GetTLSConfig()

	tlsConfig.RootCAs = k.caCertPool
	tlsConfig.InsecureSkipVerify = true
	tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
		for _, certBytes := range rawCerts {
			cert, err := x509.ParseCertificate(certBytes)
			if err != nil {
				return err
			}
			if err = k.CheckCertCA(cert); err != nil {
				return err
			}
			if err = k.CheckCertTime(cert); err != nil {
				return err
			}
			if err = k.CheckCertPseudonym(cert, "SERVER"); err != nil {
				return err
			}
			if err = k.CheckCertUsage(cert, "MSG SERVICE"); err != nil {
				return err
			}
		}
		return nil
	}
	return tlsConfig
}

func (k *KeyStore) GetClientTLSConfig() *tls.Config {
	tlsConfig := k.GetTLSConfig()

	tlsConfig.RootCAs = k.caCertPool
	tlsConfig.InsecureSkipVerify = true
	tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
		for _, certBytes := range rawCerts {
			cert, err := x509.ParseCertificate(certBytes)
			if err != nil {
				return err
			}
			if err = k.CheckCertCA(cert); err != nil {
				return err
			}
			if err = k.CheckCertTime(cert); err != nil {
				return err
			}
			if err = k.CheckCertPseudonym(cert, "GATEWAY"); err != nil {
				return err
			}
			if err = k.CheckCertUsage(cert, "MSG SERVICE"); err != nil {
				return err
			}

		}
		return nil
	}
	return tlsConfig
}

func (k *KeyStore) GetServerTLSConfig() *tls.Config {
	tlsConfig := k.GetTLSConfig()

	tlsConfig.ClientAuth = tls.RequireAnyClientCert
	tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
		for _, certBytes := range rawCerts {
			cert, err := x509.ParseCertificate(certBytes)
			if err != nil {
				return err
			}
			if err = k.CheckCertCA(cert); err != nil {
				return err
			}
			if err = k.CheckCertTime(cert); err != nil {
				return err
			}
			if err = k.CheckCertPseudonym(cert, "GATEWAY"); err != nil {
				return err
			}
			if err = k.CheckCertUsage(cert, "MSG SERVICE"); err != nil {
				return err
			}
		}
		return nil
	}
	return tlsConfig
}

func (k KeyStore) EncryptMessageContent(receiverCert *x509.Certificate, recieverId string, content []byte) ([]byte, error) {
	// Digital envolope
	// Create a random symmetric key
	dataKey := make([]byte, 32)
	if _, err := rand.Read(dataKey); err != nil {
		return nil, err
	}

	cipher, err := chacha20poly1305.New(dataKey)
	if err != nil {
		return nil, err
	}

	nonce := make([]byte, cipher.NonceSize(), cipher.NonceSize()+len(content)+cipher.Overhead())
	if _, err = rand.Read(nonce); err != nil {
		return nil, err
	}

	// sign the message and append the signature
	hashedContent := sha256.Sum256(append(content, []byte(recieverId)...))
	signature, err := rsa.SignPKCS1v15(nil, k.privKey, crypto.SHA256, hashedContent[:])
	if err != nil {
		return nil, err
	}
	content = pair(signature, content)
	ciphertext := cipher.Seal(nonce, nonce, content, nil)

	receiverPubKey := receiverCert.PublicKey.(*rsa.PublicKey)
	encryptedDataKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, receiverPubKey, dataKey, nil)
	if err != nil {
		return nil, err
	}
	return pair(encryptedDataKey, ciphertext), nil
}

func (k KeyStore) DecryptMessageContent(senderCert *x509.Certificate, recieverId string, cipherContent []byte) ([]byte, error) {
	encryptedDataKey, encryptedMsg := unPair(cipherContent)
	dataKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, k.GetPrivKey(), encryptedDataKey, nil)
	if err != nil {
		return nil, err
	}
	// decrypt ciphertext
	cipher, err := chacha20poly1305.New(dataKey)
	if err != nil {
		return nil, err
	}

	nonce, ciphertext := encryptedMsg[:cipher.NonceSize()], encryptedMsg[cipher.NonceSize():]
	contentAndSig, err := cipher.Open(nil, nonce, ciphertext, nil)
	if err != nil {
		return nil, err
	}
	// check signature with sender public key
	signature, content := unPair(contentAndSig)
	hashedContent := sha256.Sum256(append(content, []byte(recieverId)...))
	senderKey := senderCert.PublicKey.(*rsa.PublicKey)
	if err := rsa.VerifyPKCS1v15(senderKey, crypto.SHA256, hashedContent[:], signature); err != nil {
		return nil, err
	}
	return content, nil
}

func pair(l []byte, r []byte) []byte {
	length := len(l)
	lenBytes := make([]byte, 2)
	binary.BigEndian.PutUint16(lenBytes, uint16(length))

	lWithLen := append(lenBytes, l...)
	return append(lWithLen, r...)
}

func unPair(pair []byte) ([]byte, []byte) {
	lenBytes := pair[:2]
	pair = pair[2:]
	length := binary.BigEndian.Uint16(lenBytes)
	l := pair[:length]
	r := pair[length:]
	return l, r
}