#!/usr/bin/env python3

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
import os
import argparse



def encrypt(input_file, password):
    inp = open(input_file,"rb")
    out = open(f"{input_file}.enc","wb")

    plaintext = inp.read()
    print(f"plaintext len : {len(plaintext)}")


    # Derive the key from the password using PBKDF2
    salt = os.urandom(16)
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        iterations=100000
    )

    key = kdf.derive(password.encode('utf-8'))
    iv = os.urandom(16)

    cipher = Cipher(algorithms.AES(key),modes.CTR(iv))
    encryptor = cipher.encryptor()
    ciphertext = encryptor.update(plaintext)
    
    ciphertext = salt + iv + ciphertext

    print(f"plaintext len : {len(plaintext)}")
    print(f"ciphertext len : {len(ciphertext)}")
    print(f"iv len : {len(iv)}")    

    
    out.write(ciphertext)

    inp.close()
    out.close()

def decrypt(input_file,password):
    inp = open(f"{input_file}","rb")
    out = open(f"{input_file}.dec","wb")

    input_bytes = inp.read()
    salt = input_bytes[:16]
    iv = input_bytes[16:32]
    ciphertext = input_bytes[32:]

    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        iterations=100000
    )

    print(f"plaintext len : {len(ciphertext)}")
    print(f"iv len : {len(iv)}")
    print(f"salt len : {len(salt)}")

    key = kdf.derive(password.encode('utf-8'))

    # FIX: block size for aes must be 16 bytes
    # plaintext needs padding
    cipher = Cipher(algorithms.AES(key),modes.CTR(iv))
    decryptor = cipher.decryptor()
    plaintext = decryptor.update(ciphertext)

    out.write(plaintext)

    inp.close()
    out.close()

def main():
    parser = argparse.ArgumentParser(
        description="Program to perform operations using AES cipher on files",
    )

    subparsers = parser.add_subparsers(dest="operation", help="Operation to perform")

    # Encrypt subcommand
    enc_parser = subparsers.add_parser("enc", help="Encrypt a file")
    enc_parser.add_argument("fich", help="File to be encrypted")
    enc_parser.add_argument("password", help="Pass-phrase to derive the key")

    # Decrypt subcommand
    dec_parser = subparsers.add_parser("dec", help="Decrypt a file")
    dec_parser.add_argument("fich", help="File to be decrypted")
    dec_parser.add_argument("password", help="Pass-phrase to derive the key")

    args = parser.parse_args()
    match args.operation:
        case "enc":
            input_file = args.fich
            password = args.password
            encrypt(input_file,password)
        case "dec":
            input_file = args.fich
            password = args.password
            decrypt(input_file,password)

if __name__ == "__main__":
    main()