from socket import socket, AF_INET, SOCK_STREAM
from ssl import CERT_REQUIRED
import ssl
import sys
import threading


ip = "127.0.0.1"
port = 8443

client_cert = "client/client.crt"
server_cert = "server/server.crt"
server_key = "server/server.key"

context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)

context.minimum_version = ssl.TLSVersion.TLSv1_3
context.maximum_version = ssl.TLSVersion.TLSv1_3

# A biblioteca ssl não suporta editar as ciphers de TLS1.3
# context.set_ciphers("TLS_CHACHA20_POLY1305_SHA256")

context.load_cert_chain(certfile=server_cert, keyfile=server_key)
context.load_verify_locations(cafile="./CA/CA.pem")
context.verify_mode = CERT_REQUIRED


def handle_client(connection: ssl.SSLSocket, address):
    try:
        print("\033[92mClient connected\033[0m: ", address)
        print(f"Using the following TLS1.3 cipher: ", str(connection.cipher()))

        while True:
            message = connection.recv(1024)
            if not message:
                break
            if message.decode().lower() == "exit":
                connection.close()
                return
            print(f"\rClient: {message.decode()}\n> ", end="")
            sys.stdout.flush()
    except KeyboardInterrupt:
        connection.close()
        return


with socket(AF_INET, SOCK_STREAM) as server:
    server.bind((ip, port))
    server.listen(5)
    while True:
        client_socket, client_address = server.accept()
        ssl_socket = context.wrap_socket(client_socket, server_side=True)
        threading.Thread(
            target=handle_client, args=(ssl_socket, client_address)
        ).start()

        try:
            while True:
                server_message = input("> ")
                print("\033[A                             \033[A")
                ssl_socket.sendall(server_message.encode())
                print(f"Server: {server_message}\n", end="")
                if server_message.lower() == "exit":
                    ssl_socket.close()
                    client_socket.close()
                    break
        except KeyboardInterrupt:
            ssl_socket.close()
            client_socket.close()
            pass
        except Exception:
            print("\033[91mClient disconnected\033[0m")
            pass