import os
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import argparse

def setup(key_file):
    key = AESGCM.generate_key(bit_length=128)
    with open(key_file, "wb") as f:
        f.write(key)


def encrypt(input_file, key_file):
    with open(input_file, "rb") as f:
        plaintext = f.read()

    with open(key_file, "rb") as f:
        key = f.read()

    aad = b"authenticated but unencrypted data"
    aesgcm = AESGCM(key)
    nonce = os.urandom(12)

    ct = aesgcm.encrypt(nonce, plaintext, aad)

    with open(f"{input_file}.enc", "wb") as f:
        f.write(nonce)
        f.write(ct)


def decrypt(input_file, key_file):
    with open(input_file, "rb") as f:
        nonce = f.read(12)
        ct = f.read()

    with open(key_file, "rb") as f:
        key = f.read()

    aad = b"authenticated but unencrypted data"
    aesgcm = AESGCM(key)
    pt = aesgcm.decrypt(nonce, ct, aad)

    with open(f"{input_file}.dec", "wb") as f:
        f.write(pt)



def main():
    parser = argparse.ArgumentParser(
        description="Program to perform operations using AES-GCM cipher on files",
    )

    subparsers = parser.add_subparsers(dest="operation", help="Operation to perform")

    # Encrypt subcommand
    enc_parser = subparsers.add_parser("enc", help="Encrypt a file")
    enc_parser.add_argument("fich", help="File to be encrypted")
    enc_parser.add_argument("password", help="Pass-phrase to derive the key")

    # Decrypt subcommand
    dec_parser = subparsers.add_parser("dec", help="Decrypt a file")
    dec_parser.add_argument("fich", help="File to be decrypted")
    dec_parser.add_argument("password", help="Pass-phrase to derive the key")

    args = parser.parse_args()
    match args.operation:
        case "enc":
            input_file = args.fich
            password = args.password
            encrypt(input_file,password)
        case "dec":
            input_file = args.fich
            password = args.password
            decrypt(input_file,password)

if __name__ == "__main__":
    main()