from socket import create_connection
from ssl import CERT_REQUIRED
import ssl
import threading
import sys

hostname = "example.org"
ip = "127.0.0.1"
port = 8443

client_cert = "client/client.crt"
client_key = "client/client.key"
server_cert = "server/server.crt"

context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)

context.minimum_version = ssl.TLSVersion.TLSv1_3
context.maximum_version = ssl.TLSVersion.TLSv1_3

# A biblioteca ssl não suporta editar as ciphers de TLS1.3
# context.set_ciphers("TLS_CHACHA20_POLY1305_SHA256")

context.load_cert_chain(certfile=client_cert, keyfile=client_key)
context.load_verify_locations(cafile="./CA/CA.pem")
context.verify_mode = CERT_REQUIRED


def receive_messages(tls: ssl.SSLSocket):
    try:
        while True:
            data = tls.recv(1024)
            if not data:
                break
            print(f"\rServer: {data.decode()}\n> ", end="")
            sys.stdout.flush()
    except KeyboardInterrupt:
        return


with create_connection((ip, port)) as client:
    with context.wrap_socket(
        client, server_side=False, server_hostname=hostname
    ) as tls:
        print(f"Using the following TLS1.3 cipher: ", str(tls.cipher()))

        threading.Thread(target=receive_messages, args=(tls,)).start()

        try:
            while True:
                message = input("> ")
                print("\033[A                             \033[A")
                print(f"Client: {message}\n", end="")
                tls.sendall(message.encode())
                if message.lower() == "exit":
                    tls.close()
                    break
        except KeyboardInterrupt:
            tls.close()
            pass