import os
import sys
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.exceptions import InvalidSignature
import argparse
import datetime

def mkpair(x, y):
    """ produz uma byte-string contendo o tuplo '(x,y)' ('x' e 'y' são byte-strings) """
    len_x = len(x)
    len_x_bytes = len_x.to_bytes(2, 'little')
    return len_x_bytes + x + y

def unpair(xy):
    """ extrai componentes de um par codificado com 'mkpair' """
    len_x = int.from_bytes(xy[:2], 'little')
    x = xy[2:len_x+2]
    y = xy[len_x+2:]
    return x, y

def cert_load(fname):
   """ lê certificado de ficheiro """
   with open(fname, "rb") as fcert:
        cert = x509.load_pem_x509_certificate(fcert.read())
   return cert

def cert_validtime(cert, now=None):
    """ valida que 'now' se encontra no período
    de validade do certificado. """
    if now is None:
        now = datetime.datetime.now(tz=datetime.timezone.utc)
    if now < cert.not_valid_before_utc or now > cert.not_valid_after_utc:
        raise x509.verification.VerificationError("Certificate is not valid at this time")

def cert_validsubject(cert, attrs=[]):
    """ verifica atributos do campo 'subject'. 'attrs'
    é uma lista de pares '(attr,value)' que condiciona
    os valores de 'attr' a 'value'. """
    print(cert.subject)
    for attr in attrs:
        if cert.subject.get_attributes_for_oid(attr[0])[0].value != attr[1]:
            raise x509.verification.VerificationError("Certificate subject does not match expected value")    

def cert_validexts(cert, policy=[]):
    """ valida extensões do certificado. 'policy' é uma lista de pares '(ext,pred)' onde 'ext' é o OID de uma extensão e 'pred'
    o predicado responsável por verificar o conteúdo dessa extensão. """
    for check in policy:
        ext = cert.extensions.get_extension_for_oid(check[0]).value
        if not check[1](ext):
            raise x509.verification.VerificationError("Certificate extensions does not match expected value")

def valida_cert(cert, ca_cert, user):
    try:
        ca_cert.public_key().verify(
            cert.signature,
            cert.tbs_certificate_bytes,
            padding.PKCS1v15(), 
            cert.signature_hash_algorithm
        )
    except InvalidSignature:
        raise x509.verification.VerificationError("Certificate signature is invalid")
    
    cert_validtime(cert)

    cert_validsubject(cert, [(NameOID.COMMON_NAME, user)])

    

"""
sign <user> <fich> -- em que assina o conteúdo de <fich> usando a chave privada armazenada em <user>.key. 
Deve produzir o ficheiro <fich>.sig contendo o par composto pela assinatura e certificado do assinante;
"""
def sign(user, filename):

    with open(user + ".key", "rb") as key_file:
        private_key = serialization.load_pem_private_key(key_file.read(), password=b'1234', backend=default_backend())

    user_cert = cert_load(user + ".crt")

    with open(filename, "rb") as file:
        data_to_sign = file.read()



    signature = private_key.sign(data_to_sign,
                                 padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                                    salt_length=padding.PSS.MAX_LENGTH), 
                                 hashes.SHA256())

     
    signature_and_cert = mkpair(signature, user_cert.public_bytes(serialization.Encoding.PEM))
    
    with open(filename + ".sig", "wb") as sig_file:
        sig_file.write(signature_and_cert)



"""
verify <fich>-- verifica a assinatura contida em <fich>.sig usando a 
informação do signatário contida no certificado (também incluído em <fich>.sig). 
Deve apresentar o status de validade da assinatura (Válida/Inválida) e, 
no caso de ser válida, ainda os dados do signatário.
"""
def verify(filename, user):

    with open(filename + ".sig", "rb") as sig_file:
        signature_and_cert = sig_file.read()

    signature, cert_bytes = unpair(signature_and_cert)

    cert = x509.load_pem_x509_certificate(cert_bytes, default_backend())

    ca_cert = cert_load("EC.crt")

    try:
        valida_cert(cert, ca_cert, user)

        with open(filename, "rb") as file:
            data_to_verify = file.read()

        cert.public_key().verify(signature, data_to_verify,
                                    padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                                        salt_length=padding.PSS.MAX_LENGTH), 
                                    hashes.SHA256())
        
        print("Valid signature")
        print(cert.subject)

    except x509.verification.VerificationError as e:
        print("Invalid signature: " + str(e))
        return

def main():
    parser = argparse.ArgumentParser(description="Sign or verify files using X.509 certificates.")
    parser.add_argument("command", choices=["sign", "verify"], help="Command to execute: 'sign' or 'verify'")
    parser.add_argument("user", help="Username")
    parser.add_argument("filename", help="File name")
    parser.add_argument("--test", action="store_true", help="Perform testing (simulate errors)")

    args = parser.parse_args()

    match args.command:
        case "sign" if os.path.exists(args.user + ".crt") and os.path.exists(args.user + ".key"):
            sign(args.user, args.filename)
            print("File signed successfully.")
        case "sign":
            print("User certificate or private key not found.")
        case "verify" if os.path.exists(args.filename + ".sig"):
            verify(args.filename, args.user)
        case "verify":
            print("Signature file not found.")
        case _:
            parser.print_help()
            sys.exit(1)



if __name__ == "__main__":
    main()