diff --git a/daemon/core/api/grpc/client.py b/daemon/core/api/grpc/client.py index 6098db25..a2641e87 100644 --- a/daemon/core/api/grpc/client.py +++ b/daemon/core/api/grpc/client.py @@ -5,6 +5,7 @@ gRpc client for interfacing with CORE, when gRPC mode is enabled. import logging import threading from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, List import grpc import netaddr @@ -18,7 +19,7 @@ class InterfaceHelper: Convenience class to help generate IP4 and IP6 addresses for gRPC clients. """ - def __init__(self, ip4_prefix=None, ip6_prefix=None): + def __init__(self, ip4_prefix: str = None, ip6_prefix: str = None) -> None: """ Creates an InterfaceHelper object. @@ -36,7 +37,7 @@ class InterfaceHelper: if ip6_prefix: self.ip6 = netaddr.IPNetwork(ip6_prefix) - def ip4_address(self, node_id): + def ip4_address(self, node_id: int) -> str: """ Convenience method to return the IP4 address for a node. @@ -48,7 +49,7 @@ class InterfaceHelper: raise ValueError("ip4 prefixes have not been set") return str(self.ip4[node_id]) - def ip6_address(self, node_id): + def ip6_address(self, node_id: int) -> str: """ Convenience method to return the IP6 address for a node. @@ -60,15 +61,18 @@ class InterfaceHelper: raise ValueError("ip6 prefixes have not been set") return str(self.ip6[node_id]) - def create_interface(self, node_id, interface_id, name=None, mac=None): + def create_interface( + self, node_id: int, interface_id: int, name: str = None, mac: str = None + ) -> core_pb2.Interface: """ - Creates interface data for linking nodes, using the nodes unique id for generation, along with a random - mac address, unless provided. + Creates interface data for linking nodes, using the nodes unique id for + generation, along with a random mac address, unless provided. :param int node_id: node id to create interface for :param int interface_id: interface id for interface :param str name: name to set for interface, default is eth{id} - :param str mac: mac address to use for this interface, default is random generation + :param str mac: mac address to use for this interface, default is random + generation :return: new interface data for the provided node :rtype: core_pb2.Interface """ @@ -101,7 +105,7 @@ class InterfaceHelper: ) -def stream_listener(stream, handler): +def stream_listener(stream: Any, handler: Callable[[core_pb2.Event], None]) -> None: """ Listen for stream events and provide them to the handler. @@ -119,7 +123,7 @@ def stream_listener(stream, handler): logging.exception("stream error") -def start_streamer(stream, handler): +def start_streamer(stream: Any, handler: Callable[[core_pb2.Event], None]) -> None: """ Convenience method for starting a grpc stream thread for handling streamed events. @@ -137,7 +141,7 @@ class CoreGrpcClient: Provides convenience methods for interfacing with the CORE grpc server. """ - def __init__(self, address="localhost:50051", proxy=False): + def __init__(self, address: str = "localhost:50051", proxy: bool = False) -> None: """ Creates a CoreGrpcClient instance. @@ -150,19 +154,19 @@ class CoreGrpcClient: def start_session( self, - session_id, - nodes, - links, - location=None, - hooks=None, - emane_config=None, - emane_model_configs=None, - wlan_configs=None, - mobility_configs=None, - service_configs=None, - service_file_configs=None, - asymmetric_links=None, - ): + session_id: int, + nodes: List[core_pb2.Node], + links: List[core_pb2.Link], + location: core_pb2.SessionLocation = None, + hooks: List[core_pb2.Hook] = None, + emane_config: Dict[str, str] = None, + emane_model_configs: List[core_pb2.EmaneModelConfig] = None, + wlan_configs: List[core_pb2.WlanConfig] = None, + mobility_configs: List[core_pb2.MobilityConfig] = None, + service_configs: List[core_pb2.ServiceConfig] = None, + service_file_configs: List[core_pb2.ServiceFileConfig] = None, + asymmetric_links: List[core_pb2.Link] = None, + ) -> core_pb2.StartSessionResponse: """ Start a session. @@ -197,7 +201,7 @@ class CoreGrpcClient: ) return self.stub.StartSession(request) - def stop_session(self, session_id): + def stop_session(self, session_id: int) -> core_pb2.StopSessionResponse: """ Stop a running session. @@ -208,18 +212,19 @@ class CoreGrpcClient: request = core_pb2.StopSessionRequest(session_id=session_id) return self.stub.StopSession(request) - def create_session(self, session_id=None): + def create_session(self, session_id: int = None) -> core_pb2.CreateSessionResponse: """ Create a session. - :param int session_id: id for session, default is None and one will be created for you + :param int session_id: id for session, default is None and one will be created + for you :return: response with created session id :rtype: core_pb2.CreateSessionResponse """ request = core_pb2.CreateSessionRequest(session_id=session_id) return self.stub.CreateSession(request) - def delete_session(self, session_id): + def delete_session(self, session_id: int) -> core_pb2.DeleteSessionResponse: """ Delete a session. @@ -231,16 +236,17 @@ class CoreGrpcClient: request = core_pb2.DeleteSessionRequest(session_id=session_id) return self.stub.DeleteSession(request) - def get_sessions(self): + def get_sessions(self) -> core_pb2.GetSessionsResponse: """ Retrieves all currently known sessions. - :return: response with a list of currently known session, their state and number of nodes + :return: response with a list of currently known session, their state and + number of nodes :rtype: core_pb2.GetSessionsResponse """ return self.stub.GetSessions(core_pb2.GetSessionsRequest()) - def get_session(self, session_id): + def get_session(self, session_id: int) -> core_pb2.GetSessionResponse: """ Retrieve a session. @@ -252,7 +258,9 @@ class CoreGrpcClient: request = core_pb2.GetSessionRequest(session_id=session_id) return self.stub.GetSession(request) - def get_session_options(self, session_id): + def get_session_options( + self, session_id: int + ) -> core_pb2.GetSessionOptionsResponse: """ Retrieve session options as a dict with id mapping. @@ -264,7 +272,9 @@ class CoreGrpcClient: request = core_pb2.GetSessionOptionsRequest(session_id=session_id) return self.stub.GetSessionOptions(request) - def set_session_options(self, session_id, config): + def set_session_options( + self, session_id: int, config: Dict[str, str] + ) -> core_pb2.SetSessionOptionsResponse: """ Set options for a session. @@ -279,7 +289,9 @@ class CoreGrpcClient: ) return self.stub.SetSessionOptions(request) - def get_session_metadata(self, session_id): + def get_session_metadata( + self, session_id: int + ) -> core_pb2.GetSessionMetadataResponse: """ Retrieve session metadata as a dict with id mapping. @@ -291,7 +303,9 @@ class CoreGrpcClient: request = core_pb2.GetSessionMetadataRequest(session_id=session_id) return self.stub.GetSessionMetadata(request) - def set_session_metadata(self, session_id, config): + def set_session_metadata( + self, session_id: int, config: Dict[str, str] + ) -> core_pb2.SetSessionMetadataResponse: """ Set metadata for a session. @@ -306,7 +320,9 @@ class CoreGrpcClient: ) return self.stub.SetSessionMetadata(request) - def get_session_location(self, session_id): + def get_session_location( + self, session_id: int + ) -> core_pb2.GetSessionLocationResponse: """ Get session location. @@ -320,15 +336,15 @@ class CoreGrpcClient: def set_session_location( self, - session_id, - x=None, - y=None, - z=None, - lat=None, - lon=None, - alt=None, - scale=None, - ): + session_id: int, + x: float = None, + y: float = None, + z: float = None, + lat: float = None, + lon: float = None, + alt: float = None, + scale: float = None, + ) -> core_pb2.SetSessionLocationResponse: """ Set session location. @@ -352,7 +368,9 @@ class CoreGrpcClient: ) return self.stub.SetSessionLocation(request) - def set_session_state(self, session_id, state): + def set_session_state( + self, session_id: int, state: core_pb2.SessionState + ) -> core_pb2.SetSessionStateResponse: """ Set session state. @@ -365,7 +383,9 @@ class CoreGrpcClient: request = core_pb2.SetSessionStateRequest(session_id=session_id, state=state) return self.stub.SetSessionState(request) - def add_session_server(self, session_id, name, host): + def add_session_server( + self, session_id: int, name: str, host: str + ) -> core_pb2.AddSessionServerResponse: """ Add distributed session server. @@ -381,7 +401,12 @@ class CoreGrpcClient: ) return self.stub.AddSessionServer(request) - def events(self, session_id, handler, events=None): + def events( + self, + session_id: int, + handler: Callable[[core_pb2.Event], None], + events: List[core_pb2.Event] = None, + ) -> Any: """ Listen for session events. @@ -393,10 +418,13 @@ class CoreGrpcClient: """ request = core_pb2.EventsRequest(session_id=session_id, events=events) stream = self.stub.Events(request) + logging.info("STREAM TYPE: %s", type(stream)) start_streamer(stream, handler) return stream - def throughputs(self, session_id, handler): + def throughputs( + self, session_id: int, handler: Callable[[core_pb2.ThroughputsEvent], None] + ) -> Any: """ Listen for throughput events with information for interfaces and bridges. @@ -410,7 +438,9 @@ class CoreGrpcClient: start_streamer(stream, handler) return stream - def add_node(self, session_id, node): + def add_node( + self, session_id: int, node: core_pb2.Node + ) -> core_pb2.AddNodeResponse: """ Add node to session. @@ -423,7 +453,7 @@ class CoreGrpcClient: request = core_pb2.AddNodeRequest(session_id=session_id, node=node) return self.stub.AddNode(request) - def get_node(self, session_id, node_id): + def get_node(self, session_id: int, node_id: int) -> core_pb2.GetNodeResponse: """ Get node details. @@ -436,7 +466,14 @@ class CoreGrpcClient: request = core_pb2.GetNodeRequest(session_id=session_id, node_id=node_id) return self.stub.GetNode(request) - def edit_node(self, session_id, node_id, position, icon=None, source=None): + def edit_node( + self, + session_id: int, + node_id: int, + position: core_pb2.Position, + icon: str = None, + source: str = None, + ) -> core_pb2.EditNodeResponse: """ Edit a node, currently only changes position. @@ -458,7 +495,7 @@ class CoreGrpcClient: ) return self.stub.EditNode(request) - def delete_node(self, session_id, node_id): + def delete_node(self, session_id: int, node_id: int) -> core_pb2.DeleteNodeResponse: """ Delete node from session. @@ -471,12 +508,15 @@ class CoreGrpcClient: request = core_pb2.DeleteNodeRequest(session_id=session_id, node_id=node_id) return self.stub.DeleteNode(request) - def node_command(self, session_id, node_id, command): + def node_command( + self, session_id: int, node_id: int, command: str + ) -> core_pb2.NodeCommandResponse: """ Send command to a node and get the output. :param int session_id: session id :param int node_id: node id + :param str command: command to run on node :return: response with command combined stdout/stderr :rtype: core_pb2.NodeCommandResponse :raises grpc.RpcError: when session or node doesn't exist @@ -486,7 +526,9 @@ class CoreGrpcClient: ) return self.stub.NodeCommand(request) - def get_node_terminal(self, session_id, node_id): + def get_node_terminal( + self, session_id: int, node_id: int + ) -> core_pb2.GetNodeTerminalResponse: """ Retrieve terminal command string for launching a local terminal. @@ -501,7 +543,9 @@ class CoreGrpcClient: ) return self.stub.GetNodeTerminal(request) - def get_node_links(self, session_id, node_id): + def get_node_links( + self, session_id: int, node_id: int + ) -> core_pb2.GetNodeLinksResponse: """ Get current links for a node. @@ -516,13 +560,13 @@ class CoreGrpcClient: def add_link( self, - session_id, - node_one_id, - node_two_id, - interface_one=None, - interface_two=None, - options=None, - ): + session_id: int, + node_one_id: int, + node_two_id: int, + interface_one: core_pb2.Interface = None, + interface_two: core_pb2.Interface = None, + options: core_pb2.LinkOptions = None, + ) -> core_pb2.AddLinkResponse: """ Add a link between nodes. @@ -549,13 +593,13 @@ class CoreGrpcClient: def edit_link( self, - session_id, - node_one_id, - node_two_id, - options, - interface_one_id=None, - interface_two_id=None, - ): + session_id: int, + node_one_id: int, + node_two_id: int, + options: core_pb2.LinkOptions, + interface_one_id: int = None, + interface_two_id: int = None, + ) -> core_pb2.EditLinkResponse: """ Edit a link between nodes. @@ -581,12 +625,12 @@ class CoreGrpcClient: def delete_link( self, - session_id, - node_one_id, - node_two_id, - interface_one_id=None, - interface_two_id=None, - ): + session_id: int, + node_one_id: int, + node_two_id: int, + interface_one_id: int = None, + interface_two_id: int = None, + ) -> core_pb2.DeleteLinkResponse: """ Delete a link between nodes. @@ -608,7 +652,7 @@ class CoreGrpcClient: ) return self.stub.DeleteLink(request) - def get_hooks(self, session_id): + def get_hooks(self, session_id: int) -> core_pb2.GetHooksResponse: """ Get all hook scripts. @@ -620,7 +664,13 @@ class CoreGrpcClient: request = core_pb2.GetHooksRequest(session_id=session_id) return self.stub.GetHooks(request) - def add_hook(self, session_id, state, file_name, file_data): + def add_hook( + self, + session_id: int, + state: core_pb2.SessionState, + file_name: str, + file_data: bytes, + ) -> core_pb2.AddHookResponse: """ Add hook scripts. @@ -636,7 +686,9 @@ class CoreGrpcClient: request = core_pb2.AddHookRequest(session_id=session_id, hook=hook) return self.stub.AddHook(request) - def get_mobility_configs(self, session_id): + def get_mobility_configs( + self, session_id: int + ) -> core_pb2.GetMobilityConfigsResponse: """ Get all mobility configurations. @@ -648,7 +700,9 @@ class CoreGrpcClient: request = core_pb2.GetMobilityConfigsRequest(session_id=session_id) return self.stub.GetMobilityConfigs(request) - def get_mobility_config(self, session_id, node_id): + def get_mobility_config( + self, session_id: int, node_id: int + ) -> core_pb2.GetMobilityConfigResponse: """ Get mobility configuration for a node. @@ -663,7 +717,9 @@ class CoreGrpcClient: ) return self.stub.GetMobilityConfig(request) - def set_mobility_config(self, session_id, node_id, config): + def set_mobility_config( + self, session_id: int, node_id: int, config: Dict[str, str] + ) -> core_pb2.SetMobilityConfigResponse: """ Set mobility configuration for a node. @@ -680,7 +736,9 @@ class CoreGrpcClient: ) return self.stub.SetMobilityConfig(request) - def mobility_action(self, session_id, node_id, action): + def mobility_action( + self, session_id: int, node_id: int, action: core_pb2.ServiceAction + ) -> core_pb2.MobilityActionResponse: """ Send a mobility action for a node. @@ -696,7 +754,7 @@ class CoreGrpcClient: ) return self.stub.MobilityAction(request) - def get_services(self): + def get_services(self) -> core_pb2.GetServicesResponse: """ Get all currently loaded services. @@ -706,7 +764,9 @@ class CoreGrpcClient: request = core_pb2.GetServicesRequest() return self.stub.GetServices(request) - def get_service_defaults(self, session_id): + def get_service_defaults( + self, session_id: int + ) -> core_pb2.GetServiceDefaultsResponse: """ Get default services for different default node models. @@ -718,7 +778,9 @@ class CoreGrpcClient: request = core_pb2.GetServiceDefaultsRequest(session_id=session_id) return self.stub.GetServiceDefaults(request) - def set_service_defaults(self, session_id, service_defaults): + def set_service_defaults( + self, session_id: int, service_defaults: Dict[str, List[str]] + ) -> core_pb2.SetServiceDefaultsResponse: """ Set default services for node models. @@ -738,7 +800,9 @@ class CoreGrpcClient: ) return self.stub.SetServiceDefaults(request) - def get_node_service_configs(self, session_id): + def get_node_service_configs( + self, session_id: int + ) -> core_pb2.GetNodeServiceConfigsResponse: """ Get service data for a node. @@ -750,7 +814,9 @@ class CoreGrpcClient: request = core_pb2.GetNodeServiceConfigsRequest(session_id=session_id) return self.stub.GetNodeServiceConfigs(request) - def get_node_service(self, session_id, node_id, service): + def get_node_service( + self, session_id: int, node_id: int, service: str + ) -> core_pb2.GetNodeServiceResponse: """ Get service data for a node. @@ -766,7 +832,9 @@ class CoreGrpcClient: ) return self.stub.GetNodeService(request) - def get_node_service_file(self, session_id, node_id, service, file_name): + def get_node_service_file( + self, session_id: int, node_id: int, service: str, file_name: str + ) -> core_pb2.GetNodeServiceFileResponse: """ Get a service file for a node. @@ -784,8 +852,14 @@ class CoreGrpcClient: return self.stub.GetNodeServiceFile(request) def set_node_service( - self, session_id, node_id, service, startup, validate, shutdown - ): + self, + session_id: int, + node_id: int, + service: str, + startup: List[str], + validate: List[str], + shutdown: List[str], + ) -> core_pb2.SetNodeServiceResponse: """ Set service data for a node. @@ -809,7 +883,9 @@ class CoreGrpcClient: request = core_pb2.SetNodeServiceRequest(session_id=session_id, config=config) return self.stub.SetNodeService(request) - def set_node_service_file(self, session_id, node_id, service, file_name, data): + def set_node_service_file( + self, session_id: int, node_id: int, service: str, file_name: str, data: bytes + ) -> core_pb2.SetNodeServiceFileResponse: """ Set a service file for a node. @@ -830,14 +906,21 @@ class CoreGrpcClient: ) return self.stub.SetNodeServiceFile(request) - def service_action(self, session_id, node_id, service, action): + def service_action( + self, + session_id: int, + node_id: int, + service: str, + action: core_pb2.ServiceAction, + ) -> core_pb2.ServiceActionResponse: """ Send an action to a service for a node. :param int session_id: session id :param int node_id: node id :param str service: service name - :param core_pb2.ServiceAction action: action for service (start, stop, restart, validate) + :param core_pb2.ServiceAction action: action for service (start, stop, restart, + validate) :return: response with result of success or failure :rtype: core_pb2.ServiceActionResponse :raises grpc.RpcError: when session or node doesn't exist @@ -847,7 +930,7 @@ class CoreGrpcClient: ) return self.stub.ServiceAction(request) - def get_wlan_configs(self, session_id): + def get_wlan_configs(self, session_id: int) -> core_pb2.GetWlanConfigsResponse: """ Get all wlan configurations. @@ -859,7 +942,9 @@ class CoreGrpcClient: request = core_pb2.GetWlanConfigsRequest(session_id=session_id) return self.stub.GetWlanConfigs(request) - def get_wlan_config(self, session_id, node_id): + def get_wlan_config( + self, session_id: int, node_id: int + ) -> core_pb2.GetWlanConfigResponse: """ Get wlan configuration for a node. @@ -872,7 +957,9 @@ class CoreGrpcClient: request = core_pb2.GetWlanConfigRequest(session_id=session_id, node_id=node_id) return self.stub.GetWlanConfig(request) - def set_wlan_config(self, session_id, node_id, config): + def set_wlan_config( + self, session_id: int, node_id: int, config: Dict[str, str] + ) -> core_pb2.SetWlanConfigResponse: """ Set wlan configuration for a node. @@ -889,7 +976,7 @@ class CoreGrpcClient: ) return self.stub.SetWlanConfig(request) - def get_emane_config(self, session_id): + def get_emane_config(self, session_id: int) -> core_pb2.GetEmaneConfigResponse: """ Get session emane configuration. @@ -901,7 +988,9 @@ class CoreGrpcClient: request = core_pb2.GetEmaneConfigRequest(session_id=session_id) return self.stub.GetEmaneConfig(request) - def set_emane_config(self, session_id, config): + def set_emane_config( + self, session_id: int, config: Dict[str, str] + ) -> core_pb2.SetEmaneConfigResponse: """ Set session emane configuration. @@ -914,7 +1003,7 @@ class CoreGrpcClient: request = core_pb2.SetEmaneConfigRequest(session_id=session_id, config=config) return self.stub.SetEmaneConfig(request) - def get_emane_models(self, session_id): + def get_emane_models(self, session_id: int) -> core_pb2.GetEmaneModelsResponse: """ Get session emane models. @@ -926,7 +1015,9 @@ class CoreGrpcClient: request = core_pb2.GetEmaneModelsRequest(session_id=session_id) return self.stub.GetEmaneModels(request) - def get_emane_model_config(self, session_id, node_id, model, interface_id=-1): + def get_emane_model_config( + self, session_id: int, node_id: int, model: str, interface_id: int = -1 + ) -> core_pb2.GetEmaneModelConfigResponse: """ Get emane model configuration for a node or a node's interface. @@ -944,8 +1035,13 @@ class CoreGrpcClient: return self.stub.GetEmaneModelConfig(request) def set_emane_model_config( - self, session_id, node_id, model, config, interface_id=-1 - ): + self, + session_id: int, + node_id: int, + model: str, + config: Dict[str, str], + interface_id: int = -1, + ) -> core_pb2.SetEmaneModelConfigResponse: """ Set emane model configuration for a node or a node's interface. @@ -966,7 +1062,9 @@ class CoreGrpcClient: ) return self.stub.SetEmaneModelConfig(request) - def get_emane_model_configs(self, session_id): + def get_emane_model_configs( + self, session_id: int + ) -> core_pb2.GetEmaneModelConfigsResponse: """ Get all emane model configurations for a session. @@ -978,7 +1076,7 @@ class CoreGrpcClient: request = core_pb2.GetEmaneModelConfigsRequest(session_id=session_id) return self.stub.GetEmaneModelConfigs(request) - def save_xml(self, session_id, file_path): + def save_xml(self, session_id: int, file_path: str) -> core_pb2.SaveXmlResponse: """ Save the current scenario to an XML file. @@ -991,7 +1089,7 @@ class CoreGrpcClient: with open(file_path, "w") as xml_file: xml_file.write(response.data) - def open_xml(self, file_path, start=False): + def open_xml(self, file_path: str, start: bool = False) -> core_pb2.OpenXmlResponse: """ Load a local scenario XML file to open as a new session. @@ -1005,7 +1103,9 @@ class CoreGrpcClient: request = core_pb2.OpenXmlRequest(data=data, start=start, file=file_path) return self.stub.OpenXml(request) - def emane_link(self, session_id, nem_one, nem_two, linked): + def emane_link( + self, session_id: int, nem_one: int, nem_two: int, linked: bool + ) -> core_pb2.EmaneLinkResponse: """ Helps broadcast wireless link/unlink between EMANE nodes. @@ -1020,7 +1120,7 @@ class CoreGrpcClient: ) return self.stub.EmaneLink(request) - def get_interfaces(self): + def get_interfaces(self) -> core_pb2.GetInterfacesResponse: """ Retrieves a list of interfaces available on the host machine that are not a part of a CORE session. @@ -1030,7 +1130,7 @@ class CoreGrpcClient: request = core_pb2.GetInterfacesRequest() return self.stub.GetInterfaces(request) - def connect(self): + def connect(self) -> None: """ Open connection to server, must be closed manually. @@ -1041,7 +1141,7 @@ class CoreGrpcClient: ) self.stub = core_pb2_grpc.CoreApiStub(self.channel) - def close(self): + def close(self) -> None: """ Close currently opened server channel connection. @@ -1052,7 +1152,7 @@ class CoreGrpcClient: self.channel = None @contextmanager - def context_connect(self): + def context_connect(self) -> Generator: """ Makes a context manager based connection to the server, will close after context ends. diff --git a/daemon/core/api/grpc/events.py b/daemon/core/api/grpc/events.py index 5c4ee25e..2eebd8ae 100644 --- a/daemon/core/api/grpc/events.py +++ b/daemon/core/api/grpc/events.py @@ -1,5 +1,6 @@ import logging from queue import Empty, Queue +from typing import Iterable from core.api.grpc import core_pb2 from core.api.grpc.grpcutils import convert_value @@ -11,9 +12,10 @@ from core.emulator.data import ( LinkData, NodeData, ) +from core.emulator.session import Session -def handle_node_event(event): +def handle_node_event(event: NodeData) -> core_pb2.NodeEvent: """ Handle node event when there is a node event @@ -34,7 +36,7 @@ def handle_node_event(event): return core_pb2.NodeEvent(node=node_proto, source=event.source) -def handle_link_event(event): +def handle_link_event(event: LinkData) -> core_pb2.LinkEvent: """ Handle link event when there is a link event @@ -90,7 +92,7 @@ def handle_link_event(event): return core_pb2.LinkEvent(message_type=event.message_type, link=link) -def handle_session_event(event): +def handle_session_event(event: EventData) -> core_pb2.SessionEvent: """ Handle session event when there is a session event @@ -110,7 +112,7 @@ def handle_session_event(event): ) -def handle_config_event(event): +def handle_config_event(event: ConfigData) -> core_pb2.ConfigEvent: """ Handle configuration event when there is configuration event @@ -135,7 +137,7 @@ def handle_config_event(event): ) -def handle_exception_event(event): +def handle_exception_event(event: ExceptionData) -> core_pb2.ExceptionEvent: """ Handle exception event when there is exception event @@ -153,7 +155,7 @@ def handle_exception_event(event): ) -def handle_file_event(event): +def handle_file_event(event: FileData) -> core_pb2.FileEvent: """ Handle file event @@ -179,7 +181,9 @@ class EventStreamer: Processes session events to generate grpc events. """ - def __init__(self, session, event_types): + def __init__( + self, session: Session, event_types: Iterable[core_pb2.EventType] + ) -> None: """ Create a EventStreamer instance. @@ -191,7 +195,7 @@ class EventStreamer: self.queue = Queue() self.add_handlers() - def add_handlers(self): + def add_handlers(self) -> None: """ Add a session event handler for desired event types. @@ -210,7 +214,7 @@ class EventStreamer: if core_pb2.EventType.SESSION in self.event_types: self.session.event_handlers.append(self.queue.put) - def process(self): + def process(self) -> core_pb2.Event: """ Process the next event in the queue. @@ -239,7 +243,7 @@ class EventStreamer: event = None return event - def remove_handlers(self): + def remove_handlers(self) -> None: """ Remove session event handlers for events being watched. diff --git a/daemon/core/api/grpc/grpcutils.py b/daemon/core/api/grpc/grpcutils.py index 5468e617..89a1d298 100644 --- a/daemon/core/api/grpc/grpcutils.py +++ b/daemon/core/api/grpc/grpcutils.py @@ -1,16 +1,21 @@ import logging import time +from typing import Any, Dict, List, Tuple, Type from core import utils from core.api.grpc import core_pb2 +from core.config import ConfigurableOptions +from core.emulator.data import LinkData from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions from core.emulator.enumerations import LinkTypes, NodeTypes -from core.nodes.base import CoreNetworkBase +from core.emulator.session import Session +from core.nodes.base import CoreNetworkBase, NodeBase +from core.services.coreservices import CoreService WORKERS = 10 -def add_node_data(node_proto): +def add_node_data(node_proto: core_pb2.Node) -> Tuple[NodeTypes, int, NodeOptions]: """ Convert node protobuf message to data for creating a node. @@ -40,7 +45,7 @@ def add_node_data(node_proto): return _type, _id, options -def link_interface(interface_proto): +def link_interface(interface_proto: core_pb2.Interface) -> InterfaceData: """ Create interface data from interface proto. @@ -68,7 +73,9 @@ def link_interface(interface_proto): return interface -def add_link_data(link_proto): +def add_link_data( + link_proto: core_pb2.Link +) -> Tuple[InterfaceData, InterfaceData, LinkOptions]: """ Convert link proto to link interfaces and options data. @@ -102,7 +109,9 @@ def add_link_data(link_proto): return interface_one, interface_two, options -def create_nodes(session, node_protos): +def create_nodes( + session: Session, node_protos: List[core_pb2.Node] +) -> Tuple[List[NodeBase], List[Exception]]: """ Create nodes using a thread pool and wait for completion. @@ -123,7 +132,9 @@ def create_nodes(session, node_protos): return results, exceptions -def create_links(session, link_protos): +def create_links( + session: Session, link_protos: List[core_pb2.Link] +) -> Tuple[List[NodeBase], List[Exception]]: """ Create links using a thread pool and wait for completion. @@ -146,7 +157,9 @@ def create_links(session, link_protos): return results, exceptions -def edit_links(session, link_protos): +def edit_links( + session: Session, link_protos: List[core_pb2.Link] +) -> Tuple[List[None], List[Exception]]: """ Edit links using a thread pool and wait for completion. @@ -169,7 +182,7 @@ def edit_links(session, link_protos): return results, exceptions -def convert_value(value): +def convert_value(value: Any) -> str: """ Convert value into string. @@ -182,7 +195,9 @@ def convert_value(value): return value -def get_config_options(config, configurable_options): +def get_config_options( + config: Dict[str, str], configurable_options: Type[ConfigurableOptions] +) -> Dict[str, core_pb2.ConfigOption]: """ Retrieve configuration options in a form that is used by the grpc server. @@ -211,12 +226,12 @@ def get_config_options(config, configurable_options): return results -def get_links(session, node): +def get_links(session: Session, node: NodeBase): """ Retrieve a list of links for grpc to use :param core.emulator.Session session: node's section - :param core.nodes.base.CoreNode node: node to get links from + :param core.nodes.base.NodeBase node: node to get links from :return: [core.api.grpc.core_pb2.Link] """ links = [] @@ -226,7 +241,7 @@ def get_links(session, node): return links -def get_emane_model_id(node_id, interface_id): +def get_emane_model_id(node_id: int, interface_id: int) -> int: """ Get EMANE model id @@ -241,7 +256,7 @@ def get_emane_model_id(node_id, interface_id): return node_id -def parse_emane_model_id(_id): +def parse_emane_model_id(_id: int) -> Tuple[int, int]: """ Parses EMANE model id to get true node id and interface id. @@ -257,7 +272,7 @@ def parse_emane_model_id(_id): return node_id, interface -def convert_link(session, link_data): +def convert_link(session: Session, link_data: LinkData) -> core_pb2.Link: """ Convert link_data into core protobuf Link @@ -324,7 +339,7 @@ def convert_link(session, link_data): ) -def get_net_stats(): +def get_net_stats() -> Dict[str, Dict]: """ Retrieve status about the current interfaces in the system @@ -346,7 +361,7 @@ def get_net_stats(): return stats -def session_location(session, location): +def session_location(session: Session, location: core_pb2.SessionLocation) -> None: """ Set session location based on location proto. @@ -359,7 +374,7 @@ def session_location(session, location): session.location.refscale = location.scale -def service_configuration(session, config): +def service_configuration(session: Session, config: core_pb2.ServiceConfig) -> None: """ Convenience method for setting a node service configuration. @@ -374,7 +389,7 @@ def service_configuration(session, config): service.shutdown = tuple(config.shutdown) -def get_service_configuration(service): +def get_service_configuration(service: Type[CoreService]) -> core_pb2.NodeServiceData: """ Convenience for converting a service to service data proto. diff --git a/daemon/core/api/grpc/server.py b/daemon/core/api/grpc/server.py index ea343165..06fde7e8 100644 --- a/daemon/core/api/grpc/server.py +++ b/daemon/core/api/grpc/server.py @@ -7,6 +7,7 @@ import time from concurrent import futures import grpc +from grpc import ServicerContext from core.api.grpc import core_pb2, core_pb2_grpc, grpcutils from core.api.grpc.events import EventStreamer @@ -17,11 +18,14 @@ from core.api.grpc.grpcutils import ( get_net_stats, ) from core.emane.nodes import EmaneNet +from core.emulator.coreemu import CoreEmu from core.emulator.data import LinkData from core.emulator.emudata import LinkOptions, NodeOptions from core.emulator.enumerations import EventTypes, LinkTypes, MessageFlags +from core.emulator.session import Session from core.errors import CoreCommandError, CoreError from core.location.mobility import BasicRangeModel, Ns2ScriptedMobility +from core.nodes.base import NodeBase from core.nodes.docker import DockerNode from core.nodes.lxd import LxcNode from core.services.coreservices import ServiceManager @@ -37,24 +41,24 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): :param core.emulator.coreemu.CoreEmu coreemu: coreemu object """ - def __init__(self, coreemu): + def __init__(self, coreemu: CoreEmu) -> None: super().__init__() self.coreemu = coreemu self.running = True self.server = None atexit.register(self._exit_handler) - def _exit_handler(self): + def _exit_handler(self) -> None: logging.debug("catching exit, stop running") self.running = False - def _is_running(self, context): + def _is_running(self, context) -> bool: return self.running and context.is_active() - def _cancel_stream(self, context): + def _cancel_stream(self, context) -> None: context.abort(grpc.StatusCode.CANCELLED, "server stopping") - def listen(self, address): + def listen(self, address: str) -> None: logging.info("CORE gRPC API listening on: %s", address) self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) core_pb2_grpc.add_CoreApiServicer_to_server(self, self.server) @@ -67,7 +71,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): except KeyboardInterrupt: self.server.stop(None) - def get_session(self, session_id, context): + def get_session(self, session_id: int, context: ServicerContext) -> Session: """ Retrieve session given the session id @@ -82,7 +86,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): context.abort(grpc.StatusCode.NOT_FOUND, f"session {session_id} not found") return session - def get_node(self, session, node_id, context): + def get_node( + self, session: Session, node_id: int, context: ServicerContext + ) -> NodeBase: """ Retrieve node given session and node id @@ -97,7 +103,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): except CoreError: context.abort(grpc.StatusCode.NOT_FOUND, f"node {node_id} not found") - def StartSession(self, request, context): + def StartSession( + self, request: core_pb2.StartSessionRequest, context: ServicerContext + ) -> core_pb2.StartSessionResponse: """ Start a session. @@ -184,7 +192,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): return core_pb2.StartSessionResponse(result=True) - def StopSession(self, request, context): + def StopSession( + self, request: core_pb2.StopSessionRequest, context: ServicerContext + ) -> core_pb2.StopSessionResponse: """ Stop a running session. @@ -201,7 +211,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.set_state(EventTypes.SHUTDOWN_STATE, send_event=True) return core_pb2.StopSessionResponse(result=True) - def CreateSession(self, request, context): + def CreateSession( + self, request: core_pb2.CreateSessionRequest, context: ServicerContext + ) -> core_pb2.CreateSessionResponse: """ Create a session @@ -219,7 +231,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session_id=session.id, state=session.state ) - def DeleteSession(self, request, context): + def DeleteSession( + self, request: core_pb2.DeleteSessionRequest, context: ServicerContext + ) -> core_pb2.DeleteSessionResponse: """ Delete the session @@ -232,7 +246,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): result = self.coreemu.delete_session(request.session_id) return core_pb2.DeleteSessionResponse(result=result) - def GetSessions(self, request, context): + def GetSessions( + self, request: core_pb2.GetSessionsRequest, context: ServicerContext + ) -> core_pb2.GetSessionsResponse: """ Delete the session @@ -254,7 +270,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): sessions.append(session_summary) return core_pb2.GetSessionsResponse(sessions=sessions) - def GetSessionLocation(self, request, context): + def GetSessionLocation( + self, request: core_pb2.GetSessionLocationRequest, context: ServicerContext + ) -> core_pb2.GetSessionLocationResponse: """ Retrieve a requested session location @@ -273,7 +291,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.GetSessionLocationResponse(location=location) - def SetSessionLocation(self, request, context): + def SetSessionLocation( + self, request: core_pb2.SetSessionLocationRequest, context: ServicerContext + ) -> core_pb2.SetSessionLocationResponse: """ Set session location @@ -287,7 +307,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): grpcutils.session_location(session, request.location) return core_pb2.SetSessionLocationResponse(result=True) - def SetSessionState(self, request, context): + def SetSessionState( + self, request: core_pb2.SetSessionStateRequest, context: ServicerContext + ) -> core_pb2.SetSessionStateResponse: """ Set session state @@ -320,7 +342,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): return core_pb2.SetSessionStateResponse(result=result) - def GetSessionOptions(self, request, context): + def GetSessionOptions( + self, request: core_pb2.GetSessionOptionsRequest, context: ServicerContext + ) -> core_pb2.GetSessionOptionsResponse: """ Retrieve session options. @@ -338,7 +362,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config = get_config_options(default_config, session.options) return core_pb2.GetSessionOptionsResponse(config=config) - def SetSessionOptions(self, request, context): + def SetSessionOptions( + self, request: core_pb2.SetSessionOptionsRequest, context: ServicerContext + ) -> core_pb2.SetSessionOptionsResponse: """ Update a session's configuration @@ -353,7 +379,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config.update(request.config) return core_pb2.SetSessionOptionsResponse(result=True) - def GetSessionMetadata(self, request, context): + def GetSessionMetadata( + self, request: core_pb2.GetSessionMetadataRequest, context: ServicerContext + ) -> core_pb2.GetSessionMetadataResponse: """ Retrieve session metadata. @@ -367,7 +395,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session = self.get_session(request.session_id, context) return core_pb2.GetSessionMetadataResponse(config=session.metadata) - def SetSessionMetadata(self, request, context): + def SetSessionMetadata( + self, request: core_pb2.SetSessionMetadataRequest, context: ServicerContext + ) -> core_pb2.SetSessionMetadataResponse: """ Update a session's metadata. @@ -381,7 +411,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.metadata = dict(request.config) return core_pb2.SetSessionMetadataResponse(result=True) - def GetSession(self, request, context): + def GetSession( + self, request: core_pb2.GetSessionRequest, context: ServicerContext + ) -> core_pb2.GetSessionResponse: """ Retrieve requested session @@ -435,7 +467,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session_proto = core_pb2.Session(state=session.state, nodes=nodes, links=links) return core_pb2.GetSessionResponse(session=session_proto) - def AddSessionServer(self, request, context): + def AddSessionServer( + self, request: core_pb2.AddSessionServerRequest, context: ServicerContext + ) -> core_pb2.AddSessionServerResponse: """ Add distributed server to a session. @@ -449,7 +483,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.distributed.add_server(request.name, request.host) return core_pb2.AddSessionServerResponse(result=True) - def Events(self, request, context): + def Events(self, request: core_pb2.EventsRequest, context: ServicerContext) -> None: session = self.get_session(request.session_id, context) event_types = set(request.events) if not event_types: @@ -464,7 +498,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): streamer.remove_handlers() self._cancel_stream(context) - def Throughputs(self, request, context): + def Throughputs( + self, request: core_pb2.ThroughputsRequest, context: ServicerContext + ) -> None: """ Calculate average throughput after every certain amount of delay time @@ -532,7 +568,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): last_stats = stats time.sleep(delay) - def AddNode(self, request, context): + def AddNode( + self, request: core_pb2.AddNodeRequest, context: ServicerContext + ) -> core_pb2.AddNodeResponse: """ Add node to requested session @@ -547,7 +585,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): node = session.add_node(_type=_type, _id=_id, options=options) return core_pb2.AddNodeResponse(node_id=node.id) - def GetNode(self, request, context): + def GetNode( + self, request: core_pb2.GetNodeRequest, context: ServicerContext + ) -> core_pb2.GetNodeResponse: """ Retrieve node @@ -602,7 +642,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): return core_pb2.GetNodeResponse(node=node_proto, interfaces=interfaces) - def EditNode(self, request, context): + def EditNode( + self, request: core_pb2.EditNodeRequest, context: ServicerContext + ) -> core_pb2.EditNodeResponse: """ Edit node @@ -635,7 +677,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): result = False return core_pb2.EditNodeResponse(result=result) - def DeleteNode(self, request, context): + def DeleteNode( + self, request: core_pb2.DeleteNodeRequest, context: ServicerContext + ) -> core_pb2.DeleteNodeResponse: """ Delete node @@ -648,7 +692,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): result = session.delete_node(request.node_id) return core_pb2.DeleteNodeResponse(result=result) - def NodeCommand(self, request, context): + def NodeCommand( + self, request: core_pb2.NodeCommandRequest, context: ServicerContext + ) -> core_pb2.NodeCommandResponse: """ Run command on a node @@ -665,7 +711,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): output = e.stderr return core_pb2.NodeCommandResponse(output=output) - def GetNodeTerminal(self, request, context): + def GetNodeTerminal( + self, request: core_pb2.GetNodeTerminalRequest, context: ServicerContext + ) -> core_pb2.GetNodeTerminalResponse: """ Retrieve terminal command string of a node @@ -680,7 +728,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): terminal = node.termcmdstring("/bin/bash") return core_pb2.GetNodeTerminalResponse(terminal=terminal) - def GetNodeLinks(self, request, context): + def GetNodeLinks( + self, request: core_pb2.GetNodeLinksRequest, context: ServicerContext + ) -> core_pb2.GetNodeLinksResponse: """ Retrieve all links form a requested node @@ -695,7 +745,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): links = get_links(session, node) return core_pb2.GetNodeLinksResponse(links=links) - def AddLink(self, request, context): + def AddLink( + self, request: core_pb2.AddLinkRequest, context: ServicerContext + ) -> core_pb2.AddLinkResponse: """ Add link to a session @@ -718,7 +770,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.AddLinkResponse(result=True) - def EditLink(self, request, context): + def EditLink( + self, request: core_pb2.EditLinkRequest, context: ServicerContext + ) -> core_pb2.EditLinkResponse: """ Edit a link @@ -751,7 +805,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.EditLinkResponse(result=True) - def DeleteLink(self, request, context): + def DeleteLink( + self, request: core_pb2.DeleteLinkRequest, context: ServicerContext + ) -> core_pb2.DeleteLinkResponse: """ Delete a link @@ -771,7 +827,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.DeleteLinkResponse(result=True) - def GetHooks(self, request, context): + def GetHooks( + self, request: core_pb2.GetHooksRequest, context: ServicerContext + ) -> core_pb2.GetHooksResponse: """ Retrieve all hooks from a session @@ -790,7 +848,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): hooks.append(hook) return core_pb2.GetHooksResponse(hooks=hooks) - def AddHook(self, request, context): + def AddHook( + self, request: core_pb2.AddHookRequest, context: ServicerContext + ) -> core_pb2.AddHookResponse: """ Add hook to a session @@ -805,7 +865,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.add_hook(hook.state, hook.file, None, hook.data) return core_pb2.AddHookResponse(result=True) - def GetMobilityConfigs(self, request, context): + def GetMobilityConfigs( + self, request: core_pb2.GetMobilityConfigsRequest, context: ServicerContext + ) -> core_pb2.GetMobilityConfigsResponse: """ Retrieve all mobility configurations from a session @@ -831,7 +893,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): response.configs[node_id].CopyFrom(mapped_config) return response - def GetMobilityConfig(self, request, context): + def GetMobilityConfig( + self, request: core_pb2.GetMobilityConfigRequest, context: ServicerContext + ) -> core_pb2.GetMobilityConfigResponse: """ Retrieve mobility configuration of a node @@ -849,7 +913,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config = get_config_options(current_config, Ns2ScriptedMobility) return core_pb2.GetMobilityConfigResponse(config=config) - def SetMobilityConfig(self, request, context): + def SetMobilityConfig( + self, request: core_pb2.SetMobilityConfigRequest, context: ServicerContext + ) -> core_pb2.SetMobilityConfigResponse: """ Set mobility configuration of a node @@ -867,7 +933,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.SetMobilityConfigResponse(result=True) - def MobilityAction(self, request, context): + def MobilityAction( + self, request: core_pb2.MobilityActionRequest, context: ServicerContext + ) -> core_pb2.MobilityActionResponse: """ Take mobility action whether to start, pause, stop or none of those @@ -891,7 +959,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): result = False return core_pb2.MobilityActionResponse(result=result) - def GetServices(self, request, context): + def GetServices( + self, request: core_pb2.GetServicesRequest, context: ServicerContext + ) -> core_pb2.GetServicesResponse: """ Retrieve all the services that are running @@ -908,7 +978,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): services.append(service_proto) return core_pb2.GetServicesResponse(services=services) - def GetServiceDefaults(self, request, context): + def GetServiceDefaults( + self, request: core_pb2.GetServiceDefaultsRequest, context: ServicerContext + ) -> core_pb2.GetServiceDefaultsResponse: """ Retrieve all the default services of all node types in a session @@ -929,7 +1001,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): all_service_defaults.append(service_defaults) return core_pb2.GetServiceDefaultsResponse(defaults=all_service_defaults) - def SetServiceDefaults(self, request, context): + def SetServiceDefaults( + self, request: core_pb2.SetServiceDefaultsRequest, context: ServicerContext + ) -> core_pb2.SetServiceDefaultsResponse: """ Set new default services to the session after whipping out the old ones :param core.api.grpc.core_pb2.SetServiceDefaults request: set-service-defaults @@ -947,7 +1021,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ] = service_defaults.services return core_pb2.SetServiceDefaultsResponse(result=True) - def GetNodeServiceConfigs(self, request, context): + def GetNodeServiceConfigs( + self, request: core_pb2.GetNodeServiceConfigsRequest, context: ServicerContext + ) -> core_pb2.GetNodeServiceConfigsResponse: """ Retrieve all node service configurations. @@ -973,7 +1049,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): configs.append(config) return core_pb2.GetNodeServiceConfigsResponse(configs=configs) - def GetNodeService(self, request, context): + def GetNodeService( + self, request: core_pb2.GetNodeServiceRequest, context: ServicerContext + ) -> core_pb2.GetNodeServiceResponse: """ Retrieve a requested service from a node @@ -991,7 +1069,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): service_proto = grpcutils.get_service_configuration(service) return core_pb2.GetNodeServiceResponse(service=service_proto) - def GetNodeServiceFile(self, request, context): + def GetNodeServiceFile( + self, request: core_pb2.GetNodeServiceFileRequest, context: ServicerContext + ) -> core_pb2.GetNodeServiceFileResponse: """ Retrieve a requested service file from a node @@ -1009,7 +1089,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.GetNodeServiceFileResponse(data=file_data.data) - def SetNodeService(self, request, context): + def SetNodeService( + self, request: core_pb2.SetNodeServiceRequest, context: ServicerContext + ) -> core_pb2.SetNodeServiceResponse: """ Set a node service for a node @@ -1025,7 +1107,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): grpcutils.service_configuration(session, config) return core_pb2.SetNodeServiceResponse(result=True) - def SetNodeServiceFile(self, request, context): + def SetNodeServiceFile( + self, request: core_pb2.SetNodeServiceFileRequest, context: ServicerContext + ) -> core_pb2.SetNodeServiceFileResponse: """ Store the customized service file in the service config @@ -1043,9 +1127,12 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): ) return core_pb2.SetNodeServiceFileResponse(result=True) - def ServiceAction(self, request, context): + def ServiceAction( + self, request: core_pb2.ServiceActionRequest, context: ServicerContext + ) -> core_pb2.ServiceActionResponse: """ - Take action whether to start, stop, restart, validate the service or none of the above + Take action whether to start, stop, restart, validate the service or none of + the above. :param core.api.grpc.core_pb2.ServiceActionRequest request: service-action request :param grpcServicerContext context: context object @@ -1082,7 +1169,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): return core_pb2.ServiceActionResponse(result=result) - def GetWlanConfigs(self, request, context): + def GetWlanConfigs( + self, request: core_pb2.GetWlanConfigsRequest, context: ServicerContext + ) -> core_pb2.GetWlanConfigsResponse: """ Retrieve all wireless-lan configurations. @@ -1107,7 +1196,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): response.configs[node_id].CopyFrom(mapped_config) return response - def GetWlanConfig(self, request, context): + def GetWlanConfig( + self, request: core_pb2.GetWlanConfigRequest, context: ServicerContext + ) -> core_pb2.GetWlanConfigResponse: """ Retrieve wireless-lan configuration of a node @@ -1124,7 +1215,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config = get_config_options(current_config, BasicRangeModel) return core_pb2.GetWlanConfigResponse(config=config) - def SetWlanConfig(self, request, context): + def SetWlanConfig( + self, request: core_pb2.SetWlanConfigRequest, context: ServicerContext + ) -> core_pb2.SetWlanConfigResponse: """ Set configuration data for a model @@ -1144,7 +1237,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): node.updatemodel(wlan_config.config) return core_pb2.SetWlanConfigResponse(result=True) - def GetEmaneConfig(self, request, context): + def GetEmaneConfig( + self, request: core_pb2.GetEmaneConfigRequest, context: ServicerContext + ) -> core_pb2.GetEmaneConfigResponse: """ Retrieve EMANE configuration of a session @@ -1159,7 +1254,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config = get_config_options(current_config, session.emane.emane_config) return core_pb2.GetEmaneConfigResponse(config=config) - def SetEmaneConfig(self, request, context): + def SetEmaneConfig( + self, request: core_pb2.SetEmaneConfigRequest, context: ServicerContext + ) -> core_pb2.SetEmaneConfigResponse: """ Set EMANE configuration of a session @@ -1174,7 +1271,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config.update(request.config) return core_pb2.SetEmaneConfigResponse(result=True) - def GetEmaneModels(self, request, context): + def GetEmaneModels( + self, request: core_pb2.GetEmaneModelsRequest, context: ServicerContext + ) -> core_pb2.GetEmaneModelsResponse: """ Retrieve all the EMANE models in the session @@ -1192,7 +1291,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): models.append(model) return core_pb2.GetEmaneModelsResponse(models=models) - def GetEmaneModelConfig(self, request, context): + def GetEmaneModelConfig( + self, request: core_pb2.GetEmaneModelConfigRequest, context: ServicerContext + ) -> core_pb2.GetEmaneModelConfigResponse: """ Retrieve EMANE model configuration of a node @@ -1210,7 +1311,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): config = get_config_options(current_config, model) return core_pb2.GetEmaneModelConfigResponse(config=config) - def SetEmaneModelConfig(self, request, context): + def SetEmaneModelConfig( + self, request: core_pb2.SetEmaneModelConfigRequest, context: ServicerContext + ) -> core_pb2.SetEmaneModelConfigResponse: """ Set EMANE model configuration of a node @@ -1227,7 +1330,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.emane.set_model_config(_id, model_config.model, model_config.config) return core_pb2.SetEmaneModelConfigResponse(result=True) - def GetEmaneModelConfigs(self, request, context): + def GetEmaneModelConfigs( + self, request: core_pb2.GetEmaneModelConfigsRequest, context: ServicerContext + ) -> core_pb2.GetEmaneModelConfigsResponse: """ Retrieve all EMANE model configurations of a session @@ -1261,7 +1366,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): configs.append(model_config) return core_pb2.GetEmaneModelConfigsResponse(configs=configs) - def SaveXml(self, request, context): + def SaveXml( + self, request: core_pb2.SaveXmlRequest, context: ServicerContext + ) -> core_pb2.SaveXmlResponse: """ Export the session nto the EmulationScript XML format @@ -1281,7 +1388,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): return core_pb2.SaveXmlResponse(data=data) - def OpenXml(self, request, context): + def OpenXml( + self, request: core_pb2.OpenXmlRequest, context: ServicerContext + ) -> core_pb2.OpenXmlResponse: """ Import a session from the EmulationScript XML format @@ -1309,7 +1418,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): finally: os.unlink(temp.name) - def GetInterfaces(self, request, context): + def GetInterfaces( + self, request: core_pb2.GetInterfacesRequest, context: ServicerContext + ) -> core_pb2.GetInterfacesResponse: """ Retrieve all the interfaces of the system including bridges, virtual ethernet, and loopback @@ -1329,7 +1440,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): interfaces.append(interface) return core_pb2.GetInterfacesResponse(interfaces=interfaces) - def EmaneLink(self, request, context): + def EmaneLink( + self, request: core_pb2.EmaneLinkRequest, context: ServicerContext + ) -> core_pb2.EmaneLinkResponse: """ Helps broadcast wireless link/unlink between EMANE nodes. diff --git a/daemon/core/config.py b/daemon/core/config.py index e8e73300..b117ce54 100644 --- a/daemon/core/config.py +++ b/daemon/core/config.py @@ -4,8 +4,111 @@ Common support for configurable CORE objects. import logging from collections import OrderedDict +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union +from core.emane.nodes import EmaneNet from core.emulator.data import ConfigData +from core.emulator.enumerations import ConfigDataTypes +from core.nodes.network import WlanNode + +if TYPE_CHECKING: + from core.location.mobility import WirelessModel + + +class ConfigGroup: + """ + Defines configuration group tabs used for display by ConfigurationOptions. + """ + + def __init__(self, name: str, start: int, stop: int) -> None: + """ + Creates a ConfigGroup object. + + :param str name: configuration group display name + :param int start: configurations start index for this group + :param int stop: configurations stop index for this group + """ + self.name = name + self.start = start + self.stop = stop + + +class Configuration: + """ + Represents a configuration options. + """ + + def __init__( + self, + _id: str, + _type: ConfigDataTypes, + label: str = None, + default: str = "", + options: List[str] = None, + ) -> None: + """ + Creates a Configuration object. + + :param str _id: unique name for configuration + :param core.enumerations.ConfigDataTypes _type: configuration data type + :param str label: configuration label for display + :param str default: default value for configuration + :param list options: list options if this is a configuration with a combobox + """ + self.id = _id + self.type = _type + self.default = default + if not options: + options = [] + self.options = options + if not label: + label = _id + self.label = label + + def __str__(self): + return f"{self.__class__.__name__}(id={self.id}, type={self.type}, default={self.default}, options={self.options})" + + +class ConfigurableOptions: + """ + Provides a base for defining configuration options within CORE. + """ + + name = None + bitmap = None + options = [] + + @classmethod + def configurations(cls) -> List[Configuration]: + """ + Provides the configurations for this class. + + :return: configurations + :rtype: list[Configuration] + """ + return cls.options + + @classmethod + def config_groups(cls) -> List[ConfigGroup]: + """ + Defines how configurations are grouped. + + :return: configuration group definition + :rtype: list[ConfigGroup] + """ + return [ConfigGroup("Options", 1, len(cls.configurations()))] + + @classmethod + def default_values(cls) -> Dict[str, str]: + """ + Provides an ordered mapping of configuration keys to default values. + + :return: ordered configuration mapping default values + :rtype: OrderedDict + """ + return OrderedDict( + [(config.id, config.default) for config in cls.configurations()] + ) class ConfigShim: @@ -14,7 +117,7 @@ class ConfigShim: """ @classmethod - def str_to_dict(cls, key_values): + def str_to_dict(cls, key_values: str) -> Dict[str, str]: """ Converts a TLV key/value string into an ordered mapping. @@ -30,7 +133,7 @@ class ConfigShim: return values @classmethod - def groups_to_str(cls, config_groups): + def groups_to_str(cls, config_groups: List[ConfigGroup]) -> str: """ Converts configuration groups to a TLV formatted string. @@ -47,7 +150,14 @@ class ConfigShim: return "|".join(group_strings) @classmethod - def config_data(cls, flags, node_id, type_flags, configurable_options, config): + def config_data( + cls, + flags: int, + node_id: int, + type_flags: int, + configurable_options: ConfigurableOptions, + config: Dict[str, str], + ) -> ConfigData: """ Convert this class to a Config API message. Some TLVs are defined by the class, but node number, conf type flags, and values must @@ -102,50 +212,22 @@ class ConfigShim: ) -class Configuration: - """ - Represents a configuration options. - """ - - def __init__(self, _id, _type, label=None, default="", options=None): - """ - Creates a Configuration object. - - :param str _id: unique name for configuration - :param core.enumerations.ConfigDataTypes _type: configuration data type - :param str label: configuration label for display - :param str default: default value for configuration - :param list options: list options if this is a configuration with a combobox - """ - self.id = _id - self.type = _type - self.default = default - if not options: - options = [] - self.options = options - if not label: - label = _id - self.label = label - - def __str__(self): - return f"{self.__class__.__name__}(id={self.id}, type={self.type}, default={self.default}, options={self.options})" - - class ConfigurableManager: """ - Provides convenience methods for storing and retrieving configuration options for nodes. + Provides convenience methods for storing and retrieving configuration options for + nodes. """ _default_node = -1 _default_type = _default_node - def __init__(self): + def __init__(self) -> None: """ Creates a ConfigurableManager object. """ self.node_configurations = {} - def nodes(self): + def nodes(self) -> List[int]: """ Retrieves the ids of all node configurations known by this manager. @@ -154,7 +236,7 @@ class ConfigurableManager: """ return [x for x in self.node_configurations if x != self._default_node] - def config_reset(self, node_id=None): + def config_reset(self, node_id: int = None) -> None: """ Clears all configurations or configuration for a specific node. @@ -166,7 +248,13 @@ class ConfigurableManager: elif node_id in self.node_configurations: self.node_configurations.pop(node_id) - def set_config(self, _id, value, node_id=_default_node, config_type=_default_type): + def set_config( + self, + _id: str, + value: str, + node_id: int = _default_node, + config_type: str = _default_type, + ) -> None: """ Set a specific configuration value for a node and configuration type. @@ -180,7 +268,12 @@ class ConfigurableManager: node_type_configs = node_configs.setdefault(config_type, OrderedDict()) node_type_configs[_id] = value - def set_configs(self, config, node_id=_default_node, config_type=_default_type): + def set_configs( + self, + config: Dict[str, str], + node_id: int = _default_node, + config_type: str = _default_type, + ) -> None: """ Set configurations for a node and configuration type. @@ -196,8 +289,12 @@ class ConfigurableManager: node_configs[config_type] = config def get_config( - self, _id, node_id=_default_node, config_type=_default_type, default=None - ): + self, + _id: str, + node_id: int = _default_node, + config_type: str = _default_type, + default: str = None, + ) -> str: """ Retrieves a specific configuration for a node and configuration type. @@ -214,7 +311,9 @@ class ConfigurableManager: result = node_type_configs.get(_id, default) return result - def get_configs(self, node_id=_default_node, config_type=_default_type): + def get_configs( + self, node_id: int = _default_node, config_type: str = _default_type + ) -> Dict[str, str]: """ Retrieve configurations for a node and configuration type. @@ -229,7 +328,7 @@ class ConfigurableManager: result = node_configs.get(config_type) return result - def get_all_configs(self, node_id=_default_node): + def get_all_configs(self, node_id: int = _default_node) -> List[Dict[str, str]]: """ Retrieve all current configuration types for a node. @@ -240,72 +339,12 @@ class ConfigurableManager: return self.node_configurations.get(node_id) -class ConfigGroup: - """ - Defines configuration group tabs used for display by ConfigurationOptions. - """ - - def __init__(self, name, start, stop): - """ - Creates a ConfigGroup object. - - :param str name: configuration group display name - :param int start: configurations start index for this group - :param int stop: configurations stop index for this group - """ - self.name = name - self.start = start - self.stop = stop - - -class ConfigurableOptions: - """ - Provides a base for defining configuration options within CORE. - """ - - name = None - bitmap = None - options = [] - - @classmethod - def configurations(cls): - """ - Provides the configurations for this class. - - :return: configurations - :rtype: list[Configuration] - """ - return cls.options - - @classmethod - def config_groups(cls): - """ - Defines how configurations are grouped. - - :return: configuration group definition - :rtype: list[ConfigGroup] - """ - return [ConfigGroup("Options", 1, len(cls.configurations()))] - - @classmethod - def default_values(cls): - """ - Provides an ordered mapping of configuration keys to default values. - - :return: ordered configuration mapping default values - :rtype: OrderedDict - """ - return OrderedDict( - [(config.id, config.default) for config in cls.configurations()] - ) - - class ModelManager(ConfigurableManager): """ Helps handle setting models for nodes and managing their model configurations. """ - def __init__(self): + def __init__(self) -> None: """ Creates a ModelManager object. """ @@ -313,7 +352,9 @@ class ModelManager(ConfigurableManager): self.models = {} self.node_models = {} - def set_model_config(self, node_id, model_name, config=None): + def set_model_config( + self, node_id: int, model_name: str, config: Dict[str, str] = None + ) -> None: """ Set configuration data for a model. @@ -341,7 +382,7 @@ class ModelManager(ConfigurableManager): # set configuration self.set_configs(model_config, node_id=node_id, config_type=model_name) - def get_model_config(self, node_id, model_name): + def get_model_config(self, node_id: int, model_name: str) -> Dict[str, str]: """ Retrieve configuration data for a model. @@ -363,7 +404,12 @@ class ModelManager(ConfigurableManager): return config - def set_model(self, node, model_class, config=None): + def set_model( + self, + node: Union[WlanNode, EmaneNet], + model_class: "WirelessModel", + config: Dict[str, str] = None, + ) -> None: """ Set model and model configuration for node. @@ -379,7 +425,9 @@ class ModelManager(ConfigurableManager): config = self.get_model_config(node.id, model_class.name) node.setmodel(model_class, config) - def get_models(self, node): + def get_models( + self, node: Union[WlanNode, EmaneNet] + ) -> List[Tuple[Type, Dict[str, str]]]: """ Return a list of model classes and values for a net if one has been configured. This is invoked when exporting a session to XML. diff --git a/daemon/core/emane/bypass.py b/daemon/core/emane/bypass.py index 68e6eee4..24f5d45a 100644 --- a/daemon/core/emane/bypass.py +++ b/daemon/core/emane/bypass.py @@ -1,6 +1,8 @@ """ EMANE Bypass model for CORE """ +from typing import List + from core.config import ConfigGroup, Configuration from core.emane import emanemodel from core.emulator.enumerations import ConfigDataTypes @@ -29,11 +31,11 @@ class EmaneBypassModel(emanemodel.EmaneModel): phy_config = [] @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: # ignore default logic pass # override config groups @classmethod - def config_groups(cls): + def config_groups(cls) -> List[ConfigGroup]: return [ConfigGroup("Bypass Parameters", 1, 1)] diff --git a/daemon/core/emane/commeffect.py b/daemon/core/emane/commeffect.py index 33edc342..c7224068 100644 --- a/daemon/core/emane/commeffect.py +++ b/daemon/core/emane/commeffect.py @@ -4,11 +4,13 @@ commeffect.py: EMANE CommEffect model for CORE import logging import os +from typing import Dict, List from lxml import etree -from core.config import ConfigGroup +from core.config import ConfigGroup, Configuration from core.emane import emanemanifest, emanemodel +from core.nodes.interface import CoreInterface from core.xml import emanexml try: @@ -20,7 +22,7 @@ except ImportError: logging.debug("compatible emane python bindings not installed") -def convert_none(x): +def convert_none(x: float) -> int: """ Helper to use 0 for None values. """ @@ -45,19 +47,21 @@ class EmaneCommEffectModel(emanemodel.EmaneModel): external_config = [] @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: shim_xml_path = os.path.join(emane_prefix, "share/emane/manifest", cls.shim_xml) cls.config_shim = emanemanifest.parse(shim_xml_path, cls.shim_defaults) @classmethod - def configurations(cls): + def configurations(cls) -> List[Configuration]: return cls.config_shim @classmethod - def config_groups(cls): + def config_groups(cls) -> List[ConfigGroup]: return [ConfigGroup("CommEffect SHIM Parameters", 1, len(cls.configurations()))] - def build_xml_files(self, config, interface=None): + def build_xml_files( + self, config: Dict[str, str], interface: CoreInterface = None + ) -> None: """ Build the necessary nem and commeffect XMLs in the given path. If an individual NEM has a nonstandard config, we need to build @@ -109,14 +113,14 @@ class EmaneCommEffectModel(emanemodel.EmaneModel): def linkconfig( self, - netif, - bw=None, - delay=None, - loss=None, - duplicate=None, - jitter=None, - netif2=None, - ): + netif: CoreInterface, + bw: float = None, + delay: float = None, + loss: float = None, + duplicate: float = None, + jitter: float = None, + netif2: CoreInterface = None, + ) -> None: """ Generate CommEffect events when a Link Message is received having link parameters. diff --git a/daemon/core/emane/emanemanager.py b/daemon/core/emane/emanemanager.py index a6237a7a..8561c68e 100644 --- a/daemon/core/emane/emanemanager.py +++ b/daemon/core/emane/emanemanager.py @@ -6,6 +6,7 @@ import logging import os import threading from collections import OrderedDict +from typing import TYPE_CHECKING, Dict, List, Set, Tuple, Type from core import utils from core.config import ConfigGroup, Configuration, ModelManager @@ -19,8 +20,15 @@ from core.emane.rfpipe import EmaneRfPipeModel from core.emane.tdma import EmaneTdmaModel from core.emulator.enumerations import ConfigDataTypes, RegisterTlvs from core.errors import CoreCommandError, CoreError +from core.nodes.base import CoreNode +from core.nodes.interface import CoreInterface +from core.nodes.network import CtrlNet from core.xml import emanexml +if TYPE_CHECKING: + from core.emulator.session import Session + + try: from emane.events import EventService from emane.events import LocationEvent @@ -57,7 +65,7 @@ class EmaneManager(ModelManager): EVENTCFGVAR = "LIBEMANEEVENTSERVICECONFIG" DEFAULT_LOG_LEVEL = 3 - def __init__(self, session): + def __init__(self, session: "Session") -> None: """ Creates a Emane instance. @@ -86,7 +94,9 @@ class EmaneManager(ModelManager): self.event_device = None self.emane_check() - def getifcconfig(self, node_id, interface, model_name): + def getifcconfig( + self, node_id: int, interface: CoreInterface, model_name: str + ) -> Dict[str, str]: """ Retrieve interface configuration or node configuration if not provided. @@ -129,11 +139,11 @@ class EmaneManager(ModelManager): return config - def config_reset(self, node_id=None): + def config_reset(self, node_id: int = None) -> None: super().config_reset(node_id) self.set_configs(self.emane_config.default_values()) - def emane_check(self): + def emane_check(self) -> None: """ Check if emane is installed and load models. @@ -157,7 +167,7 @@ class EmaneManager(ModelManager): except CoreCommandError: logging.info("emane is not installed") - def deleteeventservice(self): + def deleteeventservice(self) -> None: if self.service: for fd in self.service._readFd, self.service._writeFd: if fd >= 0: @@ -168,7 +178,7 @@ class EmaneManager(ModelManager): self.service = None self.event_device = None - def initeventservice(self, filename=None, shutdown=False): + def initeventservice(self, filename: str = None, shutdown: bool = False) -> None: """ Re-initialize the EMANE Event service. The multicast group and/or port may be configured. @@ -186,7 +196,7 @@ class EmaneManager(ModelManager): logging.error( "invalid emane event service device provided: %s", self.event_device ) - return False + return # make sure the event control network is in place eventnet = self.session.add_remove_control_net( @@ -205,9 +215,7 @@ class EmaneManager(ModelManager): except EventServiceException: logging.exception("error instantiating emane EventService") - return True - - def load_models(self, emane_models): + def load_models(self, emane_models: List[Type[EmaneModel]]) -> None: """ Load EMANE models and make them available. """ @@ -219,7 +227,7 @@ class EmaneManager(ModelManager): emane_model.load(emane_prefix) self.models[emane_model.name] = emane_model - def add_node(self, emane_net): + def add_node(self, emane_net: EmaneNet) -> None: """ Add EMANE network object to this manager. @@ -233,7 +241,7 @@ class EmaneManager(ModelManager): ) self._emane_nets[emane_net.id] = emane_net - def getnodes(self): + def getnodes(self) -> Set[CoreNode]: """ Return a set of CoreNodes that are linked to an EMANE network, e.g. containers having one or more radio interfaces. @@ -245,7 +253,7 @@ class EmaneManager(ModelManager): nodes.add(netif.node) return nodes - def setup(self): + def setup(self) -> int: """ Setup duties for EMANE manager. @@ -303,7 +311,7 @@ class EmaneManager(ModelManager): self.check_node_models() return EmaneManager.SUCCESS - def startup(self): + def startup(self) -> int: """ After all the EMANE networks have been added, build XML files and start the daemons. @@ -347,7 +355,7 @@ class EmaneManager(ModelManager): return EmaneManager.SUCCESS - def poststartup(self): + def poststartup(self) -> None: """ Retransmit location events now that all NEMs are active. """ @@ -367,7 +375,7 @@ class EmaneManager(ModelManager): x, y, z = netif.node.position.get() emane_node.setnemposition(netif, x, y, z) - def reset(self): + def reset(self) -> None: """ Remove all EMANE networks from the dictionary, reset port numbers and nem id counters @@ -382,7 +390,7 @@ class EmaneManager(ModelManager): "emane_transform_port", 8200 ) - def shutdown(self): + def shutdown(self) -> None: """ stop all EMANE daemons """ @@ -394,7 +402,7 @@ class EmaneManager(ModelManager): self.stopdaemons() self.stopeventmonitor() - def buildxml(self): + def buildxml(self) -> None: """ Build XML files required to run EMANE on each node. NEMs run inside containers using the control network for passing @@ -410,7 +418,7 @@ class EmaneManager(ModelManager): self.buildnemxml() self.buildeventservicexml() - def check_node_models(self): + def check_node_models(self) -> None: """ Associate EMANE model classes with EMANE network nodes. """ @@ -438,7 +446,7 @@ class EmaneManager(ModelManager): model_class = self.models[model_name] emane_node.setmodel(model_class, config) - def nemlookup(self, nemid): + def nemlookup(self, nemid) -> Tuple[EmaneNet, CoreInterface]: """ Look for the given numerical NEM ID and return the first matching EMANE network and NEM interface. @@ -456,7 +464,7 @@ class EmaneManager(ModelManager): return emane_node, netif - def numnems(self): + def numnems(self) -> int: """ Return the number of NEMs emulated locally. """ @@ -466,7 +474,7 @@ class EmaneManager(ModelManager): count += len(emane_node.netifs()) return count - def buildplatformxml(self, ctrlnet): + def buildplatformxml(self, ctrlnet: CtrlNet) -> None: """ Build a platform.xml file now that all nodes are configured. """ @@ -480,7 +488,7 @@ class EmaneManager(ModelManager): self, ctrlnet, emane_node, nemid, platform_xmls ) - def buildnemxml(self): + def buildnemxml(self) -> None: """ Builds the nem, mac, and phy xml files for each EMANE network. """ @@ -488,7 +496,7 @@ class EmaneManager(ModelManager): emane_net = self._emane_nets[key] emanexml.build_xml_files(self, emane_net) - def buildeventservicexml(self): + def buildeventservicexml(self) -> None: """ Build the libemaneeventservice.xml file if event service options were changed in the global config. @@ -520,7 +528,7 @@ class EmaneManager(ModelManager): ) ) - def startdaemons(self): + def startdaemons(self) -> None: """ Start one EMANE daemon per node having a radio. Add a control network even if the user has not configured one. @@ -596,7 +604,7 @@ class EmaneManager(ModelManager): self.session.distributed.execute(lambda x: x.remote_cmd(emanecmd, cwd=path)) logging.info("host emane daemon running: %s", emanecmd) - def stopdaemons(self): + def stopdaemons(self) -> None: """ Kill the appropriate EMANE daemons. """ @@ -623,7 +631,7 @@ class EmaneManager(ModelManager): except CoreCommandError: logging.exception("error shutting down emane daemons") - def installnetifs(self): + def installnetifs(self) -> None: """ Install TUN/TAP virtual interfaces into their proper namespaces now that the EMANE daemons are running. @@ -633,7 +641,7 @@ class EmaneManager(ModelManager): logging.info("emane install netifs for node: %d", key) emane_node.installnetifs() - def deinstallnetifs(self): + def deinstallnetifs(self) -> None: """ Uninstall TUN/TAP virtual interfaces. """ @@ -641,7 +649,7 @@ class EmaneManager(ModelManager): emane_node = self._emane_nets[key] emane_node.deinstallnetifs() - def doeventmonitor(self): + def doeventmonitor(self) -> bool: """ Returns boolean whether or not EMANE events will be monitored. """ @@ -649,7 +657,7 @@ class EmaneManager(ModelManager): # generate the EMANE events when nodes are moved return self.session.options.get_config_bool("emane_event_monitor") - def genlocationevents(self): + def genlocationevents(self) -> bool: """ Returns boolean whether or not EMANE events will be generated. """ @@ -660,7 +668,7 @@ class EmaneManager(ModelManager): tmp = not self.doeventmonitor() return tmp - def starteventmonitor(self): + def starteventmonitor(self) -> None: """ Start monitoring EMANE location events if configured to do so. """ @@ -681,7 +689,7 @@ class EmaneManager(ModelManager): self.eventmonthread.daemon = True self.eventmonthread.start() - def stopeventmonitor(self): + def stopeventmonitor(self) -> None: """ Stop monitoring EMANE location events. """ @@ -697,7 +705,7 @@ class EmaneManager(ModelManager): self.eventmonthread.join() self.eventmonthread = None - def eventmonitorloop(self): + def eventmonitorloop(self) -> None: """ Thread target that monitors EMANE location events. """ @@ -724,7 +732,7 @@ class EmaneManager(ModelManager): threading.currentThread().getName(), ) - def handlelocationevent(self, rxnemid, eid, data): + def handlelocationevent(self, rxnemid: int, eid: int, data: str) -> None: """ Handle an EMANE location event. """ @@ -747,7 +755,9 @@ class EmaneManager(ModelManager): logging.debug("emane location event: %s,%s,%s", lat, lon, alt) self.handlelocationeventtoxyz(txnemid, lat, lon, alt) - def handlelocationeventtoxyz(self, nemid, lat, lon, alt): + def handlelocationeventtoxyz( + self, nemid: int, lat: float, lon: float, alt: float + ) -> bool: """ Convert the (NEM ID, lat, long, alt) from a received location event into a node and x,y,z coordinate values, sending a Node Message. @@ -800,11 +810,11 @@ class EmaneManager(ModelManager): # don"t use node.setposition(x,y,z) which generates an event node.position.set(x, y, z) - node_data = node.data(message_type=0, lat=str(lat), lon=str(lon), alt=str(alt)) + node_data = node.data(message_type=0, lat=lat, lon=lon, alt=alt) self.session.broadcast_node(node_data) return True - def emanerunning(self, node): + def emanerunning(self, node: CoreNode) -> bool: """ Return True if an EMANE process associated with the given node is running, False otherwise. @@ -827,7 +837,7 @@ class EmaneGlobalModel: name = "emane" bitmap = None - def __init__(self, session): + def __init__(self, session: "Session") -> None: self.session = session self.nem_config = [ Configuration( @@ -840,7 +850,7 @@ class EmaneGlobalModel: self.emulator_config = None self.parse_config() - def parse_config(self): + def parse_config(self) -> None: emane_prefix = self.session.options.get_config( "emane_prefix", default=DEFAULT_EMANE_PREFIX ) @@ -862,10 +872,10 @@ class EmaneGlobalModel: ), ) - def configurations(self): + def configurations(self) -> List[Configuration]: return self.emulator_config + self.nem_config - def config_groups(self): + def config_groups(self) -> List[ConfigGroup]: emulator_len = len(self.emulator_config) config_len = len(self.configurations()) return [ @@ -873,7 +883,7 @@ class EmaneGlobalModel: ConfigGroup("NEM Parameters", emulator_len + 1, config_len), ] - def default_values(self): + def default_values(self) -> Dict[str, str]: return OrderedDict( [(config.id, config.default) for config in self.configurations()] ) diff --git a/daemon/core/emane/emanemanifest.py b/daemon/core/emane/emanemanifest.py index a6583b9e..0fc5facc 100644 --- a/daemon/core/emane/emanemanifest.py +++ b/daemon/core/emane/emanemanifest.py @@ -1,4 +1,5 @@ import logging +from typing import Dict, List from core.config import Configuration from core.emulator.enumerations import ConfigDataTypes @@ -13,12 +14,12 @@ except ImportError: logging.debug("compatible emane python bindings not installed") -def _type_value(config_type): +def _type_value(config_type: str) -> ConfigDataTypes: """ Convert emane configuration type to core configuration value. :param str config_type: emane configuration type - :return: + :return: core config type """ config_type = config_type.upper() if config_type == "DOUBLE": @@ -28,7 +29,7 @@ def _type_value(config_type): return ConfigDataTypes[config_type] -def _get_possible(config_type, config_regex): +def _get_possible(config_type: str, config_regex: str) -> List[str]: """ Retrieve possible config value options based on emane regexes. @@ -47,7 +48,7 @@ def _get_possible(config_type, config_regex): return [] -def _get_default(config_type_name, config_value): +def _get_default(config_type_name: str, config_value: List[str]) -> str: """ Convert default configuration values to one used by core. @@ -72,7 +73,7 @@ def _get_default(config_type_name, config_value): return config_default -def parse(manifest_path, defaults): +def parse(manifest_path: str, defaults: Dict[str, str]) -> List[Configuration]: """ Parses a valid emane manifest file and converts the provided configuration values into ones used by core. diff --git a/daemon/core/emane/emanemodel.py b/daemon/core/emane/emanemodel.py index 3ca2a18f..4e5dbbfa 100644 --- a/daemon/core/emane/emanemodel.py +++ b/daemon/core/emane/emanemodel.py @@ -3,12 +3,14 @@ Defines Emane Models used within CORE. """ import logging import os +from typing import Dict, List from core.config import ConfigGroup, Configuration from core.emane import emanemanifest from core.emulator.enumerations import ConfigDataTypes from core.errors import CoreError from core.location.mobility import WirelessModel +from core.nodes.interface import CoreInterface from core.xml import emanexml @@ -45,7 +47,7 @@ class EmaneModel(WirelessModel): config_ignore = set() @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: """ Called after being loaded within the EmaneManager. Provides configured emane_prefix for parsing xml files. @@ -63,7 +65,7 @@ class EmaneModel(WirelessModel): cls.phy_config = emanemanifest.parse(phy_xml_path, cls.phy_defaults) @classmethod - def configurations(cls): + def configurations(cls) -> List[Configuration]: """ Returns the combination all all configurations (mac, phy, and external). @@ -73,7 +75,7 @@ class EmaneModel(WirelessModel): return cls.mac_config + cls.phy_config + cls.external_config @classmethod - def config_groups(cls): + def config_groups(cls) -> List[ConfigGroup]: """ Returns the defined configuration groups. @@ -89,10 +91,12 @@ class EmaneModel(WirelessModel): ConfigGroup("External Parameters", phy_len + 1, config_len), ] - def build_xml_files(self, config, interface=None): + def build_xml_files( + self, config: Dict[str, str], interface: CoreInterface = None + ) -> None: """ - Builds xml files for this emane model. Creates a nem.xml file that points to both mac.xml and phy.xml - definitions. + Builds xml files for this emane model. Creates a nem.xml file that points to + both mac.xml and phy.xml definitions. :param dict config: emane model configuration for the node and interface :param interface: interface for the emane node @@ -127,7 +131,7 @@ class EmaneModel(WirelessModel): phy_file = os.path.join(self.session.session_dir, phy_name) emanexml.create_phy_xml(self, config, phy_file, server) - def post_startup(self): + def post_startup(self) -> None: """ Logic to execute after the emane manager is finished with startup. @@ -135,7 +139,7 @@ class EmaneModel(WirelessModel): """ logging.debug("emane model(%s) has no post setup tasks", self.name) - def update(self, moved, moved_netifs): + def update(self, moved: bool, moved_netifs: List[CoreInterface]) -> None: """ Invoked from MobilityModel when nodes are moved; this causes emane location events to be generated for the nodes in the moved @@ -143,7 +147,7 @@ class EmaneModel(WirelessModel): :param bool moved: were nodes moved :param list moved_netifs: interfaces that were moved - :return: + :return: nothing """ try: wlan = self.session.get_node(self.id) @@ -153,14 +157,14 @@ class EmaneModel(WirelessModel): def linkconfig( self, - netif, - bw=None, - delay=None, - loss=None, - duplicate=None, - jitter=None, - netif2=None, - ): + netif: CoreInterface, + bw: float = None, + delay: float = None, + loss: float = None, + duplicate: float = None, + jitter: float = None, + netif2: CoreInterface = None, + ) -> None: """ Invoked when a Link Message is received. Default is unimplemented. diff --git a/daemon/core/emane/ieee80211abg.py b/daemon/core/emane/ieee80211abg.py index e7a4d0d7..ecfd3694 100644 --- a/daemon/core/emane/ieee80211abg.py +++ b/daemon/core/emane/ieee80211abg.py @@ -15,7 +15,7 @@ class EmaneIeee80211abgModel(emanemodel.EmaneModel): mac_xml = "ieee80211abgmaclayer.xml" @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: cls.mac_defaults["pcrcurveuri"] = os.path.join( emane_prefix, "share/emane/xml/models/mac/ieee80211abg/ieee80211pcr.xml" ) diff --git a/daemon/core/emane/nodes.py b/daemon/core/emane/nodes.py index bd76ed81..3a1834f3 100644 --- a/daemon/core/emane/nodes.py +++ b/daemon/core/emane/nodes.py @@ -4,9 +4,18 @@ share the same MAC+PHY model. """ import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Type +from core.emulator.distributed import DistributedServer from core.emulator.enumerations import LinkTypes, NodeTypes, RegisterTlvs from core.nodes.base import CoreNetworkBase +from core.nodes.interface import CoreInterface + +if TYPE_CHECKING: + from core.emulator.session import Session + from core.location.mobility import WirelessModel + + WirelessModelType = Type[WirelessModel] try: from emane.events import LocationEvent @@ -29,7 +38,14 @@ class EmaneNet(CoreNetworkBase): type = "wlan" is_emane = True - def __init__(self, session, _id=None, name=None, start=True, server=None): + def __init__( + self, + session: "Session", + _id: int = None, + name: str = None, + start: bool = True, + server: DistributedServer = None, + ) -> None: super().__init__(session, _id, name, start, server) self.conf = "" self.up = False @@ -39,20 +55,20 @@ class EmaneNet(CoreNetworkBase): def linkconfig( self, - netif, - bw=None, - delay=None, - loss=None, - duplicate=None, - jitter=None, - netif2=None, - ): + netif: CoreInterface, + bw: float = None, + delay: float = None, + loss: float = None, + duplicate: float = None, + jitter: float = None, + netif2: CoreInterface = None, + ) -> None: """ The CommEffect model supports link configuration. """ if not self.model: return - return self.model.linkconfig( + self.model.linkconfig( netif=netif, bw=bw, delay=delay, @@ -62,19 +78,19 @@ class EmaneNet(CoreNetworkBase): netif2=netif2, ) - def config(self, conf): + def config(self, conf: str) -> None: self.conf = conf - def shutdown(self): + def shutdown(self) -> None: pass - def link(self, netif1, netif2): + def link(self, netif1: CoreInterface, netif2: CoreInterface) -> None: pass - def unlink(self, netif1, netif2): + def unlink(self, netif1: CoreInterface, netif2: CoreInterface) -> None: pass - def updatemodel(self, config): + def updatemodel(self, config: Dict[str, str]) -> None: if not self.model: raise ValueError("no model set to update for node(%s)", self.id) logging.info( @@ -82,7 +98,7 @@ class EmaneNet(CoreNetworkBase): ) self.model.set_configs(config, node_id=self.id) - def setmodel(self, model, config): + def setmodel(self, model: "WirelessModelType", config: Dict[str, str]) -> None: """ set the EmaneModel associated with this node """ @@ -96,14 +112,14 @@ class EmaneNet(CoreNetworkBase): self.mobility = model(session=self.session, _id=self.id) self.mobility.update_config(config) - def setnemid(self, netif, nemid): + def setnemid(self, netif: CoreInterface, nemid: int) -> None: """ Record an interface to numerical ID mapping. The Emane controller object manages and assigns these IDs for all NEMs. """ self.nemidmap[netif] = nemid - def getnemid(self, netif): + def getnemid(self, netif: CoreInterface) -> Optional[int]: """ Given an interface, return its numerical ID. """ @@ -112,7 +128,7 @@ class EmaneNet(CoreNetworkBase): else: return self.nemidmap[netif] - def getnemnetif(self, nemid): + def getnemnetif(self, nemid: int) -> Optional[CoreInterface]: """ Given a numerical NEM ID, return its interface. This returns the first interface that matches the given NEM ID. @@ -122,13 +138,13 @@ class EmaneNet(CoreNetworkBase): return netif return None - def netifs(self, sort=True): + def netifs(self, sort: bool = True) -> List[CoreInterface]: """ Retrieve list of linked interfaces sorted by node number. """ return sorted(self._netif.values(), key=lambda ifc: ifc.node.id) - def installnetifs(self): + def installnetifs(self) -> None: """ Install TAP devices into their namespaces. This is done after EMANE daemons have been started, because that is their only chance @@ -159,7 +175,7 @@ class EmaneNet(CoreNetworkBase): x, y, z = netif.node.position.get() self.setnemposition(netif, x, y, z) - def deinstallnetifs(self): + def deinstallnetifs(self) -> None: """ Uninstall TAP devices. This invokes their shutdown method for any required cleanup; the device may be actually removed when @@ -170,7 +186,9 @@ class EmaneNet(CoreNetworkBase): netif.shutdown() netif.poshook = None - def setnemposition(self, netif, x, y, z): + def setnemposition( + self, netif: CoreInterface, x: float, y: float, z: float + ) -> None: """ Publish a NEM location change event using the EMANE event service. """ @@ -191,7 +209,7 @@ class EmaneNet(CoreNetworkBase): event.append(nemid, latitude=lat, longitude=lon, altitude=alt) self.session.emane.service.publish(0, event) - def setnempositions(self, moved_netifs): + def setnempositions(self, moved_netifs: List[CoreInterface]) -> None: """ Several NEMs have moved, from e.g. a WaypointMobilityModel calculation. Generate an EMANE Location Event having several diff --git a/daemon/core/emane/rfpipe.py b/daemon/core/emane/rfpipe.py index 51820b7d..23790b3c 100644 --- a/daemon/core/emane/rfpipe.py +++ b/daemon/core/emane/rfpipe.py @@ -15,7 +15,7 @@ class EmaneRfPipeModel(emanemodel.EmaneModel): mac_xml = "rfpipemaclayer.xml" @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: cls.mac_defaults["pcrcurveuri"] = os.path.join( emane_prefix, "share/emane/xml/models/mac/rfpipe/rfpipepcr.xml" ) diff --git a/daemon/core/emane/tdma.py b/daemon/core/emane/tdma.py index 59ed9e04..17f5328f 100644 --- a/daemon/core/emane/tdma.py +++ b/daemon/core/emane/tdma.py @@ -27,7 +27,7 @@ class EmaneTdmaModel(emanemodel.EmaneModel): config_ignore = {schedule_name} @classmethod - def load(cls, emane_prefix): + def load(cls, emane_prefix: str) -> None: cls.mac_defaults["pcrcurveuri"] = os.path.join( emane_prefix, "share/emane/xml/models/mac/tdmaeventscheduler/tdmabasemodelpcr.xml", @@ -43,7 +43,7 @@ class EmaneTdmaModel(emanemodel.EmaneModel): ), ) - def post_startup(self): + def post_startup(self) -> None: """ Logic to execute after the emane manager is finished with startup. diff --git a/daemon/core/emulator/coreemu.py b/daemon/core/emulator/coreemu.py index 158dc296..ed51e076 100644 --- a/daemon/core/emulator/coreemu.py +++ b/daemon/core/emulator/coreemu.py @@ -3,13 +3,14 @@ import logging import os import signal import sys +from typing import Mapping, Type import core.services from core.emulator.session import Session from core.services.coreservices import ServiceManager -def signal_handler(signal_number, _): +def signal_handler(signal_number: int, _) -> None: """ Handle signals and force an exit with cleanup. @@ -33,7 +34,7 @@ class CoreEmu: Provides logic for creating and configuring CORE sessions and the nodes within them. """ - def __init__(self, config=None): + def __init__(self, config: Mapping[str, str] = None) -> None: """ Create a CoreEmu object. @@ -57,7 +58,7 @@ class CoreEmu: # catch exit event atexit.register(self.shutdown) - def load_services(self): + def load_services(self) -> None: # load default services self.service_errors = core.services.load() @@ -70,7 +71,7 @@ class CoreEmu: custom_service_errors = ServiceManager.add_services(service_path) self.service_errors.extend(custom_service_errors) - def shutdown(self): + def shutdown(self) -> None: """ Shutdown all CORE session. @@ -83,7 +84,7 @@ class CoreEmu: session = sessions[_id] session.shutdown() - def create_session(self, _id=None, _cls=Session): + def create_session(self, _id: int = None, _cls: Type[Session] = Session) -> Session: """ Create a new CORE session. @@ -101,7 +102,7 @@ class CoreEmu: self.sessions[_id] = session return session - def delete_session(self, _id): + def delete_session(self, _id: int) -> bool: """ Shutdown and delete a CORE session. diff --git a/daemon/core/emulator/distributed.py b/daemon/core/emulator/distributed.py index a9cba815..105b767f 100644 --- a/daemon/core/emulator/distributed.py +++ b/daemon/core/emulator/distributed.py @@ -7,6 +7,7 @@ import os import threading from collections import OrderedDict from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Callable, Dict, Tuple import netaddr from fabric import Connection @@ -17,6 +18,9 @@ from core.errors import CoreCommandError from core.nodes.interface import GreTap from core.nodes.network import CoreNetwork, CtrlNet +if TYPE_CHECKING: + from core.emulator.session import Session + LOCK = threading.Lock() CMD_HIDE = True @@ -26,7 +30,7 @@ class DistributedServer: Provides distributed server interactions. """ - def __init__(self, name, host): + def __init__(self, name: str, host: str) -> None: """ Create a DistributedServer instance. @@ -38,7 +42,9 @@ class DistributedServer: self.conn = Connection(host, user="root") self.lock = threading.Lock() - def remote_cmd(self, cmd, env=None, cwd=None, wait=True): + def remote_cmd( + self, cmd: str, env: Dict[str, str] = None, cwd: str = None, wait: bool = True + ) -> str: """ Run command remotely using server connection. @@ -73,7 +79,7 @@ class DistributedServer: stdout, stderr = e.streams_for_display() raise CoreCommandError(e.result.exited, cmd, stdout, stderr) - def remote_put(self, source, destination): + def remote_put(self, source: str, destination: str) -> None: """ Push file to remote server. @@ -84,7 +90,7 @@ class DistributedServer: with self.lock: self.conn.put(source, destination) - def remote_put_temp(self, destination, data): + def remote_put_temp(self, destination: str, data: str) -> None: """ Remote push file contents to a remote server, using a temp file as an intermediate step. @@ -106,11 +112,11 @@ class DistributedController: Provides logic for dealing with remote tunnels and distributed servers. """ - def __init__(self, session): + def __init__(self, session: "Session") -> None: """ Create - :param session: + :param session: session """ self.session = session self.servers = OrderedDict() @@ -119,7 +125,7 @@ class DistributedController: "distributed_address", default=None ) - def add_server(self, name, host): + def add_server(self, name: str, host: str) -> None: """ Add distributed server configuration. @@ -132,7 +138,7 @@ class DistributedController: cmd = f"mkdir -p {self.session.session_dir}" server.remote_cmd(cmd) - def execute(self, func): + def execute(self, func: Callable[[DistributedServer], None]) -> None: """ Convenience for executing logic against all distributed servers. @@ -143,7 +149,7 @@ class DistributedController: server = self.servers[name] func(server) - def shutdown(self): + def shutdown(self) -> None: """ Shutdown logic for dealing with distributed tunnels and server session directories. @@ -165,7 +171,7 @@ class DistributedController: # clear tunnels self.tunnels.clear() - def start(self): + def start(self) -> None: """ Start distributed network tunnels. @@ -184,7 +190,9 @@ class DistributedController: server = self.servers[name] self.create_gre_tunnel(node, server) - def create_gre_tunnel(self, node, server): + def create_gre_tunnel( + self, node: CoreNetwork, server: DistributedServer + ) -> Tuple[GreTap, GreTap]: """ Create gre tunnel using a pair of gre taps between the local and remote server. @@ -222,7 +230,7 @@ class DistributedController: self.tunnels[key] = tunnel return tunnel - def tunnel_key(self, n1_id, n2_id): + def tunnel_key(self, n1_id: int, n2_id: int) -> int: """ Compute a 32-bit key used to uniquely identify a GRE tunnel. The hash(n1num), hash(n2num) values are used, so node numbers may be @@ -239,7 +247,7 @@ class DistributedController: ) return key & 0xFFFFFFFF - def get_tunnel(self, n1_id, n2_id): + def get_tunnel(self, n1_id: int, n2_id: int) -> Tuple[GreTap, GreTap]: """ Return the GreTap between two nodes if it exists. diff --git a/daemon/core/emulator/emudata.py b/daemon/core/emulator/emudata.py index 8929f72a..acf105eb 100644 --- a/daemon/core/emulator/emudata.py +++ b/daemon/core/emulator/emudata.py @@ -1,40 +1,32 @@ +from typing import List, Optional + import netaddr from core import utils +from core.api.grpc.core_pb2 import LinkOptions from core.emane.nodes import EmaneNet from core.emulator.enumerations import LinkTypes +from core.nodes.base import CoreNetworkBase, CoreNode +from core.nodes.interface import CoreInterface from core.nodes.physical import PhysicalNode class IdGen: - def __init__(self, _id=0): + def __init__(self, _id: int = 0) -> None: self.id = _id - def next(self): + def next(self) -> int: self.id += 1 return self.id -def create_interface(node, network, interface_data): - """ - Create an interface for a node on a network using provided interface data. - - :param node: node to create interface for - :param core.nodes.base.CoreNetworkBase network: network to associate interface with - :param core.emulator.emudata.InterfaceData interface_data: interface data - :return: created interface - """ - node.newnetif( - network, - addrlist=interface_data.get_addresses(), - hwaddr=interface_data.mac, - ifindex=interface_data.id, - ifname=interface_data.name, - ) - return node.netif(interface_data.id) - - -def link_config(network, interface, link_options, devname=None, interface_two=None): +def link_config( + network: CoreNetworkBase, + interface: CoreInterface, + link_options: LinkOptions, + devname: str = None, + interface_two: CoreInterface = None, +) -> None: """ Convenience method for configuring a link, @@ -68,7 +60,7 @@ class NodeOptions: Options for creating and updating nodes within core. """ - def __init__(self, name=None, model="PC", image=None): + def __init__(self, name: str = None, model: str = "PC", image: str = None) -> None: """ Create a NodeOptions object. @@ -93,7 +85,7 @@ class NodeOptions: self.image = image self.emane = None - def set_position(self, x, y): + def set_position(self, x: float, y: float) -> None: """ Convenience method for setting position. @@ -104,7 +96,7 @@ class NodeOptions: self.x = x self.y = y - def set_location(self, lat, lon, alt): + def set_location(self, lat: float, lon: float, alt: float) -> None: """ Convenience method for setting location. @@ -123,7 +115,7 @@ class LinkOptions: Options for creating and updating links within core. """ - def __init__(self, _type=LinkTypes.WIRED): + def __init__(self, _type: LinkTypes = LinkTypes.WIRED) -> None: """ Create a LinkOptions object. @@ -148,12 +140,96 @@ class LinkOptions: self.opaque = None +class InterfaceData: + """ + Convenience class for storing interface data. + """ + + def __init__( + self, + _id: int, + name: str, + mac: str, + ip4: str, + ip4_mask: int, + ip6: str, + ip6_mask: int, + ) -> None: + """ + Creates an InterfaceData object. + + :param int _id: interface id + :param str name: name for interface + :param str mac: mac address + :param str ip4: ipv4 address + :param int ip4_mask: ipv4 bit mask + :param str ip6: ipv6 address + :param int ip6_mask: ipv6 bit mask + """ + self.id = _id + self.name = name + self.mac = mac + self.ip4 = ip4 + self.ip4_mask = ip4_mask + self.ip6 = ip6 + self.ip6_mask = ip6_mask + + def has_ip4(self) -> bool: + """ + Determines if interface has an ip4 address. + + :return: True if has ip4, False otherwise + """ + return all([self.ip4, self.ip4_mask]) + + def has_ip6(self) -> bool: + """ + Determines if interface has an ip6 address. + + :return: True if has ip6, False otherwise + """ + return all([self.ip6, self.ip6_mask]) + + def ip4_address(self) -> Optional[str]: + """ + Retrieve a string representation of the ip4 address and netmask. + + :return: ip4 string or None + """ + if self.has_ip4(): + return f"{self.ip4}/{self.ip4_mask}" + else: + return None + + def ip6_address(self) -> Optional[str]: + """ + Retrieve a string representation of the ip6 address and netmask. + + :return: ip4 string or None + """ + if self.has_ip6(): + return f"{self.ip6}/{self.ip6_mask}" + else: + return None + + def get_addresses(self) -> List[str]: + """ + Returns a list of ip4 and ip6 address when present. + + :return: list of addresses + :rtype: list + """ + ip4 = self.ip4_address() + ip6 = self.ip6_address() + return [i for i in [ip4, ip6] if i] + + class IpPrefixes: """ Convenience class to help generate IP4 and IP6 addresses for nodes within CORE. """ - def __init__(self, ip4_prefix=None, ip6_prefix=None): + def __init__(self, ip4_prefix: str = None, ip6_prefix: str = None) -> None: """ Creates an IpPrefixes object. @@ -171,7 +247,7 @@ class IpPrefixes: if ip6_prefix: self.ip6 = netaddr.IPNetwork(ip6_prefix) - def ip4_address(self, node): + def ip4_address(self, node: CoreNode) -> str: """ Convenience method to return the IP4 address for a node. @@ -183,7 +259,7 @@ class IpPrefixes: raise ValueError("ip4 prefixes have not been set") return str(self.ip4[node.id]) - def ip6_address(self, node): + def ip6_address(self, node: CoreNode) -> str: """ Convenience method to return the IP6 address for a node. @@ -195,7 +271,9 @@ class IpPrefixes: raise ValueError("ip6 prefixes have not been set") return str(self.ip6[node.id]) - def create_interface(self, node, name=None, mac=None): + def create_interface( + self, node: CoreNode, name: str = None, mac: str = None + ) -> InterfaceData: """ Creates interface data for linking nodes, using the nodes unique id for generation, along with a random mac address, unless provided. @@ -239,76 +317,22 @@ class IpPrefixes: ) -class InterfaceData: +def create_interface( + node: CoreNode, network: CoreNetworkBase, interface_data: InterfaceData +): """ - Convenience class for storing interface data. + Create an interface for a node on a network using provided interface data. + + :param node: node to create interface for + :param core.nodes.base.CoreNetworkBase network: network to associate interface with + :param core.emulator.emudata.InterfaceData interface_data: interface data + :return: created interface """ - - def __init__(self, _id, name, mac, ip4, ip4_mask, ip6, ip6_mask): - """ - Creates an InterfaceData object. - - :param int _id: interface id - :param str name: name for interface - :param str mac: mac address - :param str ip4: ipv4 address - :param int ip4_mask: ipv4 bit mask - :param str ip6: ipv6 address - :param int ip6_mask: ipv6 bit mask - """ - self.id = _id - self.name = name - self.mac = mac - self.ip4 = ip4 - self.ip4_mask = ip4_mask - self.ip6 = ip6 - self.ip6_mask = ip6_mask - - def has_ip4(self): - """ - Determines if interface has an ip4 address. - - :return: True if has ip4, False otherwise - """ - return all([self.ip4, self.ip4_mask]) - - def has_ip6(self): - """ - Determines if interface has an ip6 address. - - :return: True if has ip6, False otherwise - """ - return all([self.ip6, self.ip6_mask]) - - def ip4_address(self): - """ - Retrieve a string representation of the ip4 address and netmask. - - :return: ip4 string or None - """ - if self.has_ip4(): - return f"{self.ip4}/{self.ip4_mask}" - else: - return None - - def ip6_address(self): - """ - Retrieve a string representation of the ip6 address and netmask. - - :return: ip4 string or None - """ - if self.has_ip6(): - return f"{self.ip6}/{self.ip6_mask}" - else: - return None - - def get_addresses(self): - """ - Returns a list of ip4 and ip6 address when present. - - :return: list of addresses - :rtype: list - """ - ip4 = self.ip4_address() - ip6 = self.ip6_address() - return [i for i in [ip4, ip6] if i] + node.newnetif( + network, + addrlist=interface_data.get_addresses(), + hwaddr=interface_data.mac, + ifindex=interface_data.id, + ifname=interface_data.name, + ) + return node.netif(interface_data.id) diff --git a/daemon/core/emulator/session.py b/daemon/core/emulator/session.py index a81ba103..ca585c31 100644 --- a/daemon/core/emulator/session.py +++ b/daemon/core/emulator/session.py @@ -12,14 +12,23 @@ import subprocess import tempfile import threading import time +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from core import constants, utils from core.emane.emanemanager import EmaneManager from core.emane.nodes import EmaneNet -from core.emulator.data import EventData, ExceptionData, NodeData +from core.emulator.data import ( + ConfigData, + EventData, + ExceptionData, + FileData, + LinkData, + NodeData, +) from core.emulator.distributed import DistributedController from core.emulator.emudata import ( IdGen, + InterfaceData, LinkOptions, NodeOptions, create_interface, @@ -31,8 +40,9 @@ from core.errors import CoreError from core.location.corelocation import CoreLocation from core.location.event import EventLoop from core.location.mobility import BasicRangeModel, MobilityManager -from core.nodes.base import CoreNetworkBase, CoreNode, CoreNodeBase +from core.nodes.base import CoreNetworkBase, CoreNode, CoreNodeBase, NodeBase from core.nodes.docker import DockerNode +from core.nodes.interface import GreTap from core.nodes.lxd import LxcNode from core.nodes.network import ( CtrlNet, @@ -45,7 +55,7 @@ from core.nodes.network import ( ) from core.nodes.physical import PhysicalNode, Rj45Node from core.plugins.sdt import Sdt -from core.services.coreservices import CoreServices +from core.services.coreservices import CoreServices, ServiceBootError from core.xml import corexml, corexmldeployment from core.xml.corexml import CoreXmlReader, CoreXmlWriter @@ -74,7 +84,9 @@ class Session: CORE session manager. """ - def __init__(self, _id, config=None, mkdir=True): + def __init__( + self, _id: int, config: Dict[str, str] = None, mkdir: bool = True + ) -> None: """ Create a Session instance. @@ -150,7 +162,7 @@ class Session: } @classmethod - def get_node_class(cls, _type): + def get_node_class(cls, _type: NodeTypes) -> Type[NodeBase]: """ Retrieve the class for a given node type. @@ -163,20 +175,25 @@ class Session: return node_class @classmethod - def get_node_type(cls, _class): + def get_node_type(cls, _class: Type[NodeBase]) -> NodeTypes: """ Retrieve node type for a given node class. :param _class: node class to get a node type for :return: node type :rtype: core.emulator.enumerations.NodeTypes + :raises CoreError: when node type does not exist """ node_type = NODES_TYPE.get(_class) if node_type is None: raise CoreError(f"invalid node class: {_class}") return node_type - def _link_nodes(self, node_one_id, node_two_id): + def _link_nodes( + self, node_one_id: int, node_two_id: int + ) -> Tuple[ + CoreNode, CoreNode, CoreNetworkBase, CoreNetworkBase, Tuple[GreTap, GreTap] + ]: """ Convenience method for retrieving nodes within link data. @@ -237,14 +254,15 @@ class Session: ) return node_one, node_two, net_one, net_two, tunnel - def _link_wireless(self, objects, connect): + def _link_wireless(self, objects: Iterable[CoreNodeBase], connect: bool) -> None: """ Objects to deal with when connecting/disconnecting wireless links. :param list objects: possible objects to deal with :param bool connect: link interfaces if True, unlink otherwise :return: nothing - :raises core.CoreError: when objects to link is less than 2, or no common networks are found + :raises core.CoreError: when objects to link is less than 2, or no common + networks are found """ objects = [x for x in objects if x] if len(objects) < 2: @@ -277,20 +295,23 @@ class Session: def add_link( self, - node_one_id, - node_two_id, - interface_one=None, - interface_two=None, - link_options=None, - ): + node_one_id: int, + node_two_id: int, + interface_one: InterfaceData = None, + interface_two: InterfaceData = None, + link_options: LinkOptions = None, + ) -> None: """ Add a link between nodes. :param int node_one_id: node one id :param int node_two_id: node two id - :param core.emulator.emudata.InterfaceData interface_one: node one interface data, defaults to none - :param core.emulator.emudata.InterfaceData interface_two: node two interface data, defaults to none - :param core.emulator.emudata.LinkOptions link_options: data for creating link, defaults to no options + :param core.emulator.emudata.InterfaceData interface_one: node one interface + data, defaults to none + :param core.emulator.emudata.InterfaceData interface_two: node two interface + data, defaults to none + :param core.emulator.emudata.LinkOptions link_options: data for creating link, + defaults to no options :return: nothing """ if not link_options: @@ -406,12 +427,12 @@ class Session: def delete_link( self, - node_one_id, - node_two_id, - interface_one_id, - interface_two_id, - link_type=LinkTypes.WIRED, - ): + node_one_id: int, + node_two_id: int, + interface_one_id: int, + interface_two_id: int, + link_type: LinkTypes = LinkTypes.WIRED, + ) -> None: """ Delete a link between nodes. @@ -512,12 +533,12 @@ class Session: def update_link( self, - node_one_id, - node_two_id, - interface_one_id=None, - interface_two_id=None, - link_options=None, - ): + node_one_id: int, + node_two_id: int, + interface_one_id: int = None, + interface_two_id: int = None, + link_options: LinkOptions = None, + ) -> None: """ Update link information between nodes. @@ -623,7 +644,13 @@ class Session: if node_two: node_two.lock.release() - def add_node(self, _type=NodeTypes.DEFAULT, _id=None, options=None, _cls=None): + def add_node( + self, + _type: NodeTypes = NodeTypes.DEFAULT, + _id: int = None, + options: NodeOptions = None, + _cls: Type[NodeBase] = None, + ) -> NodeBase: """ Add a node to the session, based on the provided node data. @@ -717,14 +744,14 @@ class Session: return node - def edit_node(self, node_id, options): + def edit_node(self, node_id: int, options: NodeOptions) -> None: """ Edit node information. :param int node_id: id of node to update :param core.emulator.emudata.NodeOptions options: data to update node with :return: True if node updated, False otherwise - :rtype: bool + :rtype: nothing :raises core.CoreError: when node to update does not exist """ # get node to update @@ -737,7 +764,7 @@ class Session: node.canvas = options.canvas node.icon = options.icon - def set_node_position(self, node, options): + def set_node_position(self, node: NodeBase, options: NodeOptions) -> None: """ Set position for a node, use lat/lon/alt if needed. @@ -767,7 +794,7 @@ class Session: if using_lat_lon_alt: self.broadcast_node_location(node) - def broadcast_node_location(self, node): + def broadcast_node_location(self, node: NodeBase) -> None: """ Broadcast node location to all listeners. @@ -782,7 +809,7 @@ class Session: ) self.broadcast_node(node_data) - def start_mobility(self, node_ids=None): + def start_mobility(self, node_ids: List[int] = None) -> None: """ Start mobility for the provided node ids. @@ -791,7 +818,7 @@ class Session: """ self.mobility.startup(node_ids) - def is_active(self): + def is_active(self) -> bool: """ Determine if this session is considered to be active. (Runtime or Data collect states) @@ -804,7 +831,7 @@ class Session: logging.info("session(%s) checking if active: %s", self.id, result) return result - def open_xml(self, file_name, start=False): + def open_xml(self, file_name: str, start: bool = False) -> None: """ Import a session from the EmulationScript XML format. @@ -832,7 +859,7 @@ class Session: if start: self.instantiate() - def save_xml(self, file_name): + def save_xml(self, file_name: str) -> None: """ Export a session to the EmulationScript XML format. @@ -841,7 +868,7 @@ class Session: """ CoreXmlWriter(self).write(file_name) - def add_hook(self, state, file_name, source_name, data): + def add_hook(self, state: int, file_name: str, source_name: str, data: str) -> None: """ Store a hook from a received file message. @@ -855,7 +882,9 @@ class Session: state = f":{state}" self.set_hook(state, file_name, source_name, data) - def add_node_file(self, node_id, source_name, file_name, data): + def add_node_file( + self, node_id: int, source_name: str, file_name: str, data: str + ) -> None: """ Add a file to a node. @@ -873,7 +902,7 @@ class Session: elif data is not None: node.nodefile(file_name, data) - def clear(self): + def clear(self) -> None: """ Clear all CORE session data. (nodes, hooks, etc) @@ -889,7 +918,7 @@ class Session: self.services.reset() self.mobility.config_reset() - def start_events(self): + def start_events(self) -> None: """ Start event loop. @@ -897,7 +926,7 @@ class Session: """ self.event_loop.run() - def mobility_event(self, event_data): + def mobility_event(self, event_data: EventData) -> None: """ Handle a mobility event. @@ -906,7 +935,7 @@ class Session: """ self.mobility.handleevent(event_data) - def set_location(self, lat, lon, alt, scale): + def set_location(self, lat: float, lon: float, alt: float, scale: float) -> None: """ Set session geospatial location. @@ -919,7 +948,7 @@ class Session: self.location.setrefgeo(lat, lon, alt) self.location.refscale = scale - def shutdown(self): + def shutdown(self) -> None: """ Shutdown all session nodes and remove the session directory. """ @@ -942,7 +971,7 @@ class Session: for handler in self.shutdown_handlers: handler(self) - def broadcast_event(self, event_data): + def broadcast_event(self, event_data: EventData) -> None: """ Handle event data that should be provided to event handler. @@ -953,7 +982,7 @@ class Session: for handler in self.event_handlers: handler(event_data) - def broadcast_exception(self, exception_data): + def broadcast_exception(self, exception_data: ExceptionData) -> None: """ Handle exception data that should be provided to exception handlers. @@ -964,7 +993,7 @@ class Session: for handler in self.exception_handlers: handler(exception_data) - def broadcast_node(self, node_data): + def broadcast_node(self, node_data: NodeData) -> None: """ Handle node data that should be provided to node handlers. @@ -975,7 +1004,7 @@ class Session: for handler in self.node_handlers: handler(node_data) - def broadcast_file(self, file_data): + def broadcast_file(self, file_data: FileData) -> None: """ Handle file data that should be provided to file handlers. @@ -986,7 +1015,7 @@ class Session: for handler in self.file_handlers: handler(file_data) - def broadcast_config(self, config_data): + def broadcast_config(self, config_data: ConfigData) -> None: """ Handle config data that should be provided to config handlers. @@ -997,7 +1026,7 @@ class Session: for handler in self.config_handlers: handler(config_data) - def broadcast_link(self, link_data): + def broadcast_link(self, link_data: LinkData) -> None: """ Handle link data that should be provided to link handlers. @@ -1008,7 +1037,7 @@ class Session: for handler in self.link_handlers: handler(link_data) - def set_state(self, state, send_event=False): + def set_state(self, state: EventTypes, send_event: bool = False) -> None: """ Set the session's current state. @@ -1039,7 +1068,7 @@ class Session: event_data = EventData(event_type=state_value, time=str(time.monotonic())) self.broadcast_event(event_data) - def write_state(self, state): + def write_state(self, state: int) -> None: """ Write the current state to a state file in the session dir. @@ -1053,9 +1082,10 @@ class Session: except IOError: logging.exception("error writing state file: %s", state) - def run_hooks(self, state): + def run_hooks(self, state: int) -> None: """ - Run hook scripts upon changing states. If hooks is not specified, run all hooks in the given state. + Run hook scripts upon changing states. If hooks is not specified, run all hooks + in the given state. :param int state: state to run hooks for :return: nothing @@ -1075,7 +1105,9 @@ class Session: else: logging.info("no state hooks for %s", state) - def set_hook(self, hook_type, file_name, source_name, data): + def set_hook( + self, hook_type: str, file_name: str, source_name: str, data: str + ) -> None: """ Store a hook from a received file message. @@ -1107,13 +1139,13 @@ class Session: logging.info("immediately running new state hook") self.run_hook(hook) - def del_hooks(self): + def del_hooks(self) -> None: """ Clear the hook scripts dict. """ self._hooks.clear() - def run_hook(self, hook): + def run_hook(self, hook: Tuple[str, str]) -> None: """ Run a hook. @@ -1154,7 +1186,7 @@ class Session: except (OSError, subprocess.CalledProcessError): logging.exception("error running hook: %s", file_name) - def run_state_hooks(self, state): + def run_state_hooks(self, state: int) -> None: """ Run state hooks. @@ -1174,7 +1206,7 @@ class Session: ExceptionLevels.ERROR, "Session.run_state_hooks", None, message ) - def add_state_hook(self, state, hook): + def add_state_hook(self, state: int, hook: Callable[[int], None]) -> None: """ Add a state hook. @@ -1190,18 +1222,18 @@ class Session: if self.state == state: hook(state) - def del_state_hook(self, state, hook): + def del_state_hook(self, state: int, hook: Callable[[int], None]) -> None: """ Delete a state hook. :param int state: state to delete hook for :param func hook: hook to delete - :return: + :return: nothing """ hooks = self._state_hooks.setdefault(state, []) hooks.remove(hook) - def runtime_state_hook(self, state): + def runtime_state_hook(self, state: int) -> None: """ Runtime state hook check. @@ -1217,7 +1249,7 @@ class Session: corexmldeployment.CoreXmlDeployment(self, xml_writer.scenario) xml_writer.write(xml_file_name) - def get_environment(self, state=True): + def get_environment(self, state: bool = True) -> Dict[str, str]: """ Get an environment suitable for a subprocess.Popen call. This is the current process environment with some session-specific @@ -1265,7 +1297,7 @@ class Session: return env - def set_thumbnail(self, thumb_file): + def set_thumbnail(self, thumb_file: str) -> None: """ Set the thumbnail filename. Move files from /tmp to session dir. @@ -1281,7 +1313,7 @@ class Session: shutil.copy(thumb_file, destination_file) self.thumbnail = destination_file - def set_user(self, user): + def set_user(self, user: str) -> None: """ Set the username for this session. Update the permissions of the session dir to allow the user write access. @@ -1299,7 +1331,7 @@ class Session: self.user = user - def get_node_id(self): + def get_node_id(self) -> int: """ Return a unique, new node id. """ @@ -1308,10 +1340,9 @@ class Session: node_id = random.randint(1, 0xFFFF) if node_id not in self.nodes: break - return node_id - def create_node(self, cls, *args, **kwargs): + def create_node(self, cls: Type[NodeBase], *args: Any, **kwargs: Any) -> NodeBase: """ Create an emulation node. @@ -1322,29 +1353,27 @@ class Session: :raises core.CoreError: when id of the node to create already exists """ node = cls(self, *args, **kwargs) - with self._nodes_lock: if node.id in self.nodes: node.shutdown() raise CoreError(f"duplicate node id {node.id} for {node.name}") self.nodes[node.id] = node - return node - def get_node(self, _id): + def get_node(self, _id: int) -> NodeBase: """ Get a session node. :param int _id: node id to retrieve :return: node for the given id - :rtype: core.nodes.base.CoreNode + :rtype: core.nodes.base.NodeBase :raises core.CoreError: when node does not exist """ if _id not in self.nodes: raise CoreError(f"unknown node id {_id}") return self.nodes[_id] - def delete_node(self, _id): + def delete_node(self, _id: int) -> bool: """ Delete a node from the session and check if session should shutdown, if no nodes are left. @@ -1365,7 +1394,7 @@ class Session: return node is not None - def delete_nodes(self): + def delete_nodes(self) -> None: """ Clear the nodes dictionary, and call shutdown for each node. """ @@ -1377,7 +1406,7 @@ class Session: utils.threadpool(funcs) self.node_id_gen.id = 0 - def write_nodes(self): + def write_nodes(self) -> None: """ Write nodes to a 'nodes' file in the session dir. The 'nodes' file lists: number, name, api-type, class-type @@ -1392,7 +1421,7 @@ class Session: except IOError: logging.exception("error writing nodes file") - def dump_session(self): + def dump_session(self) -> None: """ Log information about the session in its current state. """ @@ -1405,7 +1434,9 @@ class Session: len(self.nodes), ) - def exception(self, level, source, node_id, text): + def exception( + self, level: ExceptionLevels, source: str, node_id: int, text: str + ) -> None: """ Generate and broadcast an exception event. @@ -1425,27 +1456,28 @@ class Session: ) self.broadcast_exception(exception_data) - def instantiate(self): + def instantiate(self) -> List[ServiceBootError]: """ We have entered the instantiation state, invoke startup methods of various managers and boot the nodes. Validate nodes and check for transition to the runtime state. - """ + :return: list of service boot errors during startup + """ # write current nodes out to session directory file self.write_nodes() # create control net interfaces and network tunnels # which need to exist for emane to sync on location events # in distributed scenarios - self.add_remove_control_interface(node=None, remove=False) + self.add_remove_control_net(0, remove=False) # initialize distributed tunnels self.distributed.start() - # instantiate will be invoked again upon Emane configure + # instantiate will be invoked again upon emane configure if self.emane.startup() == self.emane.NOT_READY: - return + return [] # boot node services and then start mobility exceptions = self.boot_nodes() @@ -1462,12 +1494,13 @@ class Session: self.check_runtime() return exceptions - def get_node_count(self): + def get_node_count(self) -> int: """ Returns the number of CoreNodes and CoreNets, except for those that are not considered in the GUI's node count. - """ + :return: created node count + """ with self._nodes_lock: count = 0 for node_id in self.nodes: @@ -1480,14 +1513,15 @@ class Session: continue count += 1 - return count - def check_runtime(self): + def check_runtime(self) -> None: """ Check if we have entered the runtime state, that all nodes have been started and the emulation is running. Start the event loop once we have entered runtime (time=0). + + :return: nothing """ # this is called from instantiate() after receiving an event message # for the instantiation state @@ -1504,10 +1538,12 @@ class Session: self.event_loop.run() self.set_state(EventTypes.RUNTIME_STATE, send_event=True) - def data_collect(self): + def data_collect(self) -> None: """ Tear down a running session. Stop the event loop and any running nodes, and perform clean-up. + + :return: nothing """ # stop event loop self.event_loop.stop() @@ -1528,58 +1564,59 @@ class Session: # update control interface hosts self.update_control_interface_hosts(remove=True) - # remove all four possible control networks. Does nothing if ctrlnet is not - # installed. - self.add_remove_control_interface(node=None, net_index=0, remove=True) - self.add_remove_control_interface(node=None, net_index=1, remove=True) - self.add_remove_control_interface(node=None, net_index=2, remove=True) - self.add_remove_control_interface(node=None, net_index=3, remove=True) + # remove all four possible control networks + self.add_remove_control_net(0, remove=True) + self.add_remove_control_net(1, remove=True) + self.add_remove_control_net(2, remove=True) + self.add_remove_control_net(3, remove=True) - def check_shutdown(self): + def check_shutdown(self) -> bool: """ Check if we have entered the shutdown state, when no running nodes and links remain. + + :return: True if should shutdown, False otherwise """ node_count = self.get_node_count() logging.debug( "session(%s) checking shutdown: %s nodes remaining", self.id, node_count ) - shutdown = False if node_count == 0: shutdown = True self.set_state(EventTypes.SHUTDOWN_STATE) - return shutdown - def short_session_id(self): + def short_session_id(self) -> str: """ Return a shorter version of the session ID, appropriate for interface names, where length may be limited. + + :return: short session id """ ssid = (self.id >> 8) ^ (self.id & ((1 << 8) - 1)) return f"{ssid:x}" - def boot_node(self, node): + def boot_node(self, node: CoreNode) -> None: """ Boot node by adding a control interface when necessary and starting node services. - :param core.nodes.base.CoreNodeBase node: node to boot + :param core.nodes.base.CoreNode node: node to boot :return: nothing """ logging.info("booting node(%s): %s", node.name, [x.name for x in node.services]) self.add_remove_control_interface(node=node, remove=False) self.services.boot_services(node) - def boot_nodes(self): + def boot_nodes(self) -> List[Exception]: """ Invoke the boot() procedure for all nodes and send back node messages to the GUI for node messages that had the status request flag. :return: service boot exceptions - :rtype: list[core.services.coreservices.ServiceBootError] + :rtype: list[Exception] """ with self._nodes_lock: funcs = [] @@ -1596,7 +1633,7 @@ class Session: self.update_control_interface_hosts() return exceptions - def get_control_net_prefixes(self): + def get_control_net_prefixes(self) -> List[str]: """ Retrieve control net prefixes. @@ -1608,13 +1645,11 @@ class Session: p1 = self.options.get_config("controlnet1") p2 = self.options.get_config("controlnet2") p3 = self.options.get_config("controlnet3") - if not p0 and p: p0 = p - return [p0, p1, p2, p3] - def get_control_net_server_interfaces(self): + def get_control_net_server_interfaces(self) -> List[str]: """ Retrieve control net server interfaces. @@ -1629,7 +1664,7 @@ class Session: d3 = self.options.get_config("controlnetif3") return [None, d1, d2, d3] - def get_control_net_index(self, dev): + def get_control_net_index(self, dev: str) -> int: """ Retrieve control net index. @@ -1645,10 +1680,22 @@ class Session: return index return -1 - def get_control_net(self, net_index): - return self.get_node(CTRL_NET_ID + net_index) + def get_control_net(self, net_index: int) -> CtrlNet: + """ + Retrieve a control net based on index. - def add_remove_control_net(self, net_index, remove=False, conf_required=True): + :param net_index: control net index + :return: control net + :raises CoreError: when control net is not found + """ + node = self.get_node(CTRL_NET_ID + net_index) + if not isinstance(node, CtrlNet): + raise CoreError("node is not a valid CtrlNet: %s", node.name) + return node + + def add_remove_control_net( + self, net_index: int, remove: bool = False, conf_required: bool = True + ) -> Optional[CtrlNet]: """ Create a control network bridge as necessary. When the remove flag is True, remove the bridge that connects control @@ -1682,11 +1729,9 @@ class Session: # return any existing controlnet bridge try: control_net = self.get_control_net(net_index) - if remove: self.delete_node(control_net.id) return None - return control_net except CoreError: if remove: @@ -1730,12 +1775,15 @@ class Session: updown_script=updown_script, serverintf=server_interface, ) - return control_net def add_remove_control_interface( - self, node, net_index=0, remove=False, conf_required=True - ): + self, + node: CoreNode, + net_index: int = 0, + remove: bool = False, + conf_required: bool = True, + ) -> None: """ Add a control interface to a node when a 'controlnet' prefix is listed in the config file or session options. Uses @@ -1782,7 +1830,9 @@ class Session: ) node.netif(interface1).control = True - def update_control_interface_hosts(self, net_index=0, remove=False): + def update_control_interface_hosts( + self, net_index: int = 0, remove: bool = False + ) -> None: """ Add the IP addresses of control interfaces to the /etc/hosts file. @@ -1813,10 +1863,9 @@ class Session: entries.append(f"{address} {name}") logging.info("Adding %d /etc/hosts file entries.", len(entries)) - utils.file_munge("/etc/hosts", header, "\n".join(entries) + "\n") - def runtime(self): + def runtime(self) -> float: """ Return the current time we have been in the runtime state, or zero if not in runtime. @@ -1826,7 +1875,13 @@ class Session: else: return 0.0 - def add_event(self, event_time, node=None, name=None, data=None): + def add_event( + self, + event_time: float, + node: CoreNode = None, + name: str = None, + data: str = None, + ) -> None: """ Add an event to the event queue, with a start time relative to the start of the runtime state. @@ -1865,7 +1920,9 @@ class Session: # TODO: if data is None, this blows up, but this ties into how event functions # are ran, need to clean that up - def run_event(self, node_id=None, name=None, data=None): + def run_event( + self, node_id: int = None, name: str = None, data: str = None + ) -> None: """ Run a scheduled event, executing commands in the data string. diff --git a/daemon/core/emulator/sessionconfig.py b/daemon/core/emulator/sessionconfig.py index eb38474b..b403b8d6 100644 --- a/daemon/core/emulator/sessionconfig.py +++ b/daemon/core/emulator/sessionconfig.py @@ -1,3 +1,5 @@ +from typing import Any + from core.config import ConfigurableManager, ConfigurableOptions, Configuration from core.emulator.enumerations import ConfigDataTypes, RegisterTlvs from core.plugins.sdt import Sdt @@ -60,29 +62,53 @@ class SessionConfig(ConfigurableManager, ConfigurableOptions): ] config_type = RegisterTlvs.UTILITY.value - def __init__(self): + def __init__(self) -> None: super().__init__() self.set_configs(self.default_values()) def get_config( self, - _id, - node_id=ConfigurableManager._default_node, - config_type=ConfigurableManager._default_type, - default=None, - ): + _id: str, + node_id: int = ConfigurableManager._default_node, + config_type: str = ConfigurableManager._default_type, + default: Any = None, + ) -> str: + """ + Retrieves a specific configuration for a node and configuration type. + + :param str _id: specific configuration to retrieve + :param int node_id: node id to store configuration for + :param str config_type: configuration type to store configuration for + :param default: default value to return when value is not found + :return: configuration value + :rtype str + """ value = super().get_config(_id, node_id, config_type, default) if value == "": value = default return value - def get_config_bool(self, name, default=None): + def get_config_bool(self, name: str, default: Any = None) -> bool: + """ + Get configuration value as a boolean. + + :param name: configuration name + :param default: default value if not found + :return: boolean for configuration value + """ value = self.get_config(name) if value is None: return default return value.lower() == "true" - def get_config_int(self, name, default=None): + def get_config_int(self, name: str, default: Any = None) -> int: + """ + Get configuration value as int. + + :param name: configuration name + :param default: default value if not found + :return: int for configuration value + """ value = self.get_config(name, default=default) if value is not None: value = int(value) diff --git a/daemon/core/errors.py b/daemon/core/errors.py index f5c38b5b..319bd190 100644 --- a/daemon/core/errors.py +++ b/daemon/core/errors.py @@ -9,7 +9,7 @@ class CoreCommandError(subprocess.CalledProcessError): Used when encountering internal CORE command errors. """ - def __str__(self): + def __str__(self) -> str: return ( f"Command({self.cmd}), Status({self.returncode}):\n" f"stdout: {self.output}\nstderr: {self.stderr}" diff --git a/daemon/core/gui/app.py b/daemon/core/gui/app.py index 2596a411..e70b9f52 100644 --- a/daemon/core/gui/app.py +++ b/daemon/core/gui/app.py @@ -17,7 +17,7 @@ HEIGHT = 800 class Application(tk.Frame): - def __init__(self, proxy): + def __init__(self, proxy: bool): super().__init__(master=None) # load node icons NodeUtils.setup() diff --git a/daemon/core/gui/coreclient.py b/daemon/core/gui/coreclient.py index e367608a..46009324 100644 --- a/daemon/core/gui/coreclient.py +++ b/daemon/core/gui/coreclient.py @@ -5,6 +5,7 @@ import json import logging import os from pathlib import Path +from typing import TYPE_CHECKING, Dict, List import grpc @@ -14,11 +15,16 @@ from core.gui.dialogs.mobilityplayer import MobilityPlayer from core.gui.dialogs.sessions import SessionsDialog from core.gui.errors import show_grpc_error from core.gui.graph import tags +from core.gui.graph.edges import CanvasEdge +from core.gui.graph.node import CanvasNode from core.gui.graph.shape import AnnotationData, Shape from core.gui.graph.shapeutils import ShapeType from core.gui.interface import InterfaceManager from core.gui.nodeutils import NodeDraw, NodeUtils +if TYPE_CHECKING: + from core.gui.app import Application + GUI_SOURCE = "gui" OBSERVERS = { "processes": "ps", @@ -34,20 +40,20 @@ OBSERVERS = { class CoreServer: - def __init__(self, name, address, port): + def __init__(self, name: str, address: str, port: int): self.name = name self.address = address self.port = port class Observer: - def __init__(self, name, cmd): + def __init__(self, name: str, cmd: str): self.name = name self.cmd = cmd class CoreClient: - def __init__(self, app, proxy): + def __init__(self, app: "Application", proxy: bool): """ Create a CoreGrpc instance """ @@ -110,7 +116,7 @@ class CoreClient: self.handling_events.cancel() self.handling_events = None - def set_observer(self, value): + def set_observer(self, value: str): self.observer = value def read_config(self): @@ -132,7 +138,7 @@ class CoreClient: observer = Observer(config["name"], config["cmd"]) self.custom_observers[observer.name] = observer - def handle_events(self, event): + def handle_events(self, event: core_pb2.Event): if event.session_id != self.session_id: logging.warning( "ignoring event session(%s) current(%s)", @@ -170,7 +176,7 @@ class CoreClient: else: logging.info("unhandled event: %s", event) - def handle_link_event(self, event): + def handle_link_event(self, event: core_pb2.LinkEvent): node_one_id = event.link.node_one_id node_two_id = event.link.node_two_id canvas_node_one = self.canvas_nodes[node_one_id] @@ -183,7 +189,7 @@ class CoreClient: else: logging.warning("unknown link event: %s", event.message_type) - def handle_node_event(self, event): + def handle_node_event(self, event: core_pb2.NodeEvent): if event.source == GUI_SOURCE: return node_id = event.node.id @@ -201,7 +207,7 @@ class CoreClient: self.handling_throughputs.cancel() self.handling_throughputs = None - def handle_throughputs(self, event): + def handle_throughputs(self, event: core_pb2.ThroughputsEvent): if event.session_id != self.session_id: logging.warning( "ignoring throughput event session(%s) current(%s)", @@ -212,11 +218,11 @@ class CoreClient: logging.info("handling throughputs event: %s", event) self.app.canvas.set_throughputs(event) - def handle_exception_event(self, event): + def handle_exception_event(self, event: core_pb2.ExceptionEvent): logging.info("exception event: %s", event) self.app.statusbar.core_alarms.append(event) - def join_session(self, session_id, query_location=True): + def join_session(self, session_id: int, query_location: bool = True): # update session and title self.session_id = session_id self.master.title(f"CORE Session({self.session_id})") @@ -297,10 +303,10 @@ class CoreClient: # update ui to represent current state self.app.after(0, self.app.joined_session_update) - def is_runtime(self): + def is_runtime(self) -> bool: return self.state == core_pb2.SessionState.RUNTIME - def parse_metadata(self, config): + def parse_metadata(self, config: Dict[str, str]): # canvas setting canvas_config = config.get("canvas") logging.info("canvas metadata: %s", canvas_config) @@ -364,8 +370,6 @@ class CoreClient: def create_new_session(self): """ Create a new session - - :return: nothing """ try: response = self.client.create_session() @@ -384,7 +388,7 @@ class CoreClient: except grpc.RpcError as e: self.app.after(0, show_grpc_error, e) - def delete_session(self, session_id=None): + def delete_session(self, session_id: int = None): if session_id is None: session_id = self.session_id try: @@ -396,8 +400,6 @@ class CoreClient: def set_up(self): """ Query sessions, if there exist any, prompt whether to join one - - :return: existing sessions """ try: self.client.connect() @@ -425,7 +427,7 @@ class CoreClient: self.app.after(0, show_grpc_error, e) self.app.close() - def edit_node(self, core_node): + def edit_node(self, core_node: core_pb2.Node): try: self.client.edit_node( self.session_id, core_node.id, core_node.position, source=GUI_SOURCE @@ -433,7 +435,7 @@ class CoreClient: except grpc.RpcError as e: self.app.after(0, show_grpc_error, e) - def start_session(self): + def start_session(self) -> core_pb2.StartSessionResponse: nodes = [x.core_node for x in self.canvas_nodes.values()] links = [x.link for x in self.links.values()] wlan_configs = self.get_wlan_configs_proto() @@ -476,7 +478,7 @@ class CoreClient: self.app.after(0, show_grpc_error, e) return response - def stop_session(self, session_id=None): + def stop_session(self, session_id: int = None) -> core_pb2.StartSessionResponse: if not session_id: session_id = self.session_id response = core_pb2.StopSessionResponse(result=False) @@ -518,21 +520,19 @@ class CoreClient: response = self.client.set_session_metadata(self.session_id, metadata) logging.info("set session metadata: %s", response) - def launch_terminal(self, node_id): + def launch_terminal(self, node_id: int): try: terminal = self.app.guiconfig["preferences"]["terminal"] response = self.client.get_node_terminal(self.session_id, node_id) - logging.info("get terminal %s", response.terminal) - os.system(f"{terminal} {response.terminal} &") + cmd = f'{terminal} "{response.terminal}" &' + logging.info("launching terminal %s", cmd) + os.system(cmd) except grpc.RpcError as e: self.app.after(0, show_grpc_error, e) - def save_xml(self, file_path): + def save_xml(self, file_path: str): """ Save core session as to an xml file - - :param str file_path: file path that user pick - :return: nothing """ try: if self.state != core_pb2.SessionState.RUNTIME: @@ -545,12 +545,9 @@ class CoreClient: except grpc.RpcError as e: self.app.after(0, show_grpc_error, e) - def open_xml(self, file_path): + def open_xml(self, file_path: str): """ Open core xml - - :param str file_path: file to open - :return: session id """ try: response = self.client.open_xml(file_path) @@ -559,12 +556,21 @@ class CoreClient: except grpc.RpcError as e: self.app.after(0, show_grpc_error, e) - def get_node_service(self, node_id, service_name): + def get_node_service( + self, node_id: int, service_name: str + ) -> core_pb2.NodeServiceData: response = self.client.get_node_service(self.session_id, node_id, service_name) logging.debug("get node service %s", response) return response.service - def set_node_service(self, node_id, service_name, startups, validations, shutdowns): + def set_node_service( + self, + node_id: int, + service_name: str, + startups: List[str], + validations: List[str], + shutdowns: List[str], + ) -> core_pb2.NodeServiceData: response = self.client.set_node_service( self.session_id, node_id, service_name, startups, validations, shutdowns ) @@ -573,14 +579,18 @@ class CoreClient: logging.debug("get node service : %s", response) return response.service - def get_node_service_file(self, node_id, service_name, file_name): + def get_node_service_file( + self, node_id: int, service_name: str, file_name: str + ) -> str: response = self.client.get_node_service_file( self.session_id, node_id, service_name, file_name ) logging.debug("get service file %s", response) return response.data - def set_node_service_file(self, node_id, service_name, file_name, data): + def set_node_service_file( + self, node_id: int, service_name: str, file_name: str, data: bytes + ): response = self.client.set_node_service_file( self.session_id, node_id, service_name, file_name, data ) @@ -589,8 +599,6 @@ class CoreClient: def create_nodes_and_links(self): """ create nodes and links that have not been created yet - - :return: nothing """ node_protos = [x.core_node for x in self.canvas_nodes.values()] link_protos = [x.link for x in self.links.values()] @@ -617,8 +625,6 @@ class CoreClient: def send_data(self): """ send to daemon all session info, but don't start the session - - :return: nothing """ self.create_nodes_and_links() for config_proto in self.get_wlan_configs_proto(): @@ -663,18 +669,13 @@ class CoreClient: def close(self): """ Clean ups when done using grpc - - :return: nothing """ logging.debug("close grpc") self.client.close() - def next_node_id(self): + def next_node_id(self) -> int: """ Get the next usable node id. - - :return: the next id to be used - :rtype: int """ i = 1 while True: @@ -683,15 +684,11 @@ class CoreClient: i += 1 return i - def create_node(self, x, y, node_type, model): + def create_node( + self, x: int, y: int, node_type: core_pb2.NodeType, model: str + ) -> core_pb2.Node: """ Add node, with information filled in, to grpc manager - - :param int x: x coord - :param int y: y coord - :param core_pb2.NodeType node_type: node type - :param str model: node model - :return: nothing """ node_id = self.next_node_id() position = core_pb2.Position(x=x, y=y) @@ -726,13 +723,10 @@ class CoreClient: ) return node - def delete_graph_nodes(self, canvas_nodes): + def delete_graph_nodes(self, canvas_nodes: List[core_pb2.Node]): """ remove the nodes selected by the user and anything related to that node such as link, configurations, interfaces - - :param list canvas_nodes: list of nodes to delete - :return: nothing """ edges = set() for canvas_node in canvas_nodes: @@ -754,12 +748,9 @@ class CoreClient: if edge in edges: continue edges.add(edge) - # - # if edge.token not in self.links: - # logging.error("unknown edge: %s", edge.token) self.links.pop(edge.token, None) - def create_interface(self, canvas_node): + def create_interface(self, canvas_node: CanvasNode) -> core_pb2.Interface: node = canvas_node.core_node ip4, ip6, prefix = self.interfaces_manager.get_ips(node.id) interface_id = len(canvas_node.interfaces) @@ -776,16 +767,12 @@ class CoreClient: ) return interface - def create_link(self, edge, canvas_src_node, canvas_dst_node): + def create_link( + self, edge: CanvasEdge, canvas_src_node: CanvasNode, canvas_dst_node: CanvasNode + ): """ Create core link for a pair of canvas nodes, with token referencing the canvas edge. - - :param edge: edge for link - :param canvas_src_node: canvas node one - :param canvas_dst_node: canvas node two - - :return: nothing """ src_node = canvas_src_node.core_node dst_node = canvas_dst_node.core_node @@ -815,7 +802,7 @@ class CoreClient: edge.set_link(link) self.links[edge.token] = edge - def get_wlan_configs_proto(self): + def get_wlan_configs_proto(self) -> List[core_pb2.WlanConfig]: configs = [] for node_id, config in self.wlan_configs.items(): config = {x: config[x].value for x in config} @@ -823,7 +810,7 @@ class CoreClient: configs.append(wlan_config) return configs - def get_mobility_configs_proto(self): + def get_mobility_configs_proto(self) -> List[core_pb2.MobilityConfig]: configs = [] for node_id, config in self.mobility_configs.items(): config = {x: config[x].value for x in config} @@ -831,7 +818,7 @@ class CoreClient: configs.append(mobility_config) return configs - def get_emane_model_configs_proto(self): + def get_emane_model_configs_proto(self) -> List[core_pb2.EmaneModelConfig]: configs = [] for key, config in self.emane_model_configs.items(): node_id, model, interface = key @@ -844,7 +831,7 @@ class CoreClient: configs.append(config_proto) return configs - def get_service_configs_proto(self): + def get_service_configs_proto(self) -> List[core_pb2.ServiceConfig]: configs = [] for node_id, services in self.service_configs.items(): for name, config in services.items(): @@ -858,7 +845,7 @@ class CoreClient: configs.append(config_proto) return configs - def get_service_file_configs_proto(self): + def get_service_file_configs_proto(self) -> List[core_pb2.ServiceFileConfig]: configs = [] for (node_id, file_configs) in self.file_configs.items(): for service, file_config in file_configs.items(): @@ -869,25 +856,27 @@ class CoreClient: configs.append(config_proto) return configs - def run(self, node_id): + def run(self, node_id: int) -> str: logging.info("running node(%s) cmd: %s", node_id, self.observer) return self.client.node_command(self.session_id, node_id, self.observer).output - def get_wlan_config(self, node_id): + def get_wlan_config(self, node_id: int) -> Dict[str, core_pb2.ConfigOption]: config = self.wlan_configs.get(node_id) if not config: response = self.client.get_wlan_config(self.session_id, node_id) config = response.config return config - def get_mobility_config(self, node_id): + def get_mobility_config(self, node_id: int) -> Dict[str, core_pb2.ConfigOption]: config = self.mobility_configs.get(node_id) if not config: response = self.client.get_mobility_config(self.session_id, node_id) config = response.config return config - def get_emane_model_config(self, node_id, model, interface=None): + def get_emane_model_config( + self, node_id: int, model: str, interface: int = None + ) -> Dict[str, core_pb2.ConfigOption]: logging.info("getting emane model config: %s %s %s", node_id, model, interface) config = self.emane_model_configs.get((node_id, model, interface)) if not config: @@ -899,15 +888,21 @@ class CoreClient: config = response.config return config - def set_emane_model_config(self, node_id, model, config, interface=None): + def set_emane_model_config( + self, + node_id: int, + model: str, + config: Dict[str, core_pb2.ConfigOption], + interface: int = None, + ): logging.info("setting emane model config: %s %s %s", node_id, model, interface) self.emane_model_configs[(node_id, model, interface)] = config - def copy_node_service(self, _from, _to): + def copy_node_service(self, _from: int, _to: int): services = self.canvas_nodes[_from].core_node.services self.canvas_nodes[_to].core_node.services[:] = services - def copy_node_config(self, _from, _to): + def copy_node_config(self, _from: int, _to: int): node_type = self.canvas_nodes[_from].core_node.type if node_type == core_pb2.NodeType.DEFAULT: services = self.canvas_nodes[_from].core_node.services diff --git a/daemon/core/gui/dialogs/about.py b/daemon/core/gui/dialogs/about.py index d54266f4..bf498bb8 100644 --- a/daemon/core/gui/dialogs/about.py +++ b/daemon/core/gui/dialogs/about.py @@ -1,9 +1,13 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.dialogs.dialog import Dialog from core.gui.widgets import CodeText +if TYPE_CHECKING: + from core.gui.app import Application + LICENSE = """\ Copyright (c) 2005-2020, the Boeing Company. @@ -31,7 +35,7 @@ THE POSSIBILITY OF SUCH DAMAGE.\ class AboutDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "About CORE", modal=True) self.draw() diff --git a/daemon/core/gui/dialogs/alerts.py b/daemon/core/gui/dialogs/alerts.py index 7e82da73..6c07f214 100644 --- a/daemon/core/gui/dialogs/alerts.py +++ b/daemon/core/gui/dialogs/alerts.py @@ -3,15 +3,19 @@ check engine light """ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.api.grpc.core_pb2 import ExceptionLevel from core.gui.dialogs.dialog import Dialog from core.gui.themes import PADX, PADY from core.gui.widgets import CodeText +if TYPE_CHECKING: + from core.gui.app import Application + class AlertsDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Alerts", modal=True) self.app = app self.tree = None @@ -110,7 +114,7 @@ class AlertsDialog(Dialog): dialog = DaemonLog(self, self.app) dialog.show() - def click_select(self, event): + def click_select(self, event: tk.Event): current = self.tree.selection()[0] alarm = self.alarm_map[current] self.codetext.text.config(state=tk.NORMAL) @@ -120,7 +124,7 @@ class AlertsDialog(Dialog): class DaemonLog(Dialog): - def __init__(self, master, app): + def __init__(self, master: tk.Widget, app: "Application"): super().__init__(master, app, "core-daemon log", modal=True) self.columnconfigure(0, weight=1) self.path = tk.StringVar(value="/var/log/core-daemon.log") diff --git a/daemon/core/gui/dialogs/canvassizeandscale.py b/daemon/core/gui/dialogs/canvassizeandscale.py index 11cc97b4..f04b991c 100644 --- a/daemon/core/gui/dialogs/canvassizeandscale.py +++ b/daemon/core/gui/dialogs/canvassizeandscale.py @@ -3,19 +3,21 @@ size and scale """ import tkinter as tk from tkinter import font, ttk +from typing import TYPE_CHECKING from core.gui.dialogs.dialog import Dialog from core.gui.themes import FRAME_PAD, PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + PIXEL_SCALE = 100 class SizeAndScaleDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): """ create an instance for size and scale object - - :param app: main application """ super().__init__(master, app, "Canvas Size and Scale", modal=True) self.canvas = self.app.canvas diff --git a/daemon/core/gui/dialogs/canvaswallpaper.py b/daemon/core/gui/dialogs/canvaswallpaper.py index 62fa8fe5..093d93b0 100644 --- a/daemon/core/gui/dialogs/canvaswallpaper.py +++ b/daemon/core/gui/dialogs/canvaswallpaper.py @@ -4,6 +4,7 @@ set wallpaper import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.appconfig import BACKGROUNDS_PATH from core.gui.dialogs.dialog import Dialog @@ -11,13 +12,14 @@ from core.gui.images import Images from core.gui.themes import PADX, PADY from core.gui.widgets import image_chooser +if TYPE_CHECKING: + from core.gui.app import Application + class CanvasWallpaperDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): """ create an instance of CanvasWallpaper object - - :param coretk.app.Application app: root application """ super().__init__(master, app, "Canvas Background", modal=True) self.canvas = self.app.canvas @@ -140,8 +142,6 @@ class CanvasWallpaperDialog(Dialog): def click_clear(self): """ delete like shown in image link entry if there is any - - :return: nothing """ # delete entry self.filename.set("") diff --git a/daemon/core/gui/dialogs/colorpicker.py b/daemon/core/gui/dialogs/colorpicker.py index 28d21f42..8962ea9f 100644 --- a/daemon/core/gui/dialogs/colorpicker.py +++ b/daemon/core/gui/dialogs/colorpicker.py @@ -4,12 +4,16 @@ custom color picker import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Any from core.gui.dialogs.dialog import Dialog +if TYPE_CHECKING: + from core.gui.app import Application + class ColorPickerDialog(Dialog): - def __init__(self, master, app, initcolor="#000000"): + def __init__(self, master: Any, app: "Application", initcolor: str = "#000000"): super().__init__(master, app, "color picker", modal=True) self.red_entry = None self.blue_entry = None @@ -31,7 +35,7 @@ class ColorPickerDialog(Dialog): self.draw() self.set_bindings() - def askcolor(self): + def askcolor(self) -> str: self.show() return self.color @@ -175,19 +179,16 @@ class ColorPickerDialog(Dialog): self.color = self.hex.get() self.destroy() - def get_hex(self): + def get_hex(self) -> str: """ convert current RGB values into hex color - - :rtype: str - :return: hex color """ red = self.red_entry.get() blue = self.blue_entry.get() green = self.green_entry.get() return "#%02x%02x%02x" % (int(red), int(green), int(blue)) - def current_focus(self, focus): + def current_focus(self, focus: str): self.focus = focus def update_color(self, arg1=None, arg2=None, arg3=None): @@ -210,35 +211,31 @@ class ColorPickerDialog(Dialog): self.set_entry(red, green, blue) self.set_scale(red, green, blue) self.display.config(background=hex_code) - self.set_label(red, green, blue) + self.set_label(str(red), str(green), str(blue)) - def scale_callback(self, var, color_var): + def scale_callback(self, var: tk.IntVar, color_var: tk.IntVar): color_var.set(var.get()) self.focus = "rgb" self.update_color() - def set_scale(self, red, green, blue): + def set_scale(self, red: int, green: int, blue: int): self.red_scale.set(red) self.green_scale.set(green) self.blue_scale.set(blue) - def set_entry(self, red, green, blue): + def set_entry(self, red: int, green: int, blue: int): self.red.set(red) self.green.set(green) self.blue.set(blue) - def set_label(self, red, green, blue): + def set_label(self, red: str, green: str, blue: str): self.red_label.configure(background="#%02x%02x%02x" % (int(red), 0, 0)) self.green_label.configure(background="#%02x%02x%02x" % (0, int(green), 0)) self.blue_label.configure(background="#%02x%02x%02x" % (0, 0, int(blue))) - def get_rgb(self, hex_code): + def get_rgb(self, hex_code: str) -> [int, int, int]: """ convert a valid hex code to RGB values - - :param string hex_code: color in hex - :rtype: tuple(int, int, int) - :return: the RGB values """ if len(hex_code) == 4: red = hex_code[1] diff --git a/daemon/core/gui/dialogs/copyserviceconfig.py b/daemon/core/gui/dialogs/copyserviceconfig.py index 39306db7..994058fc 100644 --- a/daemon/core/gui/dialogs/copyserviceconfig.py +++ b/daemon/core/gui/dialogs/copyserviceconfig.py @@ -5,14 +5,18 @@ copy service config dialog import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Any, Tuple from core.gui.dialogs.dialog import Dialog from core.gui.themes import FRAME_PAD, PADX from core.gui.widgets import CodeText +if TYPE_CHECKING: + from core.gui.app import Application + class CopyServiceConfigDialog(Dialog): - def __init__(self, master, app, node_id): + def __init__(self, master: Any, app: "Application", node_id: int): super().__init__(master, app, f"Copy services to node {node_id}", modal=True) self.parent = master self.app = app @@ -128,6 +132,7 @@ class CopyServiceConfigDialog(Dialog): def click_view(self): selected = self.tree.selection() + data = "" if selected: item = self.tree.item(selected[0]) if "file" in item["tags"]: @@ -157,7 +162,7 @@ class CopyServiceConfigDialog(Dialog): ) dialog.show() - def get_node_service(self, selected): + def get_node_service(self, selected: Tuple[str]) -> [int, str]: service_tree_id = self.tree.parent(selected[0]) service_name = self.tree.item(service_tree_id)["text"] node_tree_id = self.tree.parent(service_tree_id) @@ -166,7 +171,14 @@ class CopyServiceConfigDialog(Dialog): class ViewConfigDialog(Dialog): - def __init__(self, master, app, node_id, data, filename=None): + def __init__( + self, + master: Any, + app: "Application", + node_id: int, + data: str, + filename: str = None, + ): super().__init__(master, app, f"n{node_id} config data", modal=True) self.data = data self.service_data = None diff --git a/daemon/core/gui/dialogs/customnodes.py b/daemon/core/gui/dialogs/customnodes.py index 86303a69..fe694651 100644 --- a/daemon/core/gui/dialogs/customnodes.py +++ b/daemon/core/gui/dialogs/customnodes.py @@ -2,6 +2,7 @@ import logging import tkinter as tk from pathlib import Path from tkinter import ttk +from typing import TYPE_CHECKING, Any, Set from core.gui import nodeutils from core.gui.appconfig import ICONS_PATH @@ -11,9 +12,12 @@ from core.gui.nodeutils import NodeDraw from core.gui.themes import FRAME_PAD, PADX, PADY from core.gui.widgets import CheckboxList, ListboxScroll, image_chooser +if TYPE_CHECKING: + from core.gui.app import Application + class ServicesSelectDialog(Dialog): - def __init__(self, master, app, current_services): + def __init__(self, master: Any, app: "Application", current_services: Set[str]): super().__init__(master, app, "Node Services", modal=True) self.groups = None self.services = None @@ -71,7 +75,7 @@ class ServicesSelectDialog(Dialog): # trigger group change self.groups.listbox.event_generate("<>") - def handle_group_change(self, event): + def handle_group_change(self, event: tk.Event): selection = self.groups.listbox.curselection() if selection: index = selection[0] @@ -81,7 +85,7 @@ class ServicesSelectDialog(Dialog): checked = name in self.current_services self.services.add(name, checked) - def service_clicked(self, name, var): + def service_clicked(self, name: str, var: tk.BooleanVar): if var.get() and name not in self.current_services: self.current_services.add(name) elif not var.get() and name in self.current_services: @@ -96,7 +100,7 @@ class ServicesSelectDialog(Dialog): class CustomNodesDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Custom Nodes", modal=True) self.edit_button = None self.delete_button = None @@ -241,7 +245,7 @@ class CustomNodesDialog(Dialog): self.nodes_list.listbox.selection_clear(0, tk.END) self.nodes_list.listbox.event_generate("<>") - def handle_node_select(self, event): + def handle_node_select(self, event: tk.Event): selection = self.nodes_list.listbox.curselection() if selection: self.selected_index = selection[0] diff --git a/daemon/core/gui/dialogs/dialog.py b/daemon/core/gui/dialogs/dialog.py index 3e6d54f6..00532793 100644 --- a/daemon/core/gui/dialogs/dialog.py +++ b/daemon/core/gui/dialogs/dialog.py @@ -1,12 +1,18 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.images import ImageEnum, Images from core.gui.themes import DIALOG_PAD +if TYPE_CHECKING: + from core.gui.app import Application + class Dialog(tk.Toplevel): - def __init__(self, master, app, title, modal=False): + def __init__( + self, master: tk.Widget, app: "Application", title: str, modal: bool = False + ): super().__init__(master) self.withdraw() self.app = app @@ -30,7 +36,7 @@ class Dialog(tk.Toplevel): self.grab_set() self.wait_window() - def draw_spacer(self, row=None): + def draw_spacer(self, row: int = None): frame = ttk.Frame(self.top) frame.grid(row=row, sticky="nsew") frame.rowconfigure(0, weight=1) diff --git a/daemon/core/gui/dialogs/emaneconfig.py b/daemon/core/gui/dialogs/emaneconfig.py index 6f3dedd8..6ef0be78 100644 --- a/daemon/core/gui/dialogs/emaneconfig.py +++ b/daemon/core/gui/dialogs/emaneconfig.py @@ -5,18 +5,24 @@ import logging import tkinter as tk import webbrowser from tkinter import ttk +from typing import TYPE_CHECKING, Any import grpc +from core.api.grpc import core_pb2 from core.gui.dialogs.dialog import Dialog from core.gui.errors import show_grpc_error from core.gui.images import ImageEnum, Images from core.gui.themes import PADX, PADY from core.gui.widgets import ConfigFrame +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode + class GlobalEmaneDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: Any, app: "Application"): super().__init__(master, app, "EMANE Configuration", modal=True) self.config_frame = None self.draw() @@ -47,7 +53,14 @@ class GlobalEmaneDialog(Dialog): class EmaneModelDialog(Dialog): - def __init__(self, master, app, node, model, interface=None): + def __init__( + self, + master: Any, + app: "Application", + node: core_pb2.Node, + model: str, + interface: int = None, + ): super().__init__(master, app, f"{node.name} {model} Configuration", modal=True) self.node = node self.model = f"emane_{model}" @@ -91,7 +104,9 @@ class EmaneModelDialog(Dialog): class EmaneConfigDialog(Dialog): - def __init__(self, master, app, canvas_node): + def __init__( + self, master: "Application", app: "Application", canvas_node: "CanvasNode" + ): super().__init__( master, app, f"{canvas_node.core_node.name} EMANE Configuration", modal=True ) @@ -116,8 +131,6 @@ class EmaneConfigDialog(Dialog): def draw_emane_configuration(self): """ draw the main frame for emane configuration - - :return: nothing """ label = ttk.Label( self.top, @@ -143,8 +156,6 @@ class EmaneConfigDialog(Dialog): def draw_emane_models(self): """ create a combobox that has all the known emane models - - :return: nothing """ frame = ttk.Frame(self.top) frame.grid(sticky="ew", pady=PADY) @@ -210,8 +221,6 @@ class EmaneConfigDialog(Dialog): def click_model_config(self): """ draw emane model configuration - - :return: nothing """ model_name = self.emane_model.get() logging.info("configuring emane model: %s", model_name) @@ -220,12 +229,9 @@ class EmaneConfigDialog(Dialog): ) dialog.show() - def emane_model_change(self, event): + def emane_model_change(self, event: tk.Event): """ update emane model options button - - :param event: - :return: nothing """ model_name = self.emane_model.get() self.emane_model_button.config(text=f"{model_name} options") diff --git a/daemon/core/gui/dialogs/hooks.py b/daemon/core/gui/dialogs/hooks.py index 79741add..ad8ad533 100644 --- a/daemon/core/gui/dialogs/hooks.py +++ b/daemon/core/gui/dialogs/hooks.py @@ -1,14 +1,18 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Any from core.api.grpc import core_pb2 from core.gui.dialogs.dialog import Dialog from core.gui.themes import PADX, PADY from core.gui.widgets import CodeText, ListboxScroll +if TYPE_CHECKING: + from core.gui.app import Application + class HookDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: Any, app: "Application"): super().__init__(master, app, "Hook", modal=True) self.name = tk.StringVar() self.codetext = None @@ -62,11 +66,11 @@ class HookDialog(Dialog): button = ttk.Button(frame, text="Cancel", command=lambda: self.destroy()) button.grid(row=0, column=1, sticky="ew") - def state_change(self, event): + def state_change(self, event: tk.Event): state_name = self.state.get() self.name.set(f"{state_name.lower()}_hook.sh") - def set(self, hook): + def set(self, hook: core_pb2.Hook): self.hook = hook self.name.set(hook.file) self.codetext.text.delete(1.0, tk.END) @@ -84,7 +88,7 @@ class HookDialog(Dialog): class HooksDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Hooks", modal=True) self.listbox = None self.edit_button = None @@ -140,7 +144,7 @@ class HooksDialog(Dialog): self.edit_button.config(state=tk.DISABLED) self.delete_button.config(state=tk.DISABLED) - def select(self, event): + def select(self, event: tk.Event): if self.listbox.curselection(): index = self.listbox.curselection()[0] self.selected = self.listbox.get(index) diff --git a/daemon/core/gui/dialogs/linkconfig.py b/daemon/core/gui/dialogs/linkconfig.py index 9fd9130b..5f16a586 100644 --- a/daemon/core/gui/dialogs/linkconfig.py +++ b/daemon/core/gui/dialogs/linkconfig.py @@ -4,14 +4,19 @@ link configuration import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Union from core.api.grpc import core_pb2 from core.gui.dialogs.colorpicker import ColorPickerDialog from core.gui.dialogs.dialog import Dialog from core.gui.themes import PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.graph import CanvasGraph, CanvasEdge -def get_int(var): + +def get_int(var: tk.StringVar) -> Union[int, None]: value = var.get() if value != "": return int(value) @@ -19,7 +24,7 @@ def get_int(var): return None -def get_float(var): +def get_float(var: tk.StringVar) -> Union[float, None]: value = var.get() if value != "": return float(value) @@ -28,7 +33,7 @@ def get_float(var): class LinkConfigurationDialog(Dialog): - def __init__(self, master, app, edge): + def __init__(self, master: "CanvasGraph", app: "Application", edge: "CanvasEdge"): super().__init__(master, app, "Link Configuration", modal=True) self.app = app self.edge = edge @@ -103,7 +108,7 @@ class LinkConfigurationDialog(Dialog): button = ttk.Button(frame, text="Cancel", command=self.destroy) button.grid(row=0, column=1, sticky="ew") - def get_frame(self): + def get_frame(self) -> ttk.Frame: frame = ttk.Frame(self.top) frame.columnconfigure(1, weight=1) if self.is_symmetric: @@ -339,8 +344,6 @@ class LinkConfigurationDialog(Dialog): def load_link_config(self): """ populate link config to the table - - :return: nothing """ width = self.app.canvas.itemcget(self.edge.id, "width") self.width.set(width) diff --git a/daemon/core/gui/dialogs/marker.py b/daemon/core/gui/dialogs/marker.py index 159abd7f..1db9ca49 100644 --- a/daemon/core/gui/dialogs/marker.py +++ b/daemon/core/gui/dialogs/marker.py @@ -4,15 +4,21 @@ marker dialog import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.dialogs.colorpicker import ColorPickerDialog from core.gui.dialogs.dialog import Dialog +if TYPE_CHECKING: + from core.gui.app import Application + MARKER_THICKNESS = [3, 5, 8, 10] class MarkerDialog(Dialog): - def __init__(self, master, app, initcolor="#000000"): + def __init__( + self, master: "Application", app: "Application", initcolor: str = "#000000" + ): super().__init__(master, app, "marker tool", modal=False) self.app = app self.color = initcolor @@ -53,13 +59,13 @@ class MarkerDialog(Dialog): for i in canvas.find_withtag("marker"): canvas.delete(i) - def change_color(self, event): + def change_color(self, event: tk.Event): color_picker = ColorPickerDialog(self, self.app, self.color) color = color_picker.askcolor() event.widget.configure(background=color) self.color = color - def change_thickness(self, event): + def change_thickness(self, event: tk.Event): self.radius = self.marker_thickness.get() def show(self): diff --git a/daemon/core/gui/dialogs/mobilityconfig.py b/daemon/core/gui/dialogs/mobilityconfig.py index 3b9c1ca6..18e62a17 100644 --- a/daemon/core/gui/dialogs/mobilityconfig.py +++ b/daemon/core/gui/dialogs/mobilityconfig.py @@ -2,6 +2,7 @@ mobility configuration """ from tkinter import ttk +from typing import TYPE_CHECKING import grpc @@ -10,9 +11,15 @@ from core.gui.errors import show_grpc_error from core.gui.themes import PADX, PADY from core.gui.widgets import ConfigFrame +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode + class MobilityConfigDialog(Dialog): - def __init__(self, master, app, canvas_node): + def __init__( + self, master: "Application", app: "Application", canvas_node: "CanvasNode" + ): super().__init__( master, app, diff --git a/daemon/core/gui/dialogs/mobilityplayer.py b/daemon/core/gui/dialogs/mobilityplayer.py index 9c2848d8..873a2b37 100644 --- a/daemon/core/gui/dialogs/mobilityplayer.py +++ b/daemon/core/gui/dialogs/mobilityplayer.py @@ -1,5 +1,6 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Any import grpc @@ -9,11 +10,21 @@ from core.gui.errors import show_grpc_error from core.gui.images import ImageEnum, Images from core.gui.themes import PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode + ICON_SIZE = 16 class MobilityPlayer: - def __init__(self, master, app, canvas_node, config): + def __init__( + self, + master: "Application", + app: "Application", + canvas_node: "CanvasNode", + config, + ): self.master = master self.app = app self.canvas_node = canvas_node @@ -57,7 +68,9 @@ class MobilityPlayer: class MobilityPlayerDialog(Dialog): - def __init__(self, master, app, canvas_node, config): + def __init__( + self, master: Any, app: "Application", canvas_node: "CanvasNode", config + ): super().__init__( master, app, f"{canvas_node.core_node.name} Mobility Player", modal=False ) diff --git a/daemon/core/gui/dialogs/nodeconfig.py b/daemon/core/gui/dialogs/nodeconfig.py index c3dd646a..7db65dc7 100644 --- a/daemon/core/gui/dialogs/nodeconfig.py +++ b/daemon/core/gui/dialogs/nodeconfig.py @@ -2,6 +2,7 @@ import logging import tkinter as tk from functools import partial from tkinter import ttk +from typing import TYPE_CHECKING from core.gui import nodeutils from core.gui.appconfig import ICONS_PATH @@ -12,20 +13,32 @@ from core.gui.nodeutils import NodeUtils from core.gui.themes import FRAME_PAD, PADX, PADY from core.gui.widgets import ListboxScroll, image_chooser +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode -def mac_auto(is_auto, entry): + +def mac_auto(is_auto: tk.BooleanVar, entry: ttk.Entry): logging.info("mac auto clicked") if is_auto.get(): logging.info("disabling mac") - entry.var.set("") + entry.delete(0, tk.END) + entry.insert(tk.END, "") entry.config(state=tk.DISABLED) else: - entry.var.set("00:00:00:00:00:00") + entry.delete(0, tk.END) + entry.insert(tk.END, "00:00:00:00:00:00") entry.config(state=tk.NORMAL) class InterfaceData: - def __init__(self, is_auto, mac, ip4, ip6): + def __init__( + self, + is_auto: tk.BooleanVar, + mac: tk.StringVar, + ip4: tk.StringVar, + ip6: tk.StringVar, + ): self.is_auto = is_auto self.mac = mac self.ip4 = ip4 @@ -33,13 +46,11 @@ class InterfaceData: class NodeConfigDialog(Dialog): - def __init__(self, master, app, canvas_node): + def __init__( + self, master: "Application", app: "Application", canvas_node: "CanvasNode" + ): """ create an instance of node configuration - - :param master: dialog master - :param coretk.app.Application: main app - :param coretk.graph.CanvasNode canvas_node: canvas node object """ super().__init__( master, app, f"{canvas_node.core_node.name} Configuration", modal=True @@ -217,7 +228,7 @@ class NodeConfigDialog(Dialog): button = ttk.Button(frame, text="Cancel", command=self.destroy) button.grid(row=0, column=1, sticky="ew") - def click_emane_config(self, emane_model, interface_id): + def click_emane_config(self, emane_model: str, interface_id: int): dialog = EmaneModelDialog(self, self.app, self.node, emane_model, interface_id) dialog.show() @@ -248,7 +259,7 @@ class NodeConfigDialog(Dialog): self.canvas_node.redraw() self.destroy() - def interface_select(self, event): + def interface_select(self, event: tk.Event): listbox = event.widget cur = listbox.curselection() if cur: diff --git a/daemon/core/gui/dialogs/nodeservice.py b/daemon/core/gui/dialogs/nodeservice.py index e0d36412..a3928c9c 100644 --- a/daemon/core/gui/dialogs/nodeservice.py +++ b/daemon/core/gui/dialogs/nodeservice.py @@ -3,15 +3,26 @@ core node services """ import tkinter as tk from tkinter import messagebox, ttk +from typing import TYPE_CHECKING, Any, Set from core.gui.dialogs.dialog import Dialog from core.gui.dialogs.serviceconfig import ServiceConfigDialog from core.gui.themes import FRAME_PAD, PADX, PADY from core.gui.widgets import CheckboxList, ListboxScroll +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode + class NodeServiceDialog(Dialog): - def __init__(self, master, app, canvas_node, services=None): + def __init__( + self, + master: Any, + app: "Application", + canvas_node: "CanvasNode", + services: Set[str] = None, + ): title = f"{canvas_node.core_node.name} Services" super().__init__(master, app, title, modal=True) self.app = app @@ -87,7 +98,7 @@ class NodeServiceDialog(Dialog): # trigger group change self.groups.listbox.event_generate("<>") - def handle_group_change(self, event=None): + def handle_group_change(self, event: tk.Event = None): selection = self.groups.listbox.curselection() if selection: index = selection[0] @@ -97,7 +108,7 @@ class NodeServiceDialog(Dialog): checked = name in self.current_services self.services.add(name, checked) - def service_clicked(self, name, var): + def service_clicked(self, name: str, var: tk.IntVar): if var.get() and name not in self.current_services: self.current_services.add(name) elif not var.get() and name in self.current_services: @@ -150,7 +161,7 @@ class NodeServiceDialog(Dialog): checkbutton.invoke() return - def is_custom_service(self, service): + def is_custom_service(self, service: str) -> bool: service_configs = self.app.core.service_configs file_configs = self.app.core.file_configs if self.node_id in service_configs and service in service_configs[self.node_id]: diff --git a/daemon/core/gui/dialogs/observers.py b/daemon/core/gui/dialogs/observers.py index 5f0f1b1e..9fe3f79e 100644 --- a/daemon/core/gui/dialogs/observers.py +++ b/daemon/core/gui/dialogs/observers.py @@ -1,14 +1,18 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.coreclient import Observer from core.gui.dialogs.dialog import Dialog from core.gui.themes import PADX, PADY from core.gui.widgets import ListboxScroll +if TYPE_CHECKING: + from core.gui.app import Application + class ObserverDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Observer Widgets", modal=True) self.observers = None self.save_button = None @@ -126,7 +130,7 @@ class ObserverDialog(Dialog): self.save_button.config(state=tk.DISABLED) self.delete_button.config(state=tk.DISABLED) - def handle_observer_change(self, event): + def handle_observer_change(self, event: tk.Event): selection = self.observers.curselection() if selection: self.selected_index = selection[0] diff --git a/daemon/core/gui/dialogs/preferences.py b/daemon/core/gui/dialogs/preferences.py index 6208990a..f60da652 100644 --- a/daemon/core/gui/dialogs/preferences.py +++ b/daemon/core/gui/dialogs/preferences.py @@ -1,14 +1,18 @@ import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui import appconfig from core.gui.dialogs.dialog import Dialog from core.gui.themes import FRAME_PAD, PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + class PreferencesDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Preferences", modal=True) preferences = self.app.guiconfig["preferences"] self.editor = tk.StringVar(value=preferences["editor"]) @@ -72,7 +76,7 @@ class PreferencesDialog(Dialog): button = ttk.Button(frame, text="Cancel", command=self.destroy) button.grid(row=0, column=1, sticky="ew") - def theme_change(self, event): + def theme_change(self, event: tk.Event): theme = self.theme.get() logging.info("changing theme: %s", theme) self.app.style.theme_use(theme) diff --git a/daemon/core/gui/dialogs/servers.py b/daemon/core/gui/dialogs/servers.py index a0eadec2..c57e97d3 100644 --- a/daemon/core/gui/dialogs/servers.py +++ b/daemon/core/gui/dialogs/servers.py @@ -1,18 +1,22 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.coreclient import CoreServer from core.gui.dialogs.dialog import Dialog from core.gui.themes import FRAME_PAD, PADX, PADY from core.gui.widgets import ListboxScroll +if TYPE_CHECKING: + from core.gui.app import Application + DEFAULT_NAME = "example" DEFAULT_ADDRESS = "127.0.0.1" DEFAULT_PORT = 50051 class ServersDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "CORE Servers", modal=True) self.name = tk.StringVar(value=DEFAULT_NAME) self.address = tk.StringVar(value=DEFAULT_ADDRESS) @@ -155,7 +159,7 @@ class ServersDialog(Dialog): self.save_button.config(state=tk.DISABLED) self.delete_button.config(state=tk.DISABLED) - def handle_server_change(self, event): + def handle_server_change(self, event: tk.Event): selection = self.servers.curselection() if selection: self.selected_index = selection[0] diff --git a/daemon/core/gui/dialogs/serviceconfig.py b/daemon/core/gui/dialogs/serviceconfig.py index 56c0a4c0..804e7e3f 100644 --- a/daemon/core/gui/dialogs/serviceconfig.py +++ b/daemon/core/gui/dialogs/serviceconfig.py @@ -1,6 +1,9 @@ -"Service configuration dialog" +""" +Service configuration dialog +""" import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Any, List import grpc @@ -12,9 +15,14 @@ from core.gui.images import ImageEnum, Images from core.gui.themes import FRAME_PAD, PADX, PADY from core.gui.widgets import CodeText, ListboxScroll +if TYPE_CHECKING: + from core.gui.app import Application + class ServiceConfigDialog(Dialog): - def __init__(self, master, app, service_name, node_id): + def __init__( + self, master: Any, app: "Application", service_name: str, node_id: int + ): title = f"{service_name} Service" super().__init__(master, app, title, modal=True) self.master = master @@ -225,7 +233,7 @@ class ServiceConfigDialog(Dialog): for i in range(3): tab.rowconfigure(i, weight=1) self.notebook.add(tab, text="Startup/Shutdown") - + commands = [] # tab 3 for i in range(3): label_frame = None @@ -345,7 +353,7 @@ class ServiceConfigDialog(Dialog): button = ttk.Button(frame, text="Cancel", command=self.destroy) button.grid(row=0, column=3, sticky="ew") - def add_filename(self, event): + def add_filename(self, event: tk.Event): # not worry about it for now return frame_contains_button = event.widget.master @@ -354,7 +362,7 @@ class ServiceConfigDialog(Dialog): if filename not in combobox["values"]: combobox["values"] += (filename,) - def delete_filename(self, event): + def delete_filename(self, event: tk.Event): # not worry about it for now return frame_comntains_button = event.widget.master @@ -364,7 +372,7 @@ class ServiceConfigDialog(Dialog): combobox["values"] = tuple([x for x in combobox["values"] if x != filename]) combobox.set("") - def add_command(self, event): + def add_command(self, event: tk.Event): frame_contains_button = event.widget.master listbox = frame_contains_button.master.grid_slaves(row=1, column=0)[0].listbox command_to_add = frame_contains_button.grid_slaves(row=0, column=0)[0].get() @@ -375,7 +383,7 @@ class ServiceConfigDialog(Dialog): return listbox.insert(tk.END, command_to_add) - def update_entry(self, event): + def update_entry(self, event: tk.Event): listbox = event.widget current_selection = listbox.curselection() if len(current_selection) > 0: @@ -386,7 +394,7 @@ class ServiceConfigDialog(Dialog): entry.delete(0, "end") entry.insert(0, cmd) - def delete_command(self, event): + def delete_command(self, event: tk.Event): button = event.widget frame_contains_button = button.master listbox = frame_contains_button.master.grid_slaves(row=1, column=0)[0].listbox @@ -439,13 +447,13 @@ class ServiceConfigDialog(Dialog): show_grpc_error(e) self.destroy() - def display_service_file_data(self, event): + def display_service_file_data(self, event: tk.Event): combobox = event.widget filename = combobox.get() self.service_file_data.text.delete(1.0, "end") self.service_file_data.text.insert("end", self.temp_service_files[filename]) - def update_temp_service_file_data(self, event): + def update_temp_service_file_data(self, event: tk.Event): scrolledtext = event.widget filename = self.filename_combobox.get() self.temp_service_files[filename] = scrolledtext.get(1.0, "end") @@ -490,7 +498,9 @@ class ServiceConfigDialog(Dialog): dialog = CopyServiceConfigDialog(self, self.app, self.node_id) dialog.show() - def append_commands(self, commands, listbox, to_add): + def append_commands( + self, commands: List[str], listbox: tk.Listbox, to_add: List[str] + ): for cmd in to_add: commands.append(cmd) listbox.insert(tk.END, cmd) diff --git a/daemon/core/gui/dialogs/sessionoptions.py b/daemon/core/gui/dialogs/sessionoptions.py index 040629b5..ffd61340 100644 --- a/daemon/core/gui/dialogs/sessionoptions.py +++ b/daemon/core/gui/dialogs/sessionoptions.py @@ -1,5 +1,6 @@ import logging from tkinter import ttk +from typing import TYPE_CHECKING import grpc @@ -8,9 +9,12 @@ from core.gui.errors import show_grpc_error from core.gui.themes import PADX, PADY from core.gui.widgets import ConfigFrame +if TYPE_CHECKING: + from core.gui.app import Application + class SessionOptionsDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Session Options", modal=True) self.config_frame = None self.config = self.get_config() diff --git a/daemon/core/gui/dialogs/sessions.py b/daemon/core/gui/dialogs/sessions.py index 0ec1a8cf..f717462d 100644 --- a/daemon/core/gui/dialogs/sessions.py +++ b/daemon/core/gui/dialogs/sessions.py @@ -1,6 +1,7 @@ import logging import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING, Iterable import grpc @@ -11,9 +12,14 @@ from core.gui.images import ImageEnum, Images from core.gui.task import BackgroundTask from core.gui.themes import PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + class SessionsDialog(Dialog): - def __init__(self, master, app, is_start_app=False): + def __init__( + self, master: "Application", app: "Application", is_start_app: bool = False + ): super().__init__(master, app, "Sessions", modal=True) self.is_start_app = is_start_app self.selected = False @@ -22,7 +28,7 @@ class SessionsDialog(Dialog): self.sessions = self.get_sessions() self.draw() - def get_sessions(self): + def get_sessions(self) -> Iterable[core_pb2.SessionSummary]: try: response = self.app.core.client.get_sessions() logging.info("sessions: %s", response) @@ -41,7 +47,6 @@ class SessionsDialog(Dialog): def draw_description(self): """ write a short description - :return: nothing """ label = ttk.Label( self.top, @@ -154,7 +159,7 @@ class SessionsDialog(Dialog): self.app.core.create_new_session() self.destroy() - def click_select(self, event): + def click_select(self, event: tk.Event): item = self.tree.selection() session_id = int(self.tree.item(item, "text")) self.selected = True @@ -163,8 +168,6 @@ class SessionsDialog(Dialog): def click_connect(self): """ if no session is selected yet, create a new one else join that session - - :return: nothing """ if self.selected and self.selected_id is not None: self.join_session(self.selected_id) @@ -177,8 +180,6 @@ class SessionsDialog(Dialog): """ if no session is currently selected create a new session else shut the selected session down. - - :return: nothing """ if self.selected and self.selected_id is not None: self.shutdown_session(self.selected_id) @@ -187,18 +188,18 @@ class SessionsDialog(Dialog): else: logging.error("querysessiondrawing.py invalid state") - def join_session(self, session_id): + def join_session(self, session_id: int): self.app.statusbar.progress_bar.start(5) task = BackgroundTask(self.app, self.app.core.join_session, args=(session_id,)) task.start() self.destroy() - def on_selected(self, event): + def on_selected(self, event: tk.Event): item = self.tree.selection() sid = int(self.tree.item(item, "text")) self.join_session(sid) - def shutdown_session(self, sid): + def shutdown_session(self, sid: int): self.app.core.stop_session(sid) self.click_new() self.destroy() diff --git a/daemon/core/gui/dialogs/shapemod.py b/daemon/core/gui/dialogs/shapemod.py index ba799220..791e1f71 100644 --- a/daemon/core/gui/dialogs/shapemod.py +++ b/daemon/core/gui/dialogs/shapemod.py @@ -3,6 +3,7 @@ shape input dialog """ import tkinter as tk from tkinter import font, ttk +from typing import TYPE_CHECKING, List, Union from core.gui.dialogs.colorpicker import ColorPickerDialog from core.gui.dialogs.dialog import Dialog @@ -10,12 +11,16 @@ from core.gui.graph import tags from core.gui.graph.shapeutils import is_draw_shape, is_shape_text from core.gui.themes import FRAME_PAD, PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.shape import Shape + FONT_SIZES = [8, 9, 10, 11, 12, 14, 16, 18, 20, 22, 24, 26, 28, 36, 48, 72] BORDER_WIDTH = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] class ShapeDialog(Dialog): - def __init__(self, master, app, shape): + def __init__(self, master: "Application", app: "Application", shape: "Shape"): if is_draw_shape(shape.shape_type): title = "Add Shape" else: @@ -162,10 +167,9 @@ class ShapeDialog(Dialog): self.add_text() self.destroy() - def make_font(self): + def make_font(self) -> List[Union[int, str]]: """ create font for text or shape label - :return: list(font specifications) """ size = int(self.font_size.get()) text_font = [self.font.get(), size] @@ -180,8 +184,6 @@ class ShapeDialog(Dialog): def save_text(self): """ save info related to text or shape label - - :return: nothing """ data = self.shape.shape_data data.text = self.shape_text.get() @@ -195,8 +197,6 @@ class ShapeDialog(Dialog): def save_shape(self): """ save info related to shape - - :return: nothing """ data = self.shape.shape_data data.fill_color = self.fill_color @@ -206,8 +206,6 @@ class ShapeDialog(Dialog): def add_text(self): """ add text to canvas - - :return: nothing """ text = self.shape_text.get() text_font = self.make_font() diff --git a/daemon/core/gui/dialogs/throughput.py b/daemon/core/gui/dialogs/throughput.py index 150a3c5a..96aa3bc5 100644 --- a/daemon/core/gui/dialogs/throughput.py +++ b/daemon/core/gui/dialogs/throughput.py @@ -3,14 +3,18 @@ throughput dialog """ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.dialogs.colorpicker import ColorPickerDialog from core.gui.dialogs.dialog import Dialog from core.gui.themes import FRAME_PAD, PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + class ThroughputDialog(Dialog): - def __init__(self, master, app): + def __init__(self, master: "Application", app: "Application"): super().__init__(master, app, "Throughput Config", modal=False) self.app = app self.canvas = app.canvas diff --git a/daemon/core/gui/dialogs/wlanconfig.py b/daemon/core/gui/dialogs/wlanconfig.py index 3ceaa7e8..264b9e2e 100644 --- a/daemon/core/gui/dialogs/wlanconfig.py +++ b/daemon/core/gui/dialogs/wlanconfig.py @@ -3,6 +3,7 @@ wlan configuration """ from tkinter import ttk +from typing import TYPE_CHECKING import grpc @@ -11,9 +12,15 @@ from core.gui.errors import show_grpc_error from core.gui.themes import PADX, PADY from core.gui.widgets import ConfigFrame +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.node import CanvasNode + class WlanConfigDialog(Dialog): - def __init__(self, master, app, canvas_node): + def __init__( + self, master: "Application", app: "Application", canvas_node: "CanvasNode" + ): super().__init__( master, app, f"{canvas_node.core_node.name} Wlan Configuration", modal=True ) @@ -38,8 +45,6 @@ class WlanConfigDialog(Dialog): def draw_apply_buttons(self): """ create node configuration options - - :return: nothing """ frame = ttk.Frame(self.top) frame.grid(sticky="ew") @@ -55,8 +60,6 @@ class WlanConfigDialog(Dialog): def click_apply(self): """ retrieve user's wlan configuration and store the new configuration values - - :return: nothing """ config = self.config_frame.parse_config() self.app.core.wlan_configs[self.node.id] = self.config diff --git a/daemon/core/gui/errors.py b/daemon/core/gui/errors.py index 936968ad..18e025db 100644 --- a/daemon/core/gui/errors.py +++ b/daemon/core/gui/errors.py @@ -1,7 +1,11 @@ from tkinter import messagebox +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import grpc -def show_grpc_error(e): +def show_grpc_error(e: "grpc.RpcError"): title = [x.capitalize() for x in e.code().name.lower().split("_")] title = " ".join(title) title = f"GRPC {title}" diff --git a/daemon/core/gui/graph/edges.py b/daemon/core/gui/graph/edges.py index 323309de..2ca30031 100644 --- a/daemon/core/gui/graph/edges.py +++ b/daemon/core/gui/graph/edges.py @@ -1,19 +1,30 @@ import logging import tkinter as tk from tkinter.font import Font +from typing import TYPE_CHECKING, Any, Tuple from core.gui import themes from core.gui.dialogs.linkconfig import LinkConfigurationDialog from core.gui.graph import tags from core.gui.nodeutils import NodeUtils +if TYPE_CHECKING: + from core.gui.graph.graph import CanvasGraph + TEXT_DISTANCE = 0.30 EDGE_WIDTH = 3 EDGE_COLOR = "#ff0000" class CanvasWirelessEdge: - def __init__(self, token, position, src, dst, canvas): + def __init__( + self, + token: Tuple[Any, ...], + position: Tuple[float, float, float, float], + src: int, + dst: int, + canvas: "CanvasGraph", + ): self.token = token self.src = src self.dst = dst @@ -31,15 +42,17 @@ class CanvasEdge: Canvas edge class """ - def __init__(self, x1, y1, x2, y2, src, canvas): + def __init__( + self, + x1: float, + y1: float, + x2: float, + y2: float, + src: int, + canvas: "CanvasGraph", + ): """ Create an instance of canvas edge object - :param int x1: source x-coord - :param int y1: source y-coord - :param int x2: destination x-coord - :param int y2: destination y-coord - :param int src: source id - :param coretk.graph.graph.GraphCanvas canvas: canvas object """ self.src = src self.dst = None @@ -66,7 +79,7 @@ class CanvasEdge: self.link = link self.draw_labels() - def get_coordinates(self): + def get_coordinates(self) -> [float, float, float, float]: x1, y1, x2, y2 = self.canvas.coords(self.id) v1 = x2 - x1 v2 = y2 - y1 @@ -78,7 +91,7 @@ class CanvasEdge: y2 = y2 - uy return x1, y1, x2, y2 - def get_midpoint(self): + def get_midpoint(self) -> [float, float]: x1, y1, x2, y2 = self.canvas.coords(self.id) x = (x1 + x2) / 2 y = (y1 + y2) / 2 @@ -118,8 +131,6 @@ class CanvasEdge: def update_labels(self): """ Move edge labels based on current position. - - :return: nothing """ x1, y1, x2, y2 = self.get_coordinates() self.canvas.coords(self.text_src, x1, y1) @@ -128,7 +139,7 @@ class CanvasEdge: x, y = self.get_midpoint() self.canvas.coords(self.text_middle, x, y) - def set_throughput(self, throughput): + def set_throughput(self, throughput: float): throughput = 0.001 * throughput value = f"{throughput:.3f} kbps" if self.text_middle is None: @@ -147,7 +158,7 @@ class CanvasEdge: width = EDGE_WIDTH self.canvas.itemconfig(self.id, fill=color, width=width) - def complete(self, dst): + def complete(self, dst: int): self.dst = dst self.token = tuple(sorted((self.src, self.dst))) x, y = self.canvas.coords(self.dst) @@ -157,7 +168,7 @@ class CanvasEdge: self.canvas.tag_raise(self.src) self.canvas.tag_raise(self.dst) - def is_wireless(self): + def is_wireless(self) -> [bool, bool]: src_node = self.canvas.nodes[self.src] dst_node = self.canvas.nodes[self.dst] src_node_type = src_node.core_node.type @@ -183,7 +194,6 @@ class CanvasEdge: dst_node.add_antenna() elif not is_src_wireless and is_dst_wireless: src_node.add_antenna() - # TODO: remove this? dont allow linking wireless nodes? else: src_node.add_antenna() @@ -199,7 +209,7 @@ class CanvasEdge: self.text_middle = None self.canvas.itemconfig(self.id, fill=EDGE_COLOR, width=EDGE_WIDTH) - def create_context(self, event): + def create_context(self, event: tk.Event): logging.debug("create link context") context = tk.Menu(self.canvas) themes.style_menu(context) diff --git a/daemon/core/gui/graph/graph.py b/daemon/core/gui/graph/graph.py index fb3279a5..1a354209 100644 --- a/daemon/core/gui/graph/graph.py +++ b/daemon/core/gui/graph/graph.py @@ -1,5 +1,6 @@ import logging import tkinter as tk +from typing import TYPE_CHECKING, List, Tuple from PIL import Image, ImageTk @@ -15,12 +16,18 @@ from core.gui.graph.shapeutils import ShapeType, is_draw_shape, is_marker from core.gui.images import Images from core.gui.nodeutils import NodeUtils +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.coreclient import CoreClient + ZOOM_IN = 1.1 ZOOM_OUT = 0.9 class CanvasGraph(tk.Canvas): - def __init__(self, master, core, width, height): + def __init__( + self, master: "Application", core: "CoreClient", width: int, height: int + ): super().__init__(master, highlightthickness=0, background="#cccccc") self.app = master self.core = core @@ -67,7 +74,7 @@ class CanvasGraph(tk.Canvas): self.draw_canvas() self.draw_grid() - def draw_canvas(self, dimensions=None): + def draw_canvas(self, dimensions: Tuple[int, int] = None): if self.grid is not None: self.delete(self.grid) if not dimensions: @@ -84,13 +91,11 @@ class CanvasGraph(tk.Canvas): ) self.configure(scrollregion=self.bbox(tk.ALL)) - def reset_and_redraw(self, session): + def reset_and_redraw(self, session: core_pb2.Session): """ Reset the private variables CanvasGraph object, redraw nodes given the new grpc client. - :param core.api.grpc.core_pb2.Session session: session to draw - :return: nothing """ # hide context self.hide_context() @@ -114,8 +119,6 @@ class CanvasGraph(tk.Canvas): def setup_bindings(self): """ Bind any mouse events or hot keys to the matching action - - :return: nothing """ self.bind("", self.click_press) self.bind("", self.click_release) @@ -135,28 +138,28 @@ class CanvasGraph(tk.Canvas): self.context.unpost() self.context = None - def get_actual_coords(self, x, y): + def get_actual_coords(self, x: float, y: float) -> [float, float]: actual_x = (x - self.offset[0]) / self.ratio actual_y = (y - self.offset[1]) / self.ratio return actual_x, actual_y - def get_scaled_coords(self, x, y): + def get_scaled_coords(self, x: float, y: float) -> [float, float]: scaled_x = (x * self.ratio) + self.offset[0] scaled_y = (y * self.ratio) + self.offset[1] return scaled_x, scaled_y - def inside_canvas(self, x, y): + def inside_canvas(self, x: float, y: float) -> [bool, bool]: x1, y1, x2, y2 = self.bbox(self.grid) valid_x = x1 <= x <= x2 valid_y = y1 <= y <= y2 return valid_x and valid_y - def valid_position(self, x1, y1, x2, y2): + def valid_position(self, x1: int, y1: int, x2: int, y2: int) -> [bool, bool]: valid_topleft = self.inside_canvas(x1, y1) valid_bottomright = self.inside_canvas(x2, y2) return valid_topleft and valid_bottomright - def set_throughputs(self, throughputs_event): + def set_throughputs(self, throughputs_event: core_pb2.ThroughputsEvent): for interface_throughput in throughputs_event.interface_throughputs: node_id = interface_throughput.node_id interface_id = interface_throughput.interface_id @@ -174,8 +177,6 @@ class CanvasGraph(tk.Canvas): def draw_grid(self): """ Create grid. - - :return: nothing """ width, height = self.width_and_height() width = int(width) @@ -187,13 +188,9 @@ class CanvasGraph(tk.Canvas): self.tag_lower(tags.GRIDLINE) self.tag_lower(self.grid) - def add_wireless_edge(self, src, dst): + def add_wireless_edge(self, src: CanvasNode, dst: CanvasNode): """ add a wireless edge between 2 canvas nodes - - :param CanvasNode src: source node - :param CanvasNode dst: destination node - :return: nothing """ token = tuple(sorted((src.id, dst.id))) x1, y1 = self.coords(src.id) @@ -206,18 +203,16 @@ class CanvasGraph(tk.Canvas): self.tag_raise(src.id) self.tag_raise(dst.id) - def delete_wireless_edge(self, src, dst): + def delete_wireless_edge(self, src: CanvasNode, dst: CanvasNode): token = tuple(sorted((src.id, dst.id))) edge = self.wireless_edges.pop(token) edge.delete() src.wireless_edges.remove(edge) dst.wireless_edges.remove(edge) - def draw_session(self, session): + def draw_session(self, session: core_pb2.Session): """ Draw existing session. - - :return: nothing """ # draw existing nodes for core_node in session.nodes: @@ -296,25 +291,17 @@ class CanvasGraph(tk.Canvas): for edge in self.edges.values(): edge.reset() - def canvas_xy(self, event): + def canvas_xy(self, event: tk.Event) -> [float, float]: """ Convert window coordinate to canvas coordinate - - :param event: - :rtype: (int, int) - :return: x, y canvas coordinate """ x = self.canvasx(event.x) y = self.canvasy(event.y) return x, y - def get_selected(self, event): + def get_selected(self, event: tk.Event) -> int: """ Retrieve the item id that is on the mouse position - - :param event: mouse event - :rtype: int - :return: the item that the mouse point to """ x, y = self.canvas_xy(event) overlapping = self.find_overlapping(x, y, x, y) @@ -332,12 +319,9 @@ class CanvasGraph(tk.Canvas): return selected - def click_release(self, event): + def click_release(self, event: tk.Event): """ Draw a node or finish drawing an edge according to the current graph mode - - :param event: mouse event - :return: nothing """ logging.debug("click release") x, y = self.canvas_xy(event) @@ -380,7 +364,7 @@ class CanvasGraph(tk.Canvas): self.mode = GraphMode.NODE self.selected = None - def handle_edge_release(self, event): + def handle_edge_release(self, event: tk.Event): edge = self.drawing_edge self.drawing_edge = None @@ -417,7 +401,7 @@ class CanvasGraph(tk.Canvas): node_dst.edges.add(edge) self.core.create_link(edge, node_src, node_dst) - def select_object(self, object_id, choose_multiple=False): + def select_object(self, object_id: int, choose_multiple: bool = False): """ create a bounding box when a node is selected """ @@ -441,19 +425,17 @@ class CanvasGraph(tk.Canvas): def clear_selection(self): """ Clear current selection boxes. - - :return: nothing """ for _id in self.selection.values(): self.delete(_id) self.selection.clear() - def move_selection(self, object_id, x_offset, y_offset): + def move_selection(self, object_id: int, x_offset: float, y_offset: float): select_id = self.selection.get(object_id) if select_id is not None: self.move(select_id, x_offset, y_offset) - def delete_selection_objects(self): + def delete_selection_objects(self) -> List[CanvasNode]: edges = set() nodes = [] for object_id in self.selection: @@ -499,7 +481,7 @@ class CanvasGraph(tk.Canvas): self.selection.clear() return nodes - def zoom(self, event, factor=None): + def zoom(self, event: tk.Event, factor: float = None): if not factor: factor = ZOOM_IN if event.delta > 0 else ZOOM_OUT event.x, event.y = self.canvasx(event.x), self.canvasy(event.y) @@ -517,12 +499,9 @@ class CanvasGraph(tk.Canvas): if self.wallpaper: self.redraw_wallpaper() - def click_press(self, event): + def click_press(self, event: tk.Event): """ Start drawing an edge if mouse click is on a node - - :param event: mouse event - :return: nothing """ x, y = self.canvas_xy(event) if not self.inside_canvas(x, y): @@ -581,7 +560,7 @@ class CanvasGraph(tk.Canvas): self.select_box = shape self.clear_selection() - def ctrl_click(self, event): + def ctrl_click(self, event: tk.Event): # update cursor location x, y = self.canvas_xy(event) if not self.inside_canvas(x, y): @@ -599,12 +578,9 @@ class CanvasGraph(tk.Canvas): ): self.select_object(selected, choose_multiple=True) - def click_motion(self, event): + def click_motion(self, event: tk.Event): """ Redraw drawing edge according to the current position of the mouse - - :param event: mouse event - :return: nothing """ x, y = self.canvas_xy(event) if not self.inside_canvas(x, y): @@ -658,7 +634,7 @@ class CanvasGraph(tk.Canvas): if self.select_box and self.mode == GraphMode.SELECT: self.select_box.shape_motion(x, y) - def click_context(self, event): + def click_context(self, event: tk.Event): logging.info("context event: %s", self.context) if not self.context: selected = self.get_selected(event) @@ -670,24 +646,22 @@ class CanvasGraph(tk.Canvas): else: self.hide_context() - def press_delete(self, event): + def press_delete(self, event: tk.Event): """ delete selected nodes and any data that relates to it - :param event: - :return: """ logging.debug("press delete key") nodes = self.delete_selection_objects() self.core.delete_graph_nodes(nodes) - def double_click(self, event): + def double_click(self, event: tk.Event): selected = self.get_selected(event) if selected is not None and selected in self.shapes: shape = self.shapes[selected] dialog = ShapeDialog(self.app, self.app, shape) dialog.show() - def add_node(self, x, y): + def add_node(self, x: float, y: float) -> CanvasNode: if self.selected is None or self.selected in self.shapes: actual_x, actual_y = self.get_actual_coords(x, y) core_node = self.core.create_node( @@ -701,26 +675,25 @@ class CanvasGraph(tk.Canvas): def width_and_height(self): """ retrieve canvas width and height in pixels - - :return: nothing """ x0, y0, x1, y1 = self.coords(self.grid) canvas_w = abs(x0 - x1) canvas_h = abs(y0 - y1) return canvas_w, canvas_h - def get_wallpaper_image(self): + def get_wallpaper_image(self) -> Image.Image: width = int(self.wallpaper.width * self.ratio) height = int(self.wallpaper.height * self.ratio) image = self.wallpaper.resize((width, height), Image.ANTIALIAS) return image - def draw_wallpaper(self, image, x=None, y=None): + def draw_wallpaper( + self, image: ImageTk.PhotoImage, x: float = None, y: float = None + ): if x is None and y is None: x1, y1, x2, y2 = self.bbox(self.grid) x = (x1 + x2) / 2 y = (y1 + y2) / 2 - self.wallpaper_id = self.create_image((x, y), image=image, tags=tags.WALLPAPER) self.wallpaper_drawn = image @@ -748,8 +721,6 @@ class CanvasGraph(tk.Canvas): def wallpaper_center(self): """ place the image at the center of canvas - - :return: nothing """ self.delete(self.wallpaper_id) @@ -773,8 +744,6 @@ class CanvasGraph(tk.Canvas): def wallpaper_scaled(self): """ scale image based on canvas dimension - - :return: nothing """ self.delete(self.wallpaper_id) canvas_w, canvas_h = self.width_and_height() @@ -788,7 +757,7 @@ class CanvasGraph(tk.Canvas): self.redraw_canvas((image.width(), image.height())) self.draw_wallpaper(image) - def redraw_canvas(self, dimensions=None): + def redraw_canvas(self, dimensions: Tuple[int, int] = None): logging.info("redrawing canvas to dimensions: %s", dimensions) # reset scale and move back to original position @@ -836,7 +805,7 @@ class CanvasGraph(tk.Canvas): else: self.itemconfig(tags.GRIDLINE, state=tk.HIDDEN) - def set_wallpaper(self, filename): + def set_wallpaper(self, filename: str): logging.info("setting wallpaper: %s", filename) if filename: img = Image.open(filename) @@ -849,16 +818,12 @@ class CanvasGraph(tk.Canvas): self.wallpaper = None self.wallpaper_file = None - def is_selection_mode(self): + def is_selection_mode(self) -> bool: return self.mode == GraphMode.SELECT - def create_edge(self, source, dest): + def create_edge(self, source: CanvasNode, dest: CanvasNode): """ create an edge between source node and destination node - - :param CanvasNode source: source node - :param CanvasNode dest: destination node - :return: nothing """ if (source.id, dest.id) not in self.edges: pos0 = source.core_node.position diff --git a/daemon/core/gui/graph/node.py b/daemon/core/gui/graph/node.py index c43cbe9c..c1d8e075 100644 --- a/daemon/core/gui/graph/node.py +++ b/daemon/core/gui/graph/node.py @@ -1,5 +1,6 @@ import tkinter as tk from tkinter import font +from typing import TYPE_CHECKING import grpc @@ -16,11 +17,22 @@ from core.gui.graph import tags from core.gui.graph.tooltip import CanvasTooltip from core.gui.nodeutils import NodeUtils +if TYPE_CHECKING: + from core.gui.app import Application + from PIL.ImageTk import PhotoImage + NODE_TEXT_OFFSET = 5 class CanvasNode: - def __init__(self, app, x, y, core_node, image): + def __init__( + self, + app: "Application", + x: float, + y: float, + core_node: core_pb2.Node, + image: "PhotoImage", + ): self.app = app self.canvas = app.canvas self.image = image @@ -70,8 +82,6 @@ class CanvasNode: def delete_antenna(self): """ delete one antenna - - :return: nothing """ if self.antennae: antenna_id = self.antennae.pop() @@ -80,8 +90,6 @@ class CanvasNode: def delete_antennae(self): """ delete all antennas - - :return: nothing """ for antenna_id in self.antennae: self.canvas.delete(antenna_id) @@ -95,14 +103,14 @@ class CanvasNode: image_box = self.canvas.bbox(self.id) return image_box[3] + NODE_TEXT_OFFSET - def move(self, x, y): + def move(self, x: int, y: int): x, y = self.canvas.get_scaled_coords(x, y) current_x, current_y = self.canvas.coords(self.id) x_offset = x - current_x y_offset = y - current_y self.motion(x_offset, y_offset, update=False) - def motion(self, x_offset, y_offset, update=True): + def motion(self, x_offset: int, y_offset: int, update: bool = True): original_position = self.canvas.coords(self.id) self.canvas.move(self.id, x_offset, y_offset) x, y = self.canvas.coords(self.id) @@ -144,7 +152,7 @@ class CanvasNode: if self.app.core.is_runtime() and update: self.app.core.edit_node(self.core_node) - def on_enter(self, event): + def on_enter(self, event: tk.Event): if self.app.core.is_runtime() and self.app.core.observer: self.tooltip.text.set("waiting...") self.tooltip.on_enter(event) @@ -154,16 +162,16 @@ class CanvasNode: except grpc.RpcError as e: show_grpc_error(e) - def on_leave(self, event): + def on_leave(self, event: tk.Event): self.tooltip.on_leave(event) - def double_click(self, event): + def double_click(self, event: tk.Event): if self.app.core.is_runtime(): self.canvas.core.launch_terminal(self.core_node.id) else: self.show_config() - def create_context(self): + def create_context(self) -> tk.Menu: is_wlan = self.core_node.type == NodeType.WIRELESS_LAN is_emane = self.core_node.type == NodeType.EMANE context = tk.Menu(self.canvas) @@ -245,7 +253,7 @@ class CanvasNode: dialog = NodeServiceDialog(self.app.master, self.app, self) dialog.show() - def has_emane_link(self, interface_id): + def has_emane_link(self, interface_id: int) -> core_pb2.Node: result = None for edge in self.edges: if self.id == edge.src: diff --git a/daemon/core/gui/graph/shape.py b/daemon/core/gui/graph/shape.py index b23db2e9..011f7560 100644 --- a/daemon/core/gui/graph/shape.py +++ b/daemon/core/gui/graph/shape.py @@ -1,23 +1,28 @@ import logging +from typing import TYPE_CHECKING, Dict, List, Union from core.gui.dialogs.shapemod import ShapeDialog from core.gui.graph import tags from core.gui.graph.shapeutils import ShapeType +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.graph.graph import CanvasGraph + class AnnotationData: def __init__( self, - text="", - font="Arial", - font_size=12, - text_color="#000000", - fill_color="", - border_color="#000000", - border_width=1, - bold=False, - italic=False, - underline=False, + text: str = "", + font: str = "Arial", + font_size: int = 12, + text_color: str = "#000000", + fill_color: str = "", + border_color: str = "#000000", + border_width: int = 1, + bold: bool = False, + italic: bool = False, + underline: bool = False, ): self.text = text self.font = font @@ -32,7 +37,17 @@ class AnnotationData: class Shape: - def __init__(self, app, canvas, shape_type, x1, y1, x2=None, y2=None, data=None): + def __init__( + self, + app: "Application", + canvas: "CanvasGraph", + shape_type: ShapeType, + x1: float, + y1: float, + x2: float = None, + y2: float = None, + data: AnnotationData = None, + ): self.app = app self.canvas = canvas self.shape_type = shape_type @@ -99,7 +114,7 @@ class Shape: logging.error("unknown shape type: %s", self.shape_type) self.created = True - def get_font(self): + def get_font(self) -> List[Union[int, str]]: font = [self.shape_data.font, self.shape_data.font_size] if self.shape_data.bold: font.append("bold") @@ -123,10 +138,10 @@ class Shape: font=font, ) - def shape_motion(self, x1, y1): + def shape_motion(self, x1: float, y1: float): self.canvas.coords(self.id, self.x1, self.y1, x1, y1) - def shape_complete(self, x, y): + def shape_complete(self, x: float, y: float): for component in tags.ABOVE_SHAPE: self.canvas.tag_raise(component) s = ShapeDialog(self.app, self.app, self) @@ -135,7 +150,7 @@ class Shape: def disappear(self): self.canvas.delete(self.id) - def motion(self, x_offset, y_offset): + def motion(self, x_offset: float, y_offset: float): original_position = self.canvas.coords(self.id) self.canvas.move(self.id, x_offset, y_offset) coords = self.canvas.coords(self.id) @@ -151,7 +166,7 @@ class Shape: self.canvas.delete(self.id) self.canvas.delete(self.text_id) - def metadata(self): + def metadata(self) -> Dict[str, Union[str, int, bool]]: coords = self.canvas.coords(self.id) # update coords to actual positions if len(coords) == 4: diff --git a/daemon/core/gui/graph/shapeutils.py b/daemon/core/gui/graph/shapeutils.py index 0e2cc29c..ce2b7f96 100644 --- a/daemon/core/gui/graph/shapeutils.py +++ b/daemon/core/gui/graph/shapeutils.py @@ -11,13 +11,13 @@ class ShapeType(enum.Enum): SHAPES = {ShapeType.OVAL, ShapeType.RECTANGLE} -def is_draw_shape(shape_type): +def is_draw_shape(shape_type: ShapeType) -> bool: return shape_type in SHAPES -def is_shape_text(shape_type): +def is_shape_text(shape_type: ShapeType) -> bool: return shape_type == ShapeType.TEXT -def is_marker(shape_type): +def is_marker(shape_type: ShapeType) -> bool: return shape_type == ShapeType.MARKER diff --git a/daemon/core/gui/graph/tooltip.py b/daemon/core/gui/graph/tooltip.py index 3cc5825c..a2193901 100644 --- a/daemon/core/gui/graph/tooltip.py +++ b/daemon/core/gui/graph/tooltip.py @@ -1,8 +1,12 @@ import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.themes import Styles +if TYPE_CHECKING: + from core.gui.graph.graph import CanvasGraph + class CanvasTooltip: """ @@ -19,7 +23,14 @@ class CanvasTooltip: Alberto Vassena on 2016.12.10. """ - def __init__(self, canvas, *, pad=(5, 3, 5, 3), waittime=400, wraplength=600): + def __init__( + self, + canvas: "CanvasGraph", + *, + pad=(5, 3, 5, 3), + waittime: int = 400, + wraplength: int = 600 + ): # in miliseconds, originally 500 self.waittime = waittime # in pixels, originally 180 @@ -30,10 +41,10 @@ class CanvasTooltip: self.id = None self.tw = None - def on_enter(self, event=None): + def on_enter(self, event: tk.Event = None): self.schedule() - def on_leave(self, event=None): + def on_leave(self, event: tk.Event = None): self.unschedule() self.hide() @@ -47,7 +58,7 @@ class CanvasTooltip: if id_: self.canvas.after_cancel(id_) - def show(self, event=None): + def show(self, event: tk.Event = None): def tip_pos_calculator(canvas, label, *, tip_delta=(10, 5), pad=(5, 3, 5, 3)): c = canvas s_width, s_height = c.winfo_screenwidth(), c.winfo_screenheight() diff --git a/daemon/core/gui/images.py b/daemon/core/gui/images.py index a1ac922b..9b51b96c 100644 --- a/daemon/core/gui/images.py +++ b/daemon/core/gui/images.py @@ -9,7 +9,7 @@ class Images: images = {} @classmethod - def create(cls, file_path, width, height=None): + def create(cls, file_path: str, width: int, height: int = None): if height is None: height = width image = Image.open(file_path) @@ -22,12 +22,12 @@ class Images: cls.images[image.stem] = str(image) @classmethod - def get(cls, image_enum, width, height=None): + def get(cls, image_enum: Enum, width: int, height: int = None): file_path = cls.images[image_enum.value] return cls.create(file_path, width, height) @classmethod - def get_custom(cls, name, width, height=None): + def get_custom(cls, name: str, width: int, height: int = None): file_path = cls.images[name] return cls.create(file_path, width, height) diff --git a/daemon/core/gui/interface.py b/daemon/core/gui/interface.py index fe08b0dd..0d8573ba 100644 --- a/daemon/core/gui/interface.py +++ b/daemon/core/gui/interface.py @@ -1,24 +1,30 @@ import logging import random +from typing import TYPE_CHECKING, Set, Union from netaddr import IPNetwork from core.gui.nodeutils import NodeUtils +if TYPE_CHECKING: + from core.gui.app import Application + from core.api.grpc import core_pb2 + from core.gui.graph.node import CanvasNode + def random_mac(): return ("{:02x}" * 6).format(*[random.randrange(256) for _ in range(6)]) class InterfaceManager: - def __init__(self, app, address="10.0.0.0", mask=24): + def __init__(self, app: "Application", address: str = "10.0.0.0", mask: int = 24): self.app = app self.mask = mask self.base_prefix = max(self.mask - 8, 0) self.subnets = IPNetwork(f"{address}/{self.base_prefix}") self.current_subnet = None - def next_subnet(self): + def next_subnet(self) -> IPNetwork: # define currently used subnets used_subnets = set() for edge in self.app.core.links.values(): @@ -38,17 +44,19 @@ class InterfaceManager: def reset(self): self.current_subnet = None - def get_ips(self, node_id): + def get_ips(self, node_id: int) -> [str, str, int]: ip4 = self.current_subnet[node_id] ip6 = ip4.ipv6() prefix = self.current_subnet.prefixlen return str(ip4), str(ip6), prefix @classmethod - def get_subnet(cls, interface): + def get_subnet(cls, interface: "core_pb2.Interface") -> IPNetwork: return IPNetwork(f"{interface.ip4}/{interface.ip4mask}").cidr - def determine_subnet(self, canvas_src_node, canvas_dst_node): + def determine_subnet( + self, canvas_src_node: "CanvasNode", canvas_dst_node: "CanvasNode" + ): src_node = canvas_src_node.core_node dst_node = canvas_dst_node.core_node is_src_container = NodeUtils.is_container_node(src_node.type) @@ -70,7 +78,9 @@ class InterfaceManager: else: logging.info("ignoring subnet change for link between network nodes") - def find_subnet(self, canvas_node, visited=None): + def find_subnet( + self, canvas_node: "CanvasNode", visited: Set[int] = None + ) -> Union[IPNetwork, None]: logging.info("finding subnet for node: %s", canvas_node.core_node.name) canvas = self.app.canvas cidr = None diff --git a/daemon/core/gui/menuaction.py b/daemon/core/gui/menuaction.py index c48f82ff..4edc4797 100644 --- a/daemon/core/gui/menuaction.py +++ b/daemon/core/gui/menuaction.py @@ -3,8 +3,10 @@ The actions taken when each menubar option is clicked """ import logging +import tkinter as tk import webbrowser from tkinter import filedialog, messagebox +from typing import TYPE_CHECKING from core.gui.appconfig import XMLS_PATH from core.gui.dialogs.about import AboutDialog @@ -19,29 +21,24 @@ from core.gui.dialogs.sessions import SessionsDialog from core.gui.dialogs.throughput import ThroughputDialog from core.gui.task import BackgroundTask +if TYPE_CHECKING: + from core.gui.app import Application + class MenuAction: - """ - Actions performed when choosing menu items - """ - - def __init__(self, app, master): + def __init__(self, app: "Application", master: tk.Tk): self.master = master self.app = app self.canvas = app.canvas - def cleanup_old_session(self, quitapp=False): + def cleanup_old_session(self): logging.info("cleaning up old session") self.app.core.stop_session() self.app.core.delete_session() - # if quitapp: - # self.app.quit() - def prompt_save_running_session(self, quitapp=False): + def prompt_save_running_session(self, quitapp: bool = False): """ Prompt use to stop running session before application is closed - - :return: nothing """ result = True if self.app.core.is_runtime(): @@ -56,15 +53,13 @@ class MenuAction: elif quitapp: self.app.quit() - def on_quit(self, event=None): + def on_quit(self, event: tk.Event = None): """ Prompt user whether so save running session, and then close the application - - :return: nothing """ self.prompt_save_running_session(quitapp=True) - def file_save_as_xml(self, event=None): + def file_save_as_xml(self, event: tk.Event = None): logging.info("menuaction.py file_save_as_xml()") file_path = filedialog.asksaveasfilename( initialdir=str(XMLS_PATH), @@ -75,7 +70,7 @@ class MenuAction: if file_path: self.app.core.save_xml(file_path) - def file_open_xml(self, event=None): + def file_open_xml(self, event: tk.Event = None): logging.info("menuaction.py file_open_xml()") file_path = filedialog.askopenfilename( initialdir=str(XMLS_PATH), @@ -141,11 +136,11 @@ class MenuAction: else: self.app.core.cancel_throughputs() - def copy(self, event=None): + def copy(self, event: tk.Event = None): logging.debug("copy") self.app.canvas.copy() - def paste(self, event=None): + def paste(self, event: tk.Event = None): logging.debug("paste") self.app.canvas.paste() diff --git a/daemon/core/gui/menubar.py b/daemon/core/gui/menubar.py index f8020866..c8908333 100644 --- a/daemon/core/gui/menubar.py +++ b/daemon/core/gui/menubar.py @@ -1,22 +1,22 @@ import tkinter as tk from functools import partial +from typing import TYPE_CHECKING import core.gui.menuaction as action from core.gui.coreclient import OBSERVERS +if TYPE_CHECKING: + from core.gui.app import Application + class Menubar(tk.Menu): """ Core menubar """ - def __init__(self, master, app, cnf={}, **kwargs): + def __init__(self, master: tk.Tk, app: "Application", cnf={}, **kwargs): """ Create a CoreMenubar instance - - :param master: - :param tkinter.Menu menubar: menubar object - :param coretk.app.Application app: application object """ super().__init__(master, cnf, **kwargs) self.master.config(menu=self) @@ -27,8 +27,6 @@ class Menubar(tk.Menu): def draw(self): """ Create core menubar and bind the hot keys to their matching command - - :return: nothing """ self.draw_file_menu() self.draw_edit_menu() @@ -42,8 +40,6 @@ class Menubar(tk.Menu): def draw_file_menu(self): """ Create file menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command( @@ -81,8 +77,6 @@ class Menubar(tk.Menu): def draw_edit_menu(self): """ Create edit menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command(label="Preferences", command=self.menuaction.gui_preferences) @@ -112,8 +106,6 @@ class Menubar(tk.Menu): def draw_canvas_menu(self): """ Create canvas menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command( @@ -136,8 +128,6 @@ class Menubar(tk.Menu): def draw_view_menu(self): """ Create view menu - - :return: nothing """ view_menu = tk.Menu(self) self.create_show_menu(view_menu) @@ -149,12 +139,9 @@ class Menubar(tk.Menu): view_menu.add_command(label="Zoom out", accelerator="-", state=tk.DISABLED) self.add_cascade(label="View", menu=view_menu) - def create_show_menu(self, view_menu): + def create_show_menu(self, view_menu: tk.Menu): """ Create the menu items in View/Show - - :param tkinter.Menu view_menu: the view menu - :return: nothing """ menu = tk.Menu(view_menu) menu.add_command(label="All", state=tk.DISABLED) @@ -169,12 +156,9 @@ class Menubar(tk.Menu): menu.add_command(label="API Messages", state=tk.DISABLED) view_menu.add_cascade(label="Show", menu=menu) - def create_experimental_menu(self, tools_menu): + def create_experimental_menu(self, tools_menu: tk.Menu): """ Create experimental menu item and the sub menu items inside - - :param tkinter.Menu tools_menu: tools menu - :return: nothing """ menu = tk.Menu(tools_menu) menu.add_command(label="Plugins...", state=tk.DISABLED) @@ -182,12 +166,9 @@ class Menubar(tk.Menu): menu.add_command(label="Topology partitioning...", state=tk.DISABLED) tools_menu.add_cascade(label="Experimental", menu=menu) - def create_random_menu(self, topology_generator_menu): + def create_random_menu(self, topology_generator_menu: tk.Menu): """ Create random menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) # list of number of random nodes to create @@ -197,12 +178,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Random", menu=menu) - def create_grid_menu(self, topology_generator_menu): + def create_grid_menu(self, topology_generator_menu: tk.Menu): """ Create grid menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology_generator_menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) # list of number of nodes to create @@ -212,12 +190,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Grid", menu=menu) - def create_connected_grid_menu(self, topology_generator_menu): + def create_connected_grid_menu(self, topology_generator_menu: tk.Menu): """ Create connected grid menu items and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(1, 11, 1): @@ -229,12 +204,9 @@ class Menubar(tk.Menu): menu.add_cascade(label=label, menu=submenu) topology_generator_menu.add_cascade(label="Connected Grid", menu=menu) - def create_chain_menu(self, topology_generator_menu): + def create_chain_menu(self, topology_generator_menu: tk.Menu): """ Create chain menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) # number of nodes to create @@ -244,12 +216,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Chain", menu=menu) - def create_star_menu(self, topology_generator_menu): + def create_star_menu(self, topology_generator_menu: tk.Menu): """ Create star menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(3, 26, 1): @@ -257,12 +226,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Star", menu=menu) - def create_cycle_menu(self, topology_generator_menu): + def create_cycle_menu(self, topology_generator_menu: tk.Menu): """ Create cycle menu item and the sub items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(3, 25, 1): @@ -270,12 +236,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Cycle", menu=menu) - def create_wheel_menu(self, topology_generator_menu): + def create_wheel_menu(self, topology_generator_menu: tk.Menu): """ Create wheel menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(4, 26, 1): @@ -283,12 +246,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Wheel", menu=menu) - def create_cube_menu(self, topology_generator_menu): + def create_cube_menu(self, topology_generator_menu: tk.Menu): """ Create cube menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(2, 7, 1): @@ -296,12 +256,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Cube", menu=menu) - def create_clique_menu(self, topology_generator_menu): + def create_clique_menu(self, topology_generator_menu: tk.Menu): """ Create clique menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology generator menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) for i in range(3, 25, 1): @@ -309,12 +266,9 @@ class Menubar(tk.Menu): menu.add_command(label=label, state=tk.DISABLED) topology_generator_menu.add_cascade(label="Clique", menu=menu) - def create_bipartite_menu(self, topology_generator_menu): + def create_bipartite_menu(self, topology_generator_menu: tk.Menu): """ Create bipartite menu item and the sub menu items inside - - :param tkinter.Menu topology_generator_menu: topology_generator_menu - :return: nothing """ menu = tk.Menu(topology_generator_menu) temp = 24 @@ -328,13 +282,9 @@ class Menubar(tk.Menu): temp = temp - 1 topology_generator_menu.add_cascade(label="Bipartite", menu=menu) - def create_topology_generator_menu(self, tools_menu): + def create_topology_generator_menu(self, tools_menu: tk.Menu): """ Create topology menu item and its sub menu items - - :param tkinter.Menu tools_menu: tools menu - - :return: nothing """ menu = tk.Menu(tools_menu) self.create_random_menu(menu) @@ -352,8 +302,6 @@ class Menubar(tk.Menu): def draw_tools_menu(self): """ Create tools menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command(label="Auto rearrange all", state=tk.DISABLED) @@ -371,12 +319,9 @@ class Menubar(tk.Menu): menu.add_command(label="Debugger...", state=tk.DISABLED) self.add_cascade(label="Tools", menu=menu) - def create_observer_widgets_menu(self, widget_menu): + def create_observer_widgets_menu(self, widget_menu: tk.Menu): """ Create observer widget menu item and create the sub menu items inside - - :param tkinter.Menu widget_menu: widget_menu - :return: nothing """ var = tk.StringVar(value="none") menu = tk.Menu(widget_menu) @@ -409,12 +354,9 @@ class Menubar(tk.Menu): ) widget_menu.add_cascade(label="Observer Widgets", menu=menu) - def create_adjacency_menu(self, widget_menu): + def create_adjacency_menu(self, widget_menu: tk.Menu): """ Create adjacency menu item and the sub menu items inside - - :param tkinter.Menu widget_menu: widget menu - :return: nothing """ menu = tk.Menu(widget_menu) menu.add_command(label="OSPFv2", state=tk.DISABLED) @@ -426,8 +368,6 @@ class Menubar(tk.Menu): def draw_widgets_menu(self): """ Create widget menu - - :return: nothing """ menu = tk.Menu(self) self.create_observer_widgets_menu(menu) @@ -443,8 +383,6 @@ class Menubar(tk.Menu): def draw_session_menu(self): """ Create session menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command( @@ -461,8 +399,6 @@ class Menubar(tk.Menu): def draw_help_menu(self): """ Create help menu - - :return: nothing """ menu = tk.Menu(self) menu.add_command( diff --git a/daemon/core/gui/nodeutils.py b/daemon/core/gui/nodeutils.py index 9867998a..81d4894d 100644 --- a/daemon/core/gui/nodeutils.py +++ b/daemon/core/gui/nodeutils.py @@ -1,22 +1,34 @@ +from typing import TYPE_CHECKING, Optional, Set + from core.api.grpc.core_pb2 import NodeType from core.gui.images import ImageEnum, Images +if TYPE_CHECKING: + from core.api.grpc import core_pb2 + ICON_SIZE = 48 ANTENNA_SIZE = 32 class NodeDraw: def __init__(self): - self.custom = False + self.custom: bool = False self.image = None - self.image_enum = None + self.image_enum: Optional[ImageEnum] = None self.image_file = None - self.node_type = None - self.model = None - self.services = set() + self.node_type: core_pb2.NodeType = None + self.model: Optional[str] = None + self.services: Set[str] = set() @classmethod - def from_setup(cls, image_enum, node_type, label, model=None, tooltip=None): + def from_setup( + cls, + image_enum: ImageEnum, + node_type: "core_pb2.NodeType", + label: str, + model: str = None, + tooltip=None, + ): node_draw = NodeDraw() node_draw.image_enum = image_enum node_draw.image = Images.get(image_enum, ICON_SIZE) @@ -27,7 +39,7 @@ class NodeDraw: return node_draw @classmethod - def from_custom(cls, name, image_file, services): + def from_custom(cls, name: str, image_file: str, services: Set[str]): node_draw = NodeDraw() node_draw.custom = True node_draw.image_file = image_file @@ -53,31 +65,31 @@ class NodeUtils: ANTENNA_ICON = None @classmethod - def is_ignore_node(cls, node_type): + def is_ignore_node(cls, node_type: NodeType) -> bool: return node_type in cls.IGNORE_NODES @classmethod - def is_container_node(cls, node_type): + def is_container_node(cls, node_type: NodeType) -> bool: return node_type in cls.CONTAINER_NODES @classmethod - def is_model_node(cls, node_type): + def is_model_node(cls, node_type: NodeType) -> bool: return node_type == NodeType.DEFAULT @classmethod - def is_image_node(cls, node_type): + def is_image_node(cls, node_type: NodeType) -> bool: return node_type in cls.IMAGE_NODES @classmethod - def is_wireless_node(cls, node_type): + def is_wireless_node(cls, node_type: NodeType) -> bool: return node_type in cls.WIRELESS_NODES @classmethod - def is_rj45_node(cls, node_type): + def is_rj45_node(cls, node_type: NodeType) -> bool: return node_type in cls.RJ45_NODES @classmethod - def node_icon(cls, node_type, model): + def node_icon(cls, node_type: NodeType, model: str) -> bool: if model == "": model = None return cls.NODE_ICONS[(node_type, model)] diff --git a/daemon/core/gui/statusbar.py b/daemon/core/gui/statusbar.py index 1567e799..6de511c4 100644 --- a/daemon/core/gui/statusbar.py +++ b/daemon/core/gui/statusbar.py @@ -1,13 +1,19 @@ -"status bar" +""" +status bar +""" import tkinter as tk from tkinter import ttk +from typing import TYPE_CHECKING from core.gui.dialogs.alerts import AlertsDialog from core.gui.themes import Styles +if TYPE_CHECKING: + from core.gui.app import Application + class StatusBar(ttk.Frame): - def __init__(self, master, app, **kwargs): + def __init__(self, master: "Application", app: "Application", **kwargs): super().__init__(master, **kwargs) self.app = app self.status = None @@ -68,9 +74,5 @@ class StatusBar(ttk.Frame): dialog = AlertsDialog(self.app, self.app) dialog.show() - def set_status(self, message): + def set_status(self, message: str): self.statusvar.set(message) - - def stop_session_callback(self, cleanup_time): - self.progress_bar.stop() - self.statusvar.set(f"Stopped in {cleanup_time:.3f} seconds") diff --git a/daemon/core/gui/task.py b/daemon/core/gui/task.py index bf731dd4..bd7423ee 100644 --- a/daemon/core/gui/task.py +++ b/daemon/core/gui/task.py @@ -1,9 +1,10 @@ import logging import threading +from typing import Any, Callable class BackgroundTask: - def __init__(self, master, task, callback=None, args=()): + def __init__(self, master: Any, task: Callable, callback: Callable = None, args=()): self.master = master self.args = args self.task = task diff --git a/daemon/core/gui/themes.py b/daemon/core/gui/themes.py index 54ae3ce0..26ee5379 100644 --- a/daemon/core/gui/themes.py +++ b/daemon/core/gui/themes.py @@ -33,7 +33,7 @@ class Colors: listboxbg = "#f2f1f0" -def load(style): +def load(style: ttk.Style): style.theme_create( THEME_DARK, "clam", @@ -139,13 +139,13 @@ def load(style): ) -def theme_change_menu(event): +def theme_change_menu(event: tk.Event): if not isinstance(event.widget, tk.Menu): return style_menu(event.widget) -def style_menu(widget): +def style_menu(widget: tk.Widget): style = ttk.Style() bg = style.lookup(".", "background") fg = style.lookup(".", "foreground") @@ -157,7 +157,7 @@ def style_menu(widget): ) -def style_listbox(widget): +def style_listbox(widget: tk.Widget): style = ttk.Style() bg = style.lookup(".", "background") fg = style.lookup(".", "foreground") @@ -174,7 +174,7 @@ def style_listbox(widget): ) -def theme_change(event): +def theme_change(event: tk.Event): style = ttk.Style() style.configure(Styles.picker_button, font=("TkDefaultFont", 8, "normal")) style.configure( diff --git a/daemon/core/gui/toolbar.py b/daemon/core/gui/toolbar.py index 5404d9e5..021c2bc1 100644 --- a/daemon/core/gui/toolbar.py +++ b/daemon/core/gui/toolbar.py @@ -4,17 +4,23 @@ import tkinter as tk from functools import partial from tkinter import messagebox, ttk from tkinter.font import Font +from typing import TYPE_CHECKING, Callable +from core.api.grpc import core_pb2 from core.gui.dialogs.customnodes import CustomNodesDialog from core.gui.dialogs.marker import MarkerDialog from core.gui.graph.enums import GraphMode from core.gui.graph.shapeutils import ShapeType, is_marker from core.gui.images import ImageEnum, Images -from core.gui.nodeutils import NodeUtils +from core.gui.nodeutils import NodeDraw, NodeUtils from core.gui.task import BackgroundTask from core.gui.themes import Styles from core.gui.tooltip import Tooltip +if TYPE_CHECKING: + from core.gui.app import Application + from PIL import ImageTk + TOOLBAR_SIZE = 32 PICKER_SIZE = 24 @@ -28,11 +34,9 @@ class Toolbar(ttk.Frame): Core toolbar class """ - def __init__(self, master, app, **kwargs): + def __init__(self, master: "Application", app: "Application", **kwargs): """ Create a CoreToolbar instance - - :param tkinter.Frame edit_frame: edit frame """ super().__init__(master, **kwargs) self.app = app @@ -100,7 +104,7 @@ class Toolbar(ttk.Frame): self.create_network_button() self.create_annotation_button() - def design_select(self, button): + def design_select(self, button: ttk.Button): logging.info("selecting design button: %s", button) self.select_button.state(["!pressed"]) self.link_button.state(["!pressed"]) @@ -109,7 +113,7 @@ class Toolbar(ttk.Frame): self.annotation_button.state(["!pressed"]) button.state(["pressed"]) - def runtime_select(self, button): + def runtime_select(self, button: ttk.Button): logging.info("selecting runtime button: %s", button) self.runtime_select_button.state(["!pressed"]) self.stop_button.state(["!pressed"]) @@ -185,7 +189,7 @@ class Toolbar(ttk.Frame): 0, lambda: self.show_picker(self.node_button, self.node_picker) ) - def show_picker(self, button, picker): + def show_picker(self, button: ttk.Button, picker: ttk.Frame): x = self.winfo_width() + 1 y = button.winfo_rooty() - picker.master.winfo_rooty() - 1 picker.place(x=x, y=y) @@ -195,7 +199,9 @@ class Toolbar(ttk.Frame): self.wait_window(picker) self.app.unbind_all("") - def create_picker_button(self, image, func, frame, label): + def create_picker_button( + self, image: "ImageTk.PhotoImage", func: Callable, frame: ttk.Frame, label: str + ): """ Create button and put it on the frame @@ -203,7 +209,6 @@ class Toolbar(ttk.Frame): :param func: the command that is executed when button is clicked :param tkinter.Frame frame: frame that contains the button :param str label: button label - :return: nothing """ button = ttk.Button( frame, image=image, text=label, compound=tk.TOP, style=Styles.picker_button @@ -212,7 +217,13 @@ class Toolbar(ttk.Frame): button.bind("", lambda e: func()) button.grid(pady=1) - def create_button(self, frame, image, func, tooltip): + def create_button( + self, + frame: ttk.Frame, + image: "ImageTk.PhotoImage", + func: Callable, + tooltip: str, + ): button = ttk.Button(frame, image=image, command=func) button.image = image button.grid(sticky="ew") @@ -233,8 +244,6 @@ class Toolbar(ttk.Frame): """ Start session handler redraw buttons, send node and link messages to grpc server. - - :return: nothing """ self.app.canvas.hide_context() self.app.statusbar.progress_bar.start(5) @@ -243,7 +252,7 @@ class Toolbar(ttk.Frame): task = BackgroundTask(self, self.app.core.start_session, self.start_callback) task.start() - def start_callback(self, response): + def start_callback(self, response: core_pb2.StartSessionResponse): self.app.statusbar.progress_bar.stop() total = time.perf_counter() - self.time message = f"Start ran for {total:.3f} seconds" @@ -275,7 +284,7 @@ class Toolbar(ttk.Frame): dialog = CustomNodesDialog(self.app, self.app) dialog.show() - def update_button(self, button, image, node_draw): + def update_button(self, button: ttk.Button, image: "ImageTk", node_draw: NodeDraw): logging.info("update button(%s): %s", button, node_draw) self.hide_pickers() button.configure(image=image) @@ -298,8 +307,6 @@ class Toolbar(ttk.Frame): def create_node_button(self): """ Create network layer button - - :return: nothing """ image = icon(ImageEnum.ROUTER) self.node_button = ttk.Button( @@ -312,8 +319,6 @@ class Toolbar(ttk.Frame): def draw_network_picker(self): """ Draw the options for link-layer button. - - :return: nothing """ self.hide_pickers() self.network_picker = ttk.Frame(self.master) @@ -337,8 +342,6 @@ class Toolbar(ttk.Frame): """ Create link-layer node button and the options that represent different link-layer node types. - - :return: nothing """ image = icon(ImageEnum.HUB) self.network_button = ttk.Button( @@ -351,8 +354,6 @@ class Toolbar(ttk.Frame): def draw_annotation_picker(self): """ Draw the options for marker button. - - :return: nothing """ self.hide_pickers() self.annotation_picker = ttk.Frame(self.master) @@ -379,8 +380,6 @@ class Toolbar(ttk.Frame): def create_annotation_button(self): """ Create marker button and options that represent different marker types - - :return: nothing """ image = icon(ImageEnum.MARKER) self.annotation_button = ttk.Button( @@ -417,8 +416,6 @@ class Toolbar(ttk.Frame): def click_stop(self): """ redraw buttons on the toolbar, send node and link messages to grpc server - - :return: nothing """ self.app.canvas.hide_context() self.app.statusbar.progress_bar.start(5) @@ -426,7 +423,7 @@ class Toolbar(ttk.Frame): task = BackgroundTask(self, self.app.core.stop_session, self.stop_callback) task.start() - def stop_callback(self, response): + def stop_callback(self, response: core_pb2.StopSessionResponse): self.app.statusbar.progress_bar.stop() self.set_design() total = time.perf_counter() - self.time @@ -436,7 +433,7 @@ class Toolbar(ttk.Frame): if not response.result: messagebox.showerror("Stop Error", "Errors stopping session") - def update_annotation(self, image, shape_type): + def update_annotation(self, image: "ImageTk.PhotoImage", shape_type: ShapeType): logging.info("clicked annotation: ") self.hide_pickers() self.annotation_button.configure(image=image) @@ -446,7 +443,7 @@ class Toolbar(ttk.Frame): if is_marker(shape_type): if self.marker_tool: self.marker_tool.destroy() - self.marker_tool = MarkerDialog(self.master, self.app) + self.marker_tool = MarkerDialog(self.app, self.app) self.marker_tool.show() def click_run_button(self): @@ -462,7 +459,7 @@ class Toolbar(ttk.Frame): self.app.canvas.annotation_type = ShapeType.MARKER if self.marker_tool: self.marker_tool.destroy() - self.marker_tool = MarkerDialog(self.master, self.app) + self.marker_tool = MarkerDialog(self.app, self.app) self.marker_tool.show() def click_two_node_button(self): diff --git a/daemon/core/gui/tooltip.py b/daemon/core/gui/tooltip.py index 4fe2b467..bc1ed9b5 100644 --- a/daemon/core/gui/tooltip.py +++ b/daemon/core/gui/tooltip.py @@ -9,7 +9,7 @@ class Tooltip(object): Create tool tip for a given widget """ - def __init__(self, widget, text="widget info"): + def __init__(self, widget: tk.Widget, text: str = "widget info"): self.widget = widget self.text = text self.widget.bind("", self.on_enter) @@ -18,10 +18,10 @@ class Tooltip(object): self.id = None self.tw = None - def on_enter(self, event=None): + def on_enter(self, event: tk.Event = None): self.schedule() - def on_leave(self, event=None): + def on_leave(self, event: tk.Event = None): self.unschedule() self.close(event) @@ -35,7 +35,7 @@ class Tooltip(object): if id_: self.widget.after_cancel(id_) - def enter(self, event=None): + def enter(self, event: tk.Event = None): x, y, cx, cy = self.widget.bbox("insert") x += self.widget.winfo_rootx() y += self.widget.winfo_rooty() + 32 @@ -50,6 +50,6 @@ class Tooltip(object): label = ttk.Label(frame, text=self.text, style=Styles.tooltip) label.grid() - def close(self, event=None): + def close(self, event: tk.Event = None): if self.tw: self.tw.destroy() diff --git a/daemon/core/gui/validation.py b/daemon/core/gui/validation.py index 955a7faf..acb3d553 100644 --- a/daemon/core/gui/validation.py +++ b/daemon/core/gui/validation.py @@ -3,13 +3,17 @@ input validation """ import re import tkinter as tk +from typing import TYPE_CHECKING import netaddr from netaddr import IPNetwork +if TYPE_CHECKING: + from core.gui.app import Application + class InputValidation: - def __init__(self, app): + def __init__(self, app: "Application"): self.master = app.master self.positive_int = None self.positive_float = None @@ -27,7 +31,7 @@ class InputValidation: self.rgb = self.master.register(self.check_rbg) self.hex = self.master.register(self.check_hex) - def ip_focus_out(self, event): + def ip_focus_out(self, event: tk.Event): value = event.widget.get() try: IPNetwork(value) @@ -35,12 +39,12 @@ class InputValidation: event.widget.delete(0, tk.END) event.widget.insert(tk.END, "invalid") - def focus_out(self, event, default): + def focus_out(self, event: tk.Event, default: str): value = event.widget.get() if value == "": event.widget.insert(tk.END, default) - def check_positive_int(self, s): + def check_positive_int(self, s: str) -> bool: if len(s) == 0: return True try: @@ -51,7 +55,7 @@ class InputValidation: except ValueError: return False - def check_positive_float(self, s): + def check_positive_float(self, s: str) -> bool: if len(s) == 0: return True try: @@ -62,7 +66,7 @@ class InputValidation: except ValueError: return False - def check_node_name(self, s): + def check_node_name(self, s: str) -> bool: if len(s) < 0: return False if len(s) == 0: @@ -72,7 +76,7 @@ class InputValidation: return False return True - def check_canvas_int(sefl, s): + def check_canvas_int(self, s: str) -> bool: if len(s) == 0: return True try: @@ -83,7 +87,7 @@ class InputValidation: except ValueError: return False - def check_canvas_float(self, s): + def check_canvas_float(self, s: str) -> bool: if not s: return True try: @@ -94,7 +98,7 @@ class InputValidation: except ValueError: return False - def check_ip4(self, s): + def check_ip4(self, s: str) -> bool: if not s: return True pat = re.compile("^([0-9]+[.])*[0-9]*$") @@ -113,7 +117,7 @@ class InputValidation: else: return False - def check_rbg(self, s): + def check_rbg(self, s: str) -> bool: if not s: return True if s.startswith("0") and len(s) >= 2: @@ -127,7 +131,7 @@ class InputValidation: except ValueError: return False - def check_hex(self, s): + def check_hex(self, s: str) -> bool: if not s: return True pat = re.compile("^([#]([0-9]|[a-f])+)$|^[#]$") diff --git a/daemon/core/gui/widgets.py b/daemon/core/gui/widgets.py index d5257533..8dc163ab 100644 --- a/daemon/core/gui/widgets.py +++ b/daemon/core/gui/widgets.py @@ -1,12 +1,18 @@ import logging import tkinter as tk from functools import partial +from pathlib import PosixPath from tkinter import filedialog, font, ttk +from typing import TYPE_CHECKING, Dict from core.api.grpc import core_pb2 from core.gui import themes from core.gui.themes import FRAME_PAD, PADX, PADY +if TYPE_CHECKING: + from core.gui.app import Application + from core.gui.dialogs.dialog import Dialog + INT_TYPES = { core_pb2.ConfigOptionType.UINT8, core_pb2.ConfigOptionType.UINT16, @@ -19,14 +25,14 @@ INT_TYPES = { } -def file_button_click(value, parent): +def file_button_click(value: tk.StringVar, parent: tk.Widget): file_path = filedialog.askopenfilename(title="Select File", parent=parent) if file_path: value.set(file_path) class FrameScroll(ttk.Frame): - def __init__(self, master, app, _cls=ttk.Frame, **kw): + def __init__(self, master: tk.Widget, app: "Application", _cls=ttk.Frame, **kw): super().__init__(master, **kw) self.app = app self.rowconfigure(0, weight=1) @@ -49,13 +55,13 @@ class FrameScroll(ttk.Frame): self.frame.bind("", self._configure_frame) self.canvas.bind("", self._configure_canvas) - def _configure_frame(self, event): + def _configure_frame(self, event: tk.Event): req_width = self.frame.winfo_reqwidth() if req_width != self.canvas.winfo_reqwidth(): self.canvas.configure(width=req_width) self.canvas.configure(scrollregion=self.canvas.bbox("all")) - def _configure_canvas(self, event): + def _configure_canvas(self, event: tk.Event): self.canvas.itemconfig(self.frame_id, width=event.width) def clear(self): @@ -64,7 +70,13 @@ class FrameScroll(ttk.Frame): class ConfigFrame(ttk.Notebook): - def __init__(self, master, app, config, **kw): + def __init__( + self, + master: tk.Widget, + app: "Application", + config: Dict[str, core_pb2.ConfigOption], + **kw + ): super().__init__(master, **kw) self.app = app self.config = config @@ -174,7 +186,7 @@ class ConfigFrame(ttk.Notebook): class ListboxScroll(ttk.Frame): - def __init__(self, master=None, **kw): + def __init__(self, master: tk.Widget = None, **kw): super().__init__(master, **kw) self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1) @@ -189,12 +201,12 @@ class ListboxScroll(ttk.Frame): class CheckboxList(FrameScroll): - def __init__(self, master, app, clicked=None, **kw): + def __init__(self, master: ttk.Widget, app: "Application", clicked=None, **kw): super().__init__(master, app, **kw) self.clicked = clicked self.frame.columnconfigure(0, weight=1) - def add(self, name, checked): + def add(self, name: str, checked: bool): var = tk.BooleanVar(value=checked) func = partial(self.clicked, name, var) checkbox = ttk.Checkbutton(self.frame, text=name, variable=var, command=func) @@ -207,7 +219,7 @@ class CodeFont(font.Font): class CodeText(ttk.Frame): - def __init__(self, master, **kwargs): + def __init__(self, master: tk.Widget, **kwargs): super().__init__(master, **kwargs) self.rowconfigure(0, weight=1) self.columnconfigure(0, weight=1) @@ -231,14 +243,14 @@ class CodeText(ttk.Frame): class Spinbox(ttk.Entry): - def __init__(self, master=None, **kwargs): + def __init__(self, master: tk.Widget = None, **kwargs): super().__init__(master, "ttk::spinbox", **kwargs) def set(self, value): self.tk.call(self._w, "set", value) -def image_chooser(parent, path): +def image_chooser(parent: "Dialog", path: PosixPath): return filedialog.askopenfilename( parent=parent, initialdir=str(path), diff --git a/daemon/core/location/corelocation.py b/daemon/core/location/corelocation.py index aeab8896..fc803fac 100644 --- a/daemon/core/location/corelocation.py +++ b/daemon/core/location/corelocation.py @@ -6,6 +6,7 @@ https://pypi.python.org/pypi/utm (version 0.3.0). """ import logging +from typing import Optional, Tuple from core.emulator.enumerations import RegisterTlvs from core.location import utm @@ -21,7 +22,7 @@ class CoreLocation: name = "location" config_type = RegisterTlvs.UTILITY.value - def __init__(self): + def __init__(self) -> None: """ Creates a MobilityManager instance. @@ -37,7 +38,7 @@ class CoreLocation: for n, l in utm.ZONE_LETTERS: self.zonemap[l] = n - def reset(self): + def reset(self) -> None: """ Reset to initial state. """ @@ -50,7 +51,7 @@ class CoreLocation: # cached distance to refpt in other zones self.zoneshifts = {} - def px2m(self, val): + def px2m(self, val: float) -> float: """ Convert the specified value in pixels to meters using the configured scale. The scale is given as s, where @@ -61,7 +62,7 @@ class CoreLocation: """ return (val / 100.0) * self.refscale - def m2px(self, val): + def m2px(self, val: float) -> float: """ Convert the specified value in meters to pixels using the configured scale. The scale is given as s, where @@ -74,7 +75,7 @@ class CoreLocation: return 0.0 return 100.0 * (val / self.refscale) - def setrefgeo(self, lat, lon, alt): + def setrefgeo(self, lat: float, lon: float, alt: float) -> None: """ Record the geographical reference point decimal (lat, lon, alt) and convert and store its UTM equivalent for later use. @@ -89,7 +90,7 @@ class CoreLocation: e, n, zonen, zonel = utm.from_latlon(lat, lon) self.refutm = ((zonen, zonel), e, n, alt) - def getgeo(self, x, y, z): + def getgeo(self, x: float, y: float, z: float) -> Tuple[float, float, float]: """ Given (x, y, z) Cartesian coordinates, convert them to latitude, longitude, and altitude based on the configured reference point @@ -130,7 +131,7 @@ class CoreLocation: lat, lon = self.refgeo[:2] return lat, lon, alt - def getxyz(self, lat, lon, alt): + def getxyz(self, lat: float, lon: float, alt: float) -> Tuple[float, float, float]: """ Given latitude, longitude, and altitude location data, convert them to (x, y, z) Cartesian coordinates based on the configured @@ -165,7 +166,7 @@ class CoreLocation: z = self.m2px(zm) + self.refxyz[2] return x, y, z - def geteastingshift(self, zonen, zonel): + def geteastingshift(self, zonen: float, zonel: float) -> Optional[float]: """ If the lat, lon coordinates being converted are located in a different UTM zone than the canvas reference point, the UTM meters @@ -201,7 +202,7 @@ class CoreLocation: self.zoneshifts[z] = (xshift, yshift) return xshift - def getnorthingshift(self, zonen, zonel): + def getnorthingshift(self, zonen: float, zonel: float) -> Optional[float]: """ If the lat, lon coordinates being converted are located in a different UTM zone than the canvas reference point, the UTM meters @@ -238,7 +239,9 @@ class CoreLocation: self.zoneshifts[z] = (xshift, yshift) return yshift - def getutmzoneshift(self, e, n): + def getutmzoneshift( + self, e: float, n: float + ) -> Tuple[float, float, Tuple[float, str]]: """ Given UTM easting and northing values, check if they fall outside the reference point's zone boundary. Return the UTM coordinates in a diff --git a/daemon/core/location/event.py b/daemon/core/location/event.py index f930d9b7..d553e4ee 100644 --- a/daemon/core/location/event.py +++ b/daemon/core/location/event.py @@ -6,6 +6,7 @@ import heapq import threading import time from functools import total_ordering +from typing import Any, Callable class Timer(threading.Thread): @@ -14,7 +15,9 @@ class Timer(threading.Thread): already running. """ - def __init__(self, interval, function, args=None, kwargs=None): + def __init__( + self, interval: float, function: Callable, args: Any = None, kwargs: Any = None + ) -> None: """ Create a Timer instance. @@ -42,7 +45,7 @@ class Timer(threading.Thread): else: self.kwargs = {} - def cancel(self): + def cancel(self) -> bool: """ Stop the timer if it hasn't finished yet. Return False if the timer was already running. @@ -56,7 +59,7 @@ class Timer(threading.Thread): self._running.release() return locked - def run(self): + def run(self) -> None: """ Run the timer. @@ -75,7 +78,9 @@ class Event: Provides event objects that can be used within the EventLoop class. """ - def __init__(self, eventnum, event_time, func, *args, **kwds): + def __init__( + self, eventnum: int, event_time: float, func: Callable, *args: Any, **kwds: Any + ) -> None: """ Create an Event instance. @@ -92,13 +97,13 @@ class Event: self.kwds = kwds self.canceled = False - def __lt__(self, other): + def __lt__(self, other: "Event") -> bool: result = self.time < other.time if result: result = self.eventnum < other.eventnum return result - def run(self): + def run(self) -> None: """ Run an event. @@ -108,7 +113,7 @@ class Event: return self.func(*self.args, **self.kwds) - def cancel(self): + def cancel(self) -> None: """ Cancel event. @@ -123,7 +128,7 @@ class EventLoop: Provides an event loop for running events. """ - def __init__(self): + def __init__(self) -> None: """ Creates a EventLoop instance. """ @@ -134,7 +139,7 @@ class EventLoop: self.running = False self.start = None - def __run_events(self): + def __run_events(self) -> None: """ Run events. @@ -159,7 +164,7 @@ class EventLoop: if schedule: self.__schedule_event() - def __schedule_event(self): + def __schedule_event(self) -> None: """ Schedule event. @@ -177,7 +182,7 @@ class EventLoop: self.timer.daemon = True self.timer.start() - def run(self): + def run(self) -> None: """ Start event loop. @@ -192,7 +197,7 @@ class EventLoop: event.time += self.start self.__schedule_event() - def stop(self): + def stop(self) -> None: """ Stop event loop. @@ -209,7 +214,7 @@ class EventLoop: self.running = False self.start = None - def add_event(self, delaysec, func, *args, **kwds): + def add_event(self, delaysec: float, func: Callable, *args: Any, **kwds: Any): """ Add an event to the event loop. diff --git a/daemon/core/location/mobility.py b/daemon/core/location/mobility.py index 41678d43..e1917636 100644 --- a/daemon/core/location/mobility.py +++ b/daemon/core/location/mobility.py @@ -9,6 +9,7 @@ import os import threading import time from functools import total_ordering +from typing import TYPE_CHECKING, Dict, List, Tuple from core import utils from core.config import ConfigGroup, ConfigurableOptions, Configuration, ModelManager @@ -21,6 +22,11 @@ from core.emulator.enumerations import ( RegisterTlvs, ) from core.errors import CoreError +from core.nodes.base import CoreNode, NodeBase +from core.nodes.interface import CoreInterface + +if TYPE_CHECKING: + from core.emulator.session import Session class MobilityManager(ModelManager): @@ -32,7 +38,7 @@ class MobilityManager(ModelManager): name = "MobilityManager" config_type = RegisterTlvs.WIRELESS.value - def __init__(self, session): + def __init__(self, session: "Session") -> None: """ Creates a MobilityManager instance. @@ -43,7 +49,7 @@ class MobilityManager(ModelManager): self.models[BasicRangeModel.name] = BasicRangeModel self.models[Ns2ScriptedMobility.name] = Ns2ScriptedMobility - def reset(self): + def reset(self) -> None: """ Clear out all current configurations. @@ -51,7 +57,7 @@ class MobilityManager(ModelManager): """ self.config_reset() - def startup(self, node_ids=None): + def startup(self, node_ids: List[int] = None) -> None: """ Session is transitioning from instantiation to runtime state. Instantiate any mobility models that have been configured for a WLAN. @@ -86,7 +92,7 @@ class MobilityManager(ModelManager): if node.mobility: self.session.event_loop.add_event(0.0, node.mobility.startup) - def handleevent(self, event_data): + def handleevent(self, event_data: EventData) -> None: """ Handle an Event Message used to start, stop, or pause mobility scripts for a given WlanNode. @@ -149,7 +155,7 @@ class MobilityManager(ModelManager): if event_type == EventTypes.PAUSE.value: model.pause() - def sendevent(self, model): + def sendevent(self, model: "WayPointMobility") -> None: """ Send an event message on behalf of a mobility model. This communicates the current and end (max) times to the GUI. @@ -179,7 +185,9 @@ class MobilityManager(ModelManager): self.session.broadcast_event(event_data) - def updatewlans(self, moved, moved_netifs): + def updatewlans( + self, moved: List[NodeBase], moved_netifs: List[CoreInterface] + ) -> None: """ A mobility script has caused nodes in the 'moved' list to move. Update every WlanNode. This saves range calculations if the model @@ -208,7 +216,7 @@ class WirelessModel(ConfigurableOptions): bitmap = None position_callback = None - def __init__(self, session, _id): + def __init__(self, session: "Session", _id: int): """ Create a WirelessModel instance. @@ -218,7 +226,7 @@ class WirelessModel(ConfigurableOptions): self.session = session self.id = _id - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List: """ May be used if the model can populate the GUI with wireless (green) link lines. @@ -229,7 +237,7 @@ class WirelessModel(ConfigurableOptions): """ return [] - def update(self, moved, moved_netifs): + def update(self, moved: bool, moved_netifs: List[CoreInterface]) -> None: """ Update this wireless model. @@ -239,10 +247,10 @@ class WirelessModel(ConfigurableOptions): """ raise NotImplementedError - def update_config(self, config): + def update_config(self, config: Dict[str, str]) -> None: """ - For run-time updates of model config. Returns True when position callback and set link - parameters should be invoked. + For run-time updates of model config. Returns True when position callback and + set link parameters should be invoked. :param dict config: configuration values to update :return: nothing @@ -295,7 +303,7 @@ class BasicRangeModel(WirelessModel): def config_groups(cls): return [ConfigGroup("Basic Range Parameters", 1, len(cls.configurations()))] - def __init__(self, session, _id): + def __init__(self, session: "Session", _id: int) -> None: """ Create a BasicRangeModel instance. @@ -314,7 +322,7 @@ class BasicRangeModel(WirelessModel): self.loss = None self.jitter = None - def values_from_config(self, config): + def values_from_config(self, config: Dict[str, str]) -> None: """ Values to convert to link parameters. @@ -340,7 +348,7 @@ class BasicRangeModel(WirelessModel): if self.jitter == 0: self.jitter = None - def setlinkparams(self): + def setlinkparams(self) -> None: """ Apply link parameters to all interfaces. This is invoked from WlanNode.setmodel() after the position callback has been set. @@ -356,7 +364,7 @@ class BasicRangeModel(WirelessModel): jitter=self.jitter, ) - def get_position(self, netif): + def get_position(self, netif: CoreInterface) -> Tuple[float, float, float]: """ Retrieve network interface position. @@ -366,7 +374,9 @@ class BasicRangeModel(WirelessModel): with self._netifslock: return self._netifs[netif] - def set_position(self, netif, x=None, y=None, z=None): + def set_position( + self, netif: CoreInterface, x: float = None, y: float = None, z: float = None + ) -> None: """ A node has moved; given an interface, a new (x,y,z) position has been set; calculate the new distance between other nodes and link or @@ -389,7 +399,7 @@ class BasicRangeModel(WirelessModel): position_callback = set_position - def update(self, moved, moved_netifs): + def update(self, moved: bool, moved_netifs: List[CoreInterface]) -> None: """ Node positions have changed without recalc. Update positions from node.position, then re-calculate links for those that have moved. @@ -411,7 +421,7 @@ class BasicRangeModel(WirelessModel): continue self.calclink(netif, netif2) - def calclink(self, netif, netif2): + def calclink(self, netif: CoreInterface, netif2: CoreInterface) -> None: """ Helper used by set_position() and update() to calculate distance between two interfaces and perform @@ -455,7 +465,9 @@ class BasicRangeModel(WirelessModel): logging.exception("error getting interfaces during calclinkS") @staticmethod - def calcdistance(p1, p2): + def calcdistance( + p1: Tuple[float, float, float], p2: Tuple[float, float, float] + ) -> float: """ Calculate the distance between two three-dimensional points. @@ -471,7 +483,7 @@ class BasicRangeModel(WirelessModel): c = p1[2] - p2[2] return math.hypot(math.hypot(a, b), c) - def update_config(self, config): + def update_config(self, config: Dict[str, str]) -> None: """ Configuration has changed during runtime. @@ -482,12 +494,14 @@ class BasicRangeModel(WirelessModel): self.setlinkparams() return True - def create_link_data(self, interface1, interface2, message_type): + def create_link_data( + self, interface1: CoreInterface, interface2: CoreInterface, message_type: int + ) -> LinkData: """ Create a wireless link/unlink data message. - :param core.coreobj.PyCoreNetIf interface1: interface one - :param core.coreobj.PyCoreNetIf interface2: interface two + :param core.nodes.interface.CoreInterface interface1: interface one + :param core.nodes.interface.CoreInterface interface2: interface two :param message_type: link message type :return: link data :rtype: LinkData @@ -500,7 +514,9 @@ class BasicRangeModel(WirelessModel): link_type=LinkTypes.WIRELESS.value, ) - def sendlinkmsg(self, netif, netif2, unlink=False): + def sendlinkmsg( + self, netif: CoreInterface, netif2: CoreInterface, unlink: bool = False + ) -> None: """ Send a wireless link/unlink API message to the GUI. @@ -517,7 +533,7 @@ class BasicRangeModel(WirelessModel): link_data = self.create_link_data(netif, netif2, message_type) self.session.broadcast_link(link_data) - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List[LinkData]: """ Return a list of wireless link messages for when the GUI reconnects. @@ -540,7 +556,7 @@ class WayPoint: Maintains information regarding waypoints. """ - def __init__(self, time, nodenum, coords, speed): + def __init__(self, time: float, nodenum: int, coords, speed: float): """ Creates a WayPoint instance. @@ -554,13 +570,13 @@ class WayPoint: self.coords = coords self.speed = speed - def __eq__(self, other): - return (self.time, self.nodenum) == (other.time, other.nodedum) + def __eq__(self, other: "WayPoint") -> bool: + return (self.time, self.nodenum) == (other.time, other.nodenum) - def __ne__(self, other): + def __ne__(self, other: "WayPoint") -> bool: return not self == other - def __lt__(self, other): + def __lt__(self, other: "WayPoint") -> bool: result = self.time < other.time if result: result = self.nodenum < other.nodenum @@ -579,7 +595,7 @@ class WayPointMobility(WirelessModel): STATE_RUNNING = 1 STATE_PAUSED = 2 - def __init__(self, session, _id): + def __init__(self, session: "Session", _id: int) -> None: """ Create a WayPointMobility instance. @@ -603,7 +619,7 @@ class WayPointMobility(WirelessModel): # (ns-3 sets this to False as new waypoints may be added from trace) self.empty_queue_stop = True - def runround(self): + def runround(self) -> None: """ Advance script time and move nodes. @@ -657,7 +673,7 @@ class WayPointMobility(WirelessModel): # TODO: check session state self.session.event_loop.add_event(0.001 * self.refresh_ms, self.runround) - def run(self): + def run(self) -> None: """ Run the waypoint mobility scenario. @@ -670,7 +686,7 @@ class WayPointMobility(WirelessModel): self.runround() self.session.mobility.sendevent(self) - def movenode(self, node, dt): + def movenode(self, node: CoreNode, dt: float) -> bool: """ Calculate next node location and update its coordinates. Returns True if the node's position has changed. @@ -723,7 +739,7 @@ class WayPointMobility(WirelessModel): self.setnodeposition(node, x1 + dx, y1 + dy, z1) return True - def movenodesinitial(self): + def movenodesinitial(self) -> None: """ Move nodes to their initial positions. Then calculate the ranges. @@ -741,11 +757,13 @@ class WayPointMobility(WirelessModel): moved_netifs.append(netif) self.session.mobility.updatewlans(moved, moved_netifs) - def addwaypoint(self, time, nodenum, x, y, z, speed): + def addwaypoint( + self, _time: float, nodenum: int, x: float, y: float, z: float, speed: float + ) -> None: """ Waypoints are pushed to a heapq, sorted by time. - :param time: waypoint time + :param _time: waypoint time :param int nodenum: node id :param x: x position :param y: y position @@ -753,10 +771,10 @@ class WayPointMobility(WirelessModel): :param speed: speed :return: nothing """ - wp = WayPoint(time, nodenum, coords=(x, y, z), speed=speed) + wp = WayPoint(_time, nodenum, coords=(x, y, z), speed=speed) heapq.heappush(self.queue, wp) - def addinitial(self, nodenum, x, y, z): + def addinitial(self, nodenum: int, x: float, y: float, z: float) -> None: """ Record initial position in a dict. @@ -769,11 +787,11 @@ class WayPointMobility(WirelessModel): wp = WayPoint(0, nodenum, coords=(x, y, z), speed=0) self.initial[nodenum] = wp - def updatepoints(self, now): + def updatepoints(self, now: float) -> None: """ Move items from self.queue to self.points when their time has come. - :param int now: current timestamp + :param float now: current timestamp :return: nothing """ while len(self.queue): @@ -782,7 +800,7 @@ class WayPointMobility(WirelessModel): wp = heapq.heappop(self.queue) self.points[wp.nodenum] = wp - def copywaypoints(self): + def copywaypoints(self) -> None: """ Store backup copy of waypoints for looping and stopping. @@ -790,7 +808,7 @@ class WayPointMobility(WirelessModel): """ self.queue_copy = list(self.queue) - def loopwaypoints(self): + def loopwaypoints(self) -> None: """ Restore backup copy of waypoints when looping. @@ -799,13 +817,13 @@ class WayPointMobility(WirelessModel): self.queue = list(self.queue_copy) return self.loop - def setnodeposition(self, node, x, y, z): + def setnodeposition(self, node: CoreNode, x: float, y: float, z: float) -> None: """ Helper to move a node, notify any GUI (connected session handlers), without invoking the interface poshook callback that may perform range calculation. - :param core.netns.vnode.CoreNode node: node to set position for + :param core.nodes.base.CoreNode node: node to set position for :param x: x position :param y: y position :param z: z position @@ -815,7 +833,7 @@ class WayPointMobility(WirelessModel): node_data = node.data(message_type=0) self.session.broadcast_node(node_data) - def setendtime(self): + def setendtime(self) -> None: """ Set self.endtime to the time of the last waypoint in the queue of waypoints. This is just an estimate. The endtime will later be @@ -829,7 +847,7 @@ class WayPointMobility(WirelessModel): except IndexError: self.endtime = 0 - def start(self): + def start(self) -> None: """ Run the script from the beginning or unpause from where it was before. @@ -849,11 +867,12 @@ class WayPointMobility(WirelessModel): self.lasttime = now - (0.001 * self.refresh_ms) self.runround() - def stop(self, move_initial=True): + def stop(self, move_initial: bool = True) -> None: """ Stop the script and move nodes to initial positions. - :param bool move_initial: flag to check if we should move nodes to initial position + :param bool move_initial: flag to check if we should move nodes to initial + position :return: nothing """ self.state = self.STATE_STOPPED @@ -864,7 +883,7 @@ class WayPointMobility(WirelessModel): self.movenodesinitial() self.session.mobility.sendevent(self) - def pause(self): + def pause(self) -> None: """ Pause the script; pause time is stored to self.lasttime. @@ -926,12 +945,12 @@ class Ns2ScriptedMobility(WayPointMobility): ] @classmethod - def config_groups(cls): + def config_groups(cls) -> List[ConfigGroup]: return [ ConfigGroup("ns-2 Mobility Script Parameters", 1, len(cls.configurations())) ] - def __init__(self, session, _id): + def __init__(self, session: "Session", _id: int): """ Creates a Ns2ScriptedMobility instance. @@ -951,7 +970,7 @@ class Ns2ScriptedMobility(WayPointMobility): self.script_pause = None self.script_stop = None - def update_config(self, config): + def update_config(self, config: Dict[str, str]) -> None: self.file = config["file"] logging.info( "ns-2 scripted mobility configured for WLAN %d using file: %s", @@ -969,7 +988,7 @@ class Ns2ScriptedMobility(WayPointMobility): self.copywaypoints() self.setendtime() - def readscriptfile(self): + def readscriptfile(self) -> None: """ Read in mobility script from a file. This adds waypoints to a priority queue, sorted by waypoint time. Initial waypoints are @@ -1012,7 +1031,6 @@ class Ns2ScriptedMobility(WayPointMobility): # initial position (time=0, speed=0): # $node_(6) set X_ 780.0 parts = line.split() - time = 0.0 nodenum = parts[0][1 + parts[0].index("(") : parts[0].index(")")] if parts[2] == "X_": if ix is not None and iy is not None: @@ -1036,7 +1054,7 @@ class Ns2ScriptedMobility(WayPointMobility): if ix is not None and iy is not None: self.addinitial(self.map(inodenum), ix, iy, iz) - def findfile(self, file_name): + def findfile(self, file_name: str) -> str: """ Locate a script file. If the specified file doesn't exist, look in the same directory as the scenario file, or in the default @@ -1065,7 +1083,7 @@ class Ns2ScriptedMobility(WayPointMobility): return file_name - def parsemap(self, mapstr): + def parsemap(self, mapstr: str) -> None: """ Parse a node mapping string, given as a configuration parameter. @@ -1085,18 +1103,18 @@ class Ns2ScriptedMobility(WayPointMobility): except ValueError: logging.exception("ns-2 mobility node map error") - def map(self, nodenum): + def map(self, nodenum: str) -> int: """ Map one node number (from a script file) to another. - :param str nodenum: node id to map + :param int nodenum: node id to map :return: mapped value or the node id itself :rtype: int """ nodenum = int(nodenum) return self.nodemap.get(nodenum, nodenum) - def startup(self): + def startup(self) -> None: """ Start running the script if autostart is enabled. Move node to initial positions when any autostart time is specified. @@ -1122,7 +1140,7 @@ class Ns2ScriptedMobility(WayPointMobility): self.state = self.STATE_RUNNING self.session.event_loop.add_event(t, self.run) - def start(self): + def start(self) -> None: """ Handle the case when un-paused. @@ -1134,7 +1152,7 @@ class Ns2ScriptedMobility(WayPointMobility): if laststate == self.STATE_PAUSED: self.statescript("unpause") - def run(self): + def run(self) -> None: """ Start is pressed or autostart is triggered. @@ -1143,7 +1161,7 @@ class Ns2ScriptedMobility(WayPointMobility): super().run() self.statescript("run") - def pause(self): + def pause(self) -> None: """ Pause the mobility script. @@ -1152,17 +1170,18 @@ class Ns2ScriptedMobility(WayPointMobility): super().pause() self.statescript("pause") - def stop(self, move_initial=True): + def stop(self, move_initial: bool = True) -> None: """ Stop the mobility script. - :param bool move_initial: flag to check if we should move node to initial position + :param bool move_initial: flag to check if we should move node to initial + position :return: nothing """ super().stop(move_initial=move_initial) self.statescript("stop") - def statescript(self, typestr): + def statescript(self, typestr: str) -> None: """ State of the mobility script. diff --git a/daemon/core/nodes/base.py b/daemon/core/nodes/base.py index c34a42b4..ccfc2c1c 100644 --- a/daemon/core/nodes/base.py +++ b/daemon/core/nodes/base.py @@ -6,6 +6,7 @@ import logging import os import shutil import threading +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import netaddr @@ -15,8 +16,12 @@ from core.emulator.data import LinkData, NodeData from core.emulator.enumerations import LinkTypes, NodeTypes from core.errors import CoreCommandError from core.nodes import client -from core.nodes.interface import TunTap, Veth -from core.nodes.netclient import get_net_client +from core.nodes.interface import CoreInterface, TunTap, Veth +from core.nodes.netclient import LinuxNetClient, get_net_client + +if TYPE_CHECKING: + from core.emulator.distributed import DistributedServer + from core.emulator.session import Session _DEFAULT_MTU = 1500 @@ -29,9 +34,16 @@ class NodeBase: apitype = None # TODO: appears start has no usage, verify and remove - def __init__(self, session, _id=None, name=None, start=True, server=None): + def __init__( + self, + session: "Session", + _id: int = None, + name: str = None, + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ - Creates a PyCoreObj instance. + Creates a NodeBase instance. :param core.emulator.session.Session session: CORE session object :param int _id: id @@ -63,7 +75,7 @@ class NodeBase: use_ovs = session.options.get_config("ovs") == "True" self.net_client = get_net_client(use_ovs, self.host_cmd) - def startup(self): + def startup(self) -> None: """ Each object implements its own startup method. @@ -71,7 +83,7 @@ class NodeBase: """ raise NotImplementedError - def shutdown(self): + def shutdown(self) -> None: """ Each object implements its own shutdown method. @@ -79,7 +91,14 @@ class NodeBase: """ raise NotImplementedError - def host_cmd(self, args, env=None, cwd=None, wait=True, shell=False): + def host_cmd( + self, + args: str, + env: Dict[str, str] = None, + cwd: str = None, + wait: bool = True, + shell: bool = False, + ) -> str: """ Runs a command on the host system or distributed server. @@ -97,7 +116,7 @@ class NodeBase: else: return self.server.remote_cmd(args, env, cwd, wait) - def setposition(self, x=None, y=None, z=None): + def setposition(self, x: float = None, y: float = None, z: float = None) -> bool: """ Set the (x,y,z) position of the object. @@ -109,7 +128,7 @@ class NodeBase: """ return self.position.set(x=x, y=y, z=z) - def getposition(self): + def getposition(self) -> Tuple[float, float, float]: """ Return an (x,y,z) tuple representing this object's position. @@ -118,7 +137,7 @@ class NodeBase: """ return self.position.get() - def ifname(self, ifindex): + def ifname(self, ifindex: int) -> str: """ Retrieve interface name for index. @@ -128,7 +147,7 @@ class NodeBase: """ return self._netif[ifindex].name - def netifs(self, sort=False): + def netifs(self, sort: bool = False) -> List[CoreInterface]: """ Retrieve network interfaces, sorted if desired. @@ -141,7 +160,7 @@ class NodeBase: else: return list(self._netif.values()) - def numnetif(self): + def numnetif(self) -> int: """ Return the attached interface count. @@ -150,7 +169,7 @@ class NodeBase: """ return len(self._netif) - def getifindex(self, netif): + def getifindex(self, netif: CoreInterface) -> int: """ Retrieve index for an interface. @@ -163,7 +182,7 @@ class NodeBase: return ifindex return -1 - def newifindex(self): + def newifindex(self) -> int: """ Create a new interface index. @@ -176,7 +195,14 @@ class NodeBase: self.ifindex += 1 return ifindex - def data(self, message_type, lat=None, lon=None, alt=None, source=None): + def data( + self, + message_type: int, + lat: float = None, + lon: float = None, + alt: float = None, + source: str = None, + ) -> NodeData: """ Build a data object for this node. @@ -223,7 +249,7 @@ class NodeBase: return node_data - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List: """ Build CORE Link data for this object. There is no default method for PyCoreObjs as PyCoreNodes do not implement this but @@ -241,7 +267,14 @@ class CoreNodeBase(NodeBase): Base class for CORE nodes. """ - def __init__(self, session, _id=None, name=None, start=True, server=None): + def __init__( + self, + session: "Session", + _id: int = None, + name: str = None, + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ Create a CoreNodeBase instance. @@ -257,7 +290,7 @@ class CoreNodeBase(NodeBase): self.nodedir = None self.tmpnodedir = False - def makenodedir(self): + def makenodedir(self) -> None: """ Create the node directory. @@ -270,7 +303,7 @@ class CoreNodeBase(NodeBase): else: self.tmpnodedir = False - def rmnodedir(self): + def rmnodedir(self) -> None: """ Remove the node directory, unless preserve directory has been set. @@ -283,7 +316,7 @@ class CoreNodeBase(NodeBase): if self.tmpnodedir: self.host_cmd(f"rm -rf {self.nodedir}") - def addnetif(self, netif, ifindex): + def addnetif(self, netif: CoreInterface, ifindex: int) -> None: """ Add network interface to node and set the network interface index if successful. @@ -296,7 +329,7 @@ class CoreNodeBase(NodeBase): self._netif[ifindex] = netif netif.netindex = ifindex - def delnetif(self, ifindex): + def delnetif(self, ifindex: int) -> None: """ Delete a network interface @@ -309,7 +342,7 @@ class CoreNodeBase(NodeBase): netif.shutdown() del netif - def netif(self, ifindex): + def netif(self, ifindex: int) -> Optional[CoreInterface]: """ Retrieve network interface. @@ -322,7 +355,7 @@ class CoreNodeBase(NodeBase): else: return None - def attachnet(self, ifindex, net): + def attachnet(self, ifindex: int, net: "CoreNetworkBase") -> None: """ Attach a network. @@ -334,7 +367,7 @@ class CoreNodeBase(NodeBase): raise ValueError(f"ifindex {ifindex} does not exist") self._netif[ifindex].attachnet(net) - def detachnet(self, ifindex): + def detachnet(self, ifindex: int) -> None: """ Detach network interface. @@ -345,7 +378,7 @@ class CoreNodeBase(NodeBase): raise ValueError(f"ifindex {ifindex} does not exist") self._netif[ifindex].detachnet() - def setposition(self, x=None, y=None, z=None): + def setposition(self, x: float = None, y: float = None, z: float = None) -> None: """ Set position. @@ -359,7 +392,9 @@ class CoreNodeBase(NodeBase): for netif in self.netifs(sort=True): netif.setposition(x, y, z) - def commonnets(self, obj, want_ctrl=False): + def commonnets( + self, obj: "CoreNodeBase", want_ctrl: bool = False + ) -> List[Tuple[NodeBase, CoreInterface, CoreInterface]]: """ Given another node or net object, return common networks between this node and that object. A list of tuples is returned, with each tuple @@ -377,10 +412,9 @@ class CoreNodeBase(NodeBase): for netif2 in obj.netifs(): if netif1.net == netif2.net: common.append((netif1.net, netif1, netif2)) - return common - def cmd(self, args, wait=True, shell=False): + def cmd(self, args: str, wait: bool = True, shell: bool = False) -> str: """ Runs a command within a node container. @@ -393,7 +427,7 @@ class CoreNodeBase(NodeBase): """ raise NotImplementedError - def termcmdstring(self, sh): + def termcmdstring(self, sh: str) -> str: """ Create a terminal command string. @@ -413,14 +447,14 @@ class CoreNode(CoreNodeBase): def __init__( self, - session, - _id=None, - name=None, - nodedir=None, - bootsh="boot.sh", - start=True, - server=None, - ): + session: "Session", + _id: int = None, + name: str = None, + nodedir: str = None, + bootsh: str = "boot.sh", + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ Create a CoreNode instance. @@ -451,17 +485,17 @@ class CoreNode(CoreNodeBase): if start: self.startup() - def create_node_net_client(self, use_ovs): + def create_node_net_client(self, use_ovs: bool) -> LinuxNetClient: """ Create node network client for running network commands within the nodes container. :param bool use_ovs: True for OVS bridges, False for Linux bridges - :return:node network client + :return: node network client """ return get_net_client(use_ovs, self.cmd) - def alive(self): + def alive(self) -> bool: """ Check if the node is alive. @@ -475,7 +509,7 @@ class CoreNode(CoreNodeBase): return True - def startup(self): + def startup(self) -> None: """ Start a new namespace node by invoking the vnoded process that allocates a new namespace. Bring up the loopback device and set @@ -521,7 +555,7 @@ class CoreNode(CoreNodeBase): self.privatedir("/var/run") self.privatedir("/var/log") - def shutdown(self): + def shutdown(self) -> None: """ Shutdown logic for simple lxc nodes. @@ -562,7 +596,7 @@ class CoreNode(CoreNodeBase): finally: self.rmnodedir() - def cmd(self, args, wait=True, shell=False): + def cmd(self, args: str, wait: bool = True, shell: bool = False) -> str: """ Runs a command that is used to configure and setup the network within a node. @@ -580,7 +614,7 @@ class CoreNode(CoreNodeBase): args = self.client.create_cmd(args) return self.server.remote_cmd(args, wait=wait) - def termcmdstring(self, sh="/bin/sh"): + def termcmdstring(self, sh: str = "/bin/sh") -> str: """ Create a terminal command string. @@ -593,7 +627,7 @@ class CoreNode(CoreNodeBase): else: return f"ssh -X -f {self.server.host} xterm -e {terminal}" - def privatedir(self, path): + def privatedir(self, path: str) -> None: """ Create a private directory. @@ -608,7 +642,7 @@ class CoreNode(CoreNodeBase): self.host_cmd(f"mkdir -p {hostpath}") self.mount(hostpath, path) - def mount(self, source, target): + def mount(self, source: str, target: str) -> None: """ Create and mount a directory. @@ -623,7 +657,7 @@ class CoreNode(CoreNodeBase): self.cmd(f"{MOUNT_BIN} -n --bind {source} {target}") self._mounts.append((source, target)) - def newifindex(self): + def newifindex(self) -> int: """ Retrieve a new interface index. @@ -633,7 +667,7 @@ class CoreNode(CoreNodeBase): with self.lock: return super().newifindex() - def newveth(self, ifindex=None, ifname=None): + def newveth(self, ifindex: int = None, ifname: str = None) -> int: """ Create a new interface. @@ -690,7 +724,7 @@ class CoreNode(CoreNodeBase): return ifindex - def newtuntap(self, ifindex=None, ifname=None): + def newtuntap(self, ifindex: int = None, ifname: str = None) -> int: """ Create a new tunnel tap. @@ -720,7 +754,7 @@ class CoreNode(CoreNodeBase): return ifindex - def sethwaddr(self, ifindex, addr): + def sethwaddr(self, ifindex: int, addr: str) -> None: """ Set hardware addres for an interface. @@ -735,7 +769,7 @@ class CoreNode(CoreNodeBase): if self.up: self.node_net_client.device_mac(interface.name, addr) - def addaddr(self, ifindex, addr): + def addaddr(self, ifindex: int, addr: str) -> None: """ Add interface address. @@ -753,7 +787,7 @@ class CoreNode(CoreNodeBase): broadcast = "+" self.node_net_client.create_address(interface.name, addr, broadcast) - def deladdr(self, ifindex, addr): + def deladdr(self, ifindex: int, addr: str) -> None: """ Delete address from an interface. @@ -772,7 +806,7 @@ class CoreNode(CoreNodeBase): if self.up: self.node_net_client.delete_address(interface.name, addr) - def ifup(self, ifindex): + def ifup(self, ifindex: int) -> None: """ Bring an interface up. @@ -783,7 +817,14 @@ class CoreNode(CoreNodeBase): interface_name = self.ifname(ifindex) self.node_net_client.device_up(interface_name) - def newnetif(self, net=None, addrlist=None, hwaddr=None, ifindex=None, ifname=None): + def newnetif( + self, + net: "CoreNetworkBase" = None, + addrlist: List[str] = None, + hwaddr: str = None, + ifindex: int = None, + ifname: str = None, + ) -> int: """ Create a new network interface. @@ -827,7 +868,7 @@ class CoreNode(CoreNodeBase): self.ifup(ifindex) return ifindex - def addfile(self, srcname, filename): + def addfile(self, srcname: str, filename: str) -> None: """ Add a file. @@ -846,7 +887,7 @@ class CoreNode(CoreNodeBase): self.host_cmd(f"mkdir -p {directory}") self.server.remote_put(srcname, filename) - def hostfilename(self, filename): + def hostfilename(self, filename: str) -> str: """ Return the name of a node"s file on the host filesystem. @@ -862,7 +903,7 @@ class CoreNode(CoreNodeBase): dirname = os.path.join(self.nodedir, dirname) return os.path.join(dirname, basename) - def nodefile(self, filename, contents, mode=0o644): + def nodefile(self, filename: str, contents: str, mode: int = 0o644) -> None: """ Create a node file with a given mode. @@ -887,7 +928,7 @@ class CoreNode(CoreNodeBase): "node(%s) added file: %s; mode: 0%o", self.name, hostfilename, mode ) - def nodefilecopy(self, filename, srcfilename, mode=None): + def nodefilecopy(self, filename: str, srcfilename: str, mode: int = None) -> None: """ Copy a file to a node, following symlinks and preserving metadata. Change file mode if specified. @@ -917,7 +958,14 @@ class CoreNetworkBase(NodeBase): linktype = LinkTypes.WIRED.value is_emane = False - def __init__(self, session, _id, name, start=True, server=None): + def __init__( + self, + session: "Session", + _id: int, + name: str, + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ Create a CoreNetworkBase instance. @@ -932,7 +980,7 @@ class CoreNetworkBase(NodeBase): self._linked = {} self._linked_lock = threading.Lock() - def startup(self): + def startup(self) -> None: """ Each object implements its own startup method. @@ -940,7 +988,7 @@ class CoreNetworkBase(NodeBase): """ raise NotImplementedError - def shutdown(self): + def shutdown(self) -> None: """ Each object implements its own shutdown method. @@ -948,7 +996,30 @@ class CoreNetworkBase(NodeBase): """ raise NotImplementedError - def attach(self, netif): + def linknet(self, net: "CoreNetworkBase") -> CoreInterface: + """ + Link network to another. + + :param core.nodes.base.CoreNetworkBase net: network to link with + :return: created interface + :rtype: core.nodes.interface.Veth + """ + pass + + def getlinknetif(self, net: "CoreNetworkBase") -> CoreInterface: + """ + Return the interface of that links this net with another net. + + :param core.nodes.base.CoreNetworkBase net: interface to get link for + :return: interface the provided network is linked to + :rtype: core.nodes.interface.CoreInterface + """ + for netif in self.netifs(): + if hasattr(netif, "othernet") and netif.othernet == net: + return netif + return None + + def attach(self, netif: CoreInterface) -> None: """ Attach network interface. @@ -961,7 +1032,7 @@ class CoreNetworkBase(NodeBase): with self._linked_lock: self._linked[netif] = {} - def detach(self, netif): + def detach(self, netif: CoreInterface) -> None: """ Detach network interface. @@ -973,7 +1044,7 @@ class CoreNetworkBase(NodeBase): with self._linked_lock: del self._linked[netif] - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List[LinkData]: """ Build link data objects for this network. Each link object describes a link between this network and a node. @@ -981,7 +1052,6 @@ class CoreNetworkBase(NodeBase): :param int flags: message type :return: list of link data :rtype: list[core.data.LinkData] - """ all_links = [] @@ -1072,7 +1142,7 @@ class Position: Helper class for Cartesian coordinate position """ - def __init__(self, x=None, y=None, z=None): + def __init__(self, x: float = None, y: float = None, z: float = None) -> None: """ Creates a Position instance. @@ -1085,7 +1155,7 @@ class Position: self.y = y self.z = z - def set(self, x=None, y=None, z=None): + def set(self, x: float = None, y: float = None, z: float = None) -> bool: """ Returns True if the position has actually changed. @@ -1102,7 +1172,7 @@ class Position: self.z = z return True - def get(self): + def get(self) -> Tuple[float, float, float]: """ Retrieve x,y,z position. diff --git a/daemon/core/nodes/client.py b/daemon/core/nodes/client.py index 66b61c37..c2624d6b 100644 --- a/daemon/core/nodes/client.py +++ b/daemon/core/nodes/client.py @@ -13,7 +13,7 @@ class VnodeClient: Provides client functionality for interacting with a virtual node. """ - def __init__(self, name, ctrlchnlname): + def __init__(self, name: str, ctrlchnlname: str) -> None: """ Create a VnodeClient instance. @@ -23,7 +23,7 @@ class VnodeClient: self.name = name self.ctrlchnlname = ctrlchnlname - def _verify_connection(self): + def _verify_connection(self) -> None: """ Checks that the vcmd client is properly connected. @@ -33,7 +33,7 @@ class VnodeClient: if not self.connected(): raise IOError("vcmd not connected") - def connected(self): + def connected(self) -> bool: """ Check if node is connected or not. @@ -42,7 +42,7 @@ class VnodeClient: """ return True - def close(self): + def close(self) -> None: """ Close the client connection. @@ -50,10 +50,10 @@ class VnodeClient: """ pass - def create_cmd(self, args): + def create_cmd(self, args: str) -> str: return f"{VCMD_BIN} -c {self.ctrlchnlname} -- {args}" - def check_cmd(self, args, wait=True, shell=False): + def check_cmd(self, args: str, wait: bool = True, shell: bool = False) -> str: """ Run command and return exit status and combined stdout and stderr. diff --git a/daemon/core/nodes/docker.py b/daemon/core/nodes/docker.py index b56fcc5c..1d7f3f02 100644 --- a/daemon/core/nodes/docker.py +++ b/daemon/core/nodes/docker.py @@ -2,22 +2,27 @@ import json import logging import os from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Callable, Dict from core import utils +from core.emulator.distributed import DistributedServer from core.emulator.enumerations import NodeTypes from core.errors import CoreCommandError from core.nodes.base import CoreNode -from core.nodes.netclient import get_net_client +from core.nodes.netclient import LinuxNetClient, get_net_client + +if TYPE_CHECKING: + from core.emulator.session import Session class DockerClient: - def __init__(self, name, image, run): + def __init__(self, name: str, image: str, run: Callable[..., str]) -> None: self.name = name self.image = image self.run = run self.pid = None - def create_container(self): + def create_container(self) -> str: self.run( f"docker run -td --init --net=none --hostname {self.name} --name {self.name} " f"--sysctl net.ipv6.conf.all.disable_ipv6=0 {self.image} /bin/bash" @@ -25,7 +30,7 @@ class DockerClient: self.pid = self.get_pid() return self.pid - def get_info(self): + def get_info(self) -> Dict: args = f"docker inspect {self.name}" output = self.run(args) data = json.loads(output) @@ -33,35 +38,35 @@ class DockerClient: raise CoreCommandError(-1, args, f"docker({self.name}) not present") return data[0] - def is_alive(self): + def is_alive(self) -> bool: try: data = self.get_info() return data["State"]["Running"] except CoreCommandError: return False - def stop_container(self): + def stop_container(self) -> None: self.run(f"docker rm -f {self.name}") - def check_cmd(self, cmd, wait=True, shell=False): + def check_cmd(self, cmd: str, wait: bool = True, shell: bool = False) -> str: logging.info("docker cmd output: %s", cmd) return utils.cmd(f"docker exec {self.name} {cmd}", wait=wait, shell=shell) - def create_ns_cmd(self, cmd): + def create_ns_cmd(self, cmd: str) -> str: return f"nsenter -t {self.pid} -u -i -p -n {cmd}" - def ns_cmd(self, cmd, wait): + def ns_cmd(self, cmd: str, wait: bool) -> str: args = f"nsenter -t {self.pid} -u -i -p -n {cmd}" return utils.cmd(args, wait=wait) - def get_pid(self): + def get_pid(self) -> str: args = f"docker inspect -f '{{{{.State.Pid}}}}' {self.name}" output = self.run(args) self.pid = output logging.debug("node(%s) pid: %s", self.name, self.pid) return output - def copy_file(self, source, destination): + def copy_file(self, source: str, destination: str) -> str: args = f"docker cp {source} {self.name}:{destination}" return self.run(args) @@ -71,15 +76,15 @@ class DockerNode(CoreNode): def __init__( self, - session, - _id=None, - name=None, - nodedir=None, - bootsh="boot.sh", - start=True, - server=None, - image=None - ): + session: "Session", + _id: int = None, + name: str = None, + nodedir: str = None, + bootsh: str = "boot.sh", + start: bool = True, + server: DistributedServer = None, + image: str = None + ) -> None: """ Create a DockerNode instance. @@ -98,7 +103,7 @@ class DockerNode(CoreNode): self.image = image super().__init__(session, _id, name, nodedir, bootsh, start, server) - def create_node_net_client(self, use_ovs): + def create_node_net_client(self, use_ovs: bool) -> LinuxNetClient: """ Create node network client for running network commands within the nodes container. @@ -108,7 +113,7 @@ class DockerNode(CoreNode): """ return get_net_client(use_ovs, self.nsenter_cmd) - def alive(self): + def alive(self) -> bool: """ Check if the node is alive. @@ -117,7 +122,7 @@ class DockerNode(CoreNode): """ return self.client.is_alive() - def startup(self): + def startup(self) -> None: """ Start a new namespace node by invoking the vnoded process that allocates a new namespace. Bring up the loopback device and set @@ -133,7 +138,7 @@ class DockerNode(CoreNode): self.pid = self.client.create_container() self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Shutdown logic. @@ -148,7 +153,7 @@ class DockerNode(CoreNode): self.client.stop_container() self.up = False - def nsenter_cmd(self, args, wait=True, shell=False): + def nsenter_cmd(self, args: str, wait: bool = True, shell: bool = False) -> str: if self.server is None: args = self.client.create_ns_cmd(args) return utils.cmd(args, wait=wait, shell=shell) @@ -156,7 +161,7 @@ class DockerNode(CoreNode): args = self.client.create_ns_cmd(args) return self.server.remote_cmd(args, wait=wait) - def termcmdstring(self, sh="/bin/sh"): + def termcmdstring(self, sh: str = "/bin/sh") -> str: """ Create a terminal command string. @@ -165,7 +170,7 @@ class DockerNode(CoreNode): """ return f"docker exec -it {self.name} bash" - def privatedir(self, path): + def privatedir(self, path: str) -> None: """ Create a private directory. @@ -176,7 +181,7 @@ class DockerNode(CoreNode): args = f"mkdir -p {path}" self.cmd(args) - def mount(self, source, target): + def mount(self, source: str, target: str) -> None: """ Create and mount a directory. @@ -188,7 +193,7 @@ class DockerNode(CoreNode): logging.debug("mounting source(%s) target(%s)", source, target) raise Exception("not supported") - def nodefile(self, filename, contents, mode=0o644): + def nodefile(self, filename: str, contents: str, mode: int = 0o644) -> None: """ Create a node file with a given mode. @@ -216,7 +221,7 @@ class DockerNode(CoreNode): "node(%s) added file: %s; mode: 0%o", self.name, filename, mode ) - def nodefilecopy(self, filename, srcfilename, mode=None): + def nodefilecopy(self, filename: str, srcfilename: str, mode: int = None) -> None: """ Copy a file to a node, following symlinks and preserving metadata. Change file mode if specified. diff --git a/daemon/core/nodes/interface.py b/daemon/core/nodes/interface.py index 236bdd5c..b583d3a8 100644 --- a/daemon/core/nodes/interface.py +++ b/daemon/core/nodes/interface.py @@ -4,18 +4,31 @@ virtual ethernet classes that implement the interfaces available under Linux. import logging import time +from typing import TYPE_CHECKING, Callable, Dict, List, Tuple from core import utils from core.errors import CoreCommandError from core.nodes.netclient import get_net_client +if TYPE_CHECKING: + from core.emulator.distributed import DistributedServer + from core.emulator.session import Session + from core.nodes.base import CoreNetworkBase, CoreNode + class CoreInterface: """ Base class for network interfaces. """ - def __init__(self, session, node, name, mtu, server=None): + def __init__( + self, + session: "Session", + node: "CoreNode", + name: str, + mtu: int, + server: "DistributedServer" = None, + ) -> None: """ Creates a CoreInterface instance. @@ -50,7 +63,14 @@ class CoreInterface: use_ovs = session.options.get_config("ovs") == "True" self.net_client = get_net_client(use_ovs, self.host_cmd) - def host_cmd(self, args, env=None, cwd=None, wait=True, shell=False): + def host_cmd( + self, + args: str, + env: Dict[str, str] = None, + cwd: str = None, + wait: bool = True, + shell: bool = False, + ) -> str: """ Runs a command on the host system or distributed server. @@ -68,7 +88,7 @@ class CoreInterface: else: return self.server.remote_cmd(args, env, cwd, wait) - def startup(self): + def startup(self) -> None: """ Startup method for the interface. @@ -76,7 +96,7 @@ class CoreInterface: """ pass - def shutdown(self): + def shutdown(self) -> None: """ Shutdown method for the interface. @@ -84,7 +104,7 @@ class CoreInterface: """ pass - def attachnet(self, net): + def attachnet(self, net: "CoreNetworkBase") -> None: """ Attach network. @@ -98,7 +118,7 @@ class CoreInterface: net.attach(self) self.net = net - def detachnet(self): + def detachnet(self) -> None: """ Detach from a network. @@ -107,7 +127,7 @@ class CoreInterface: if self.net is not None: self.net.detach(self) - def addaddr(self, addr): + def addaddr(self, addr: str) -> None: """ Add address. @@ -117,7 +137,7 @@ class CoreInterface: addr = utils.validate_ip(addr) self.addrlist.append(addr) - def deladdr(self, addr): + def deladdr(self, addr: str) -> None: """ Delete address. @@ -126,17 +146,18 @@ class CoreInterface: """ self.addrlist.remove(addr) - def sethwaddr(self, addr): + def sethwaddr(self, addr: str) -> None: """ Set hardware address. :param str addr: hardware address to set to. :return: nothing """ - addr = utils.validate_mac(addr) + if addr is not None: + addr = utils.validate_mac(addr) self.hwaddr = addr - def getparam(self, key): + def getparam(self, key: str) -> float: """ Retrieve a parameter from the, or None if the parameter does not exist. @@ -145,7 +166,7 @@ class CoreInterface: """ return self._params.get(key) - def getparams(self): + def getparams(self) -> List[Tuple[str, float]]: """ Return (key, value) pairs for parameters. """ @@ -154,7 +175,7 @@ class CoreInterface: parameters.append((k, self._params[k])) return parameters - def setparam(self, key, value): + def setparam(self, key: str, value: float) -> bool: """ Set a parameter value, returns True if the parameter has changed. @@ -174,7 +195,7 @@ class CoreInterface: self._params[key] = value return True - def swapparams(self, name): + def swapparams(self, name: str) -> None: """ Swap out parameters dict for name. If name does not exist, intialize it. This is for supporting separate upstream/downstream @@ -189,7 +210,7 @@ class CoreInterface: self._params = getattr(self, name) setattr(self, name, tmp) - def setposition(self, x, y, z): + def setposition(self, x: float, y: float, z: float) -> None: """ Dispatch position hook handler. @@ -200,7 +221,7 @@ class CoreInterface: """ self.poshook(self, x, y, z) - def __lt__(self, other): + def __lt__(self, other: "CoreInterface") -> bool: """ Used for comparisons of this object. @@ -217,8 +238,15 @@ class Veth(CoreInterface): """ def __init__( - self, session, node, name, localname, mtu=1500, server=None, start=True - ): + self, + session: "Session", + node: "CoreNode", + name: str, + localname: str, + mtu: int = 1500, + server: "DistributedServer" = None, + start: bool = True, + ) -> None: """ Creates a VEth instance. @@ -239,7 +267,7 @@ class Veth(CoreInterface): if start: self.startup() - def startup(self): + def startup(self) -> None: """ Interface startup logic. @@ -250,7 +278,7 @@ class Veth(CoreInterface): self.net_client.device_up(self.localname) self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Interface shutdown logic. @@ -280,8 +308,15 @@ class TunTap(CoreInterface): """ def __init__( - self, session, node, name, localname, mtu=1500, server=None, start=True - ): + self, + session: "Session", + node: "CoreNode", + name: str, + localname: str, + mtu: int = 1500, + server: "DistributedServer" = None, + start: bool = True, + ) -> None: """ Create a TunTap instance. @@ -301,7 +336,7 @@ class TunTap(CoreInterface): if start: self.startup() - def startup(self): + def startup(self) -> None: """ Startup logic for a tunnel tap. @@ -315,7 +350,7 @@ class TunTap(CoreInterface): # self.install() self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Shutdown functionality for a tunnel tap. @@ -331,7 +366,9 @@ class TunTap(CoreInterface): self.up = False - def waitfor(self, func, attempts=10, maxretrydelay=0.25): + def waitfor( + self, func: Callable[[], int], attempts: int = 10, maxretrydelay: float = 0.25 + ) -> bool: """ Wait for func() to return zero with exponential backoff. @@ -362,7 +399,7 @@ class TunTap(CoreInterface): return result - def waitfordevicelocal(self): + def waitfordevicelocal(self) -> None: """ Check for presence of a local device - tap device may not appear right away waits @@ -381,7 +418,7 @@ class TunTap(CoreInterface): self.waitfor(localdevexists) - def waitfordevicenode(self): + def waitfordevicenode(self) -> None: """ Check for presence of a node device - tap device may not appear right away waits. @@ -412,7 +449,7 @@ class TunTap(CoreInterface): else: raise RuntimeError("node device failed to exist") - def install(self): + def install(self) -> None: """ Install this TAP into its namespace. This is not done from the startup() method but called at a later time when a userspace @@ -428,7 +465,7 @@ class TunTap(CoreInterface): self.node.node_net_client.device_name(self.localname, self.name) self.node.node_net_client.device_up(self.name) - def setaddrs(self): + def setaddrs(self) -> None: """ Set interface addresses based on self.addrlist. @@ -448,18 +485,18 @@ class GreTap(CoreInterface): def __init__( self, - node=None, - name=None, - session=None, - mtu=1458, - remoteip=None, - _id=None, - localip=None, - ttl=255, - key=None, - start=True, - server=None, - ): + node: "CoreNode" = None, + name: str = None, + session: "Session" = None, + mtu: int = 1458, + remoteip: str = None, + _id: int = None, + localip: str = None, + ttl: int = 255, + key: int = None, + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ Creates a GreTap instance. @@ -497,7 +534,7 @@ class GreTap(CoreInterface): self.net_client.device_up(self.localname) self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Shutdown logic for a GreTap. @@ -512,7 +549,7 @@ class GreTap(CoreInterface): self.localname = None - def data(self, message_type): + def data(self, message_type: int) -> None: """ Data for a gre tap. @@ -521,7 +558,7 @@ class GreTap(CoreInterface): """ return None - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List: """ Retrieve link data. diff --git a/daemon/core/nodes/lxd.py b/daemon/core/nodes/lxd.py index 2d7a6d3d..20e2bc43 100644 --- a/daemon/core/nodes/lxd.py +++ b/daemon/core/nodes/lxd.py @@ -3,27 +3,33 @@ import logging import os import time from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Callable, Dict from core import utils +from core.emulator.distributed import DistributedServer from core.emulator.enumerations import NodeTypes from core.errors import CoreCommandError from core.nodes.base import CoreNode +from core.nodes.interface import CoreInterface + +if TYPE_CHECKING: + from core.emulator.session import Session class LxdClient: - def __init__(self, name, image, run): + def __init__(self, name: str, image: str, run: Callable[..., str]) -> None: self.name = name self.image = image self.run = run self.pid = None - def create_container(self): + def create_container(self) -> int: self.run(f"lxc launch {self.image} {self.name}") data = self.get_info() self.pid = data["state"]["pid"] return self.pid - def get_info(self): + def get_info(self) -> Dict: args = f"lxc list {self.name} --format json" output = self.run(args) data = json.loads(output) @@ -31,27 +37,27 @@ class LxdClient: raise CoreCommandError(-1, args, f"LXC({self.name}) not present") return data[0] - def is_alive(self): + def is_alive(self) -> bool: try: data = self.get_info() return data["state"]["status"] == "Running" except CoreCommandError: return False - def stop_container(self): + def stop_container(self) -> None: self.run(f"lxc delete --force {self.name}") - def create_cmd(self, cmd): + def create_cmd(self, cmd: str) -> str: return f"lxc exec -nT {self.name} -- {cmd}" - def create_ns_cmd(self, cmd): + def create_ns_cmd(self, cmd: str) -> str: return f"nsenter -t {self.pid} -m -u -i -p -n {cmd}" - def check_cmd(self, cmd, wait=True, shell=False): + def check_cmd(self, cmd: str, wait: bool = True, shell: bool = False) -> str: args = self.create_cmd(cmd) return utils.cmd(args, wait=wait, shell=shell) - def copy_file(self, source, destination): + def copy_file(self, source: str, destination: str) -> None: if destination[0] != "/": destination = os.path.join("/root/", destination) @@ -64,15 +70,15 @@ class LxcNode(CoreNode): def __init__( self, - session, - _id=None, - name=None, - nodedir=None, - bootsh="boot.sh", - start=True, - server=None, - image=None, - ): + session: "Session", + _id: int = None, + name: str = None, + nodedir: str = None, + bootsh: str = "boot.sh", + start: bool = True, + server: DistributedServer = None, + image: str = None, + ) -> None: """ Create a LxcNode instance. @@ -91,7 +97,7 @@ class LxcNode(CoreNode): self.image = image super().__init__(session, _id, name, nodedir, bootsh, start, server) - def alive(self): + def alive(self) -> bool: """ Check if the node is alive. @@ -100,7 +106,7 @@ class LxcNode(CoreNode): """ return self.client.is_alive() - def startup(self): + def startup(self) -> None: """ Startup logic. @@ -114,7 +120,7 @@ class LxcNode(CoreNode): self.pid = self.client.create_container() self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Shutdown logic. @@ -129,7 +135,7 @@ class LxcNode(CoreNode): self.client.stop_container() self.up = False - def termcmdstring(self, sh="/bin/sh"): + def termcmdstring(self, sh: str = "/bin/sh") -> str: """ Create a terminal command string. @@ -138,7 +144,7 @@ class LxcNode(CoreNode): """ return f"lxc exec {self.name} -- {sh}" - def privatedir(self, path): + def privatedir(self, path: str) -> None: """ Create a private directory. @@ -147,9 +153,9 @@ class LxcNode(CoreNode): """ logging.info("creating node dir: %s", path) args = f"mkdir -p {path}" - return self.cmd(args) + self.cmd(args) - def mount(self, source, target): + def mount(self, source: str, target: str) -> None: """ Create and mount a directory. @@ -161,7 +167,7 @@ class LxcNode(CoreNode): logging.debug("mounting source(%s) target(%s)", source, target) raise Exception("not supported") - def nodefile(self, filename, contents, mode=0o644): + def nodefile(self, filename: str, contents: str, mode: int = 0o644) -> None: """ Create a node file with a given mode. @@ -188,7 +194,7 @@ class LxcNode(CoreNode): os.unlink(temp.name) logging.debug("node(%s) added file: %s; mode: 0%o", self.name, filename, mode) - def nodefilecopy(self, filename, srcfilename, mode=None): + def nodefilecopy(self, filename: str, srcfilename: str, mode: int = None) -> None: """ Copy a file to a node, following symlinks and preserving metadata. Change file mode if specified. @@ -214,7 +220,7 @@ class LxcNode(CoreNode): self.client.copy_file(source, filename) self.cmd(f"chmod {mode:o} {filename}") - def addnetif(self, netif, ifindex): + def addnetif(self, netif: CoreInterface, ifindex: int) -> None: super().addnetif(netif, ifindex) # adding small delay to allow time for adding addresses to work correctly time.sleep(0.5) diff --git a/daemon/core/nodes/netclient.py b/daemon/core/nodes/netclient.py index 8053a0e8..1236468f 100644 --- a/daemon/core/nodes/netclient.py +++ b/daemon/core/nodes/netclient.py @@ -2,30 +2,17 @@ Clients for dealing with bridge/interface commands. """ import json +from typing import Callable from core.constants import ETHTOOL_BIN, IP_BIN, OVS_BIN, TC_BIN -def get_net_client(use_ovs, run): - """ - Retrieve desired net client for running network commands. - - :param bool use_ovs: True for OVS bridges, False for Linux bridges - :param func run: function used to run net client commands - :return: net client class - """ - if use_ovs: - return OvsNetClient(run) - else: - return LinuxNetClient(run) - - class LinuxNetClient: """ Client for creating Linux bridges and ip interfaces for nodes. """ - def __init__(self, run): + def __init__(self, run: Callable[..., str]) -> None: """ Create LinuxNetClient instance. @@ -33,7 +20,7 @@ class LinuxNetClient: """ self.run = run - def set_hostname(self, name): + def set_hostname(self, name: str) -> None: """ Set network hostname. @@ -42,7 +29,7 @@ class LinuxNetClient: """ self.run(f"hostname {name}") - def create_route(self, route, device): + def create_route(self, route: str, device: str) -> None: """ Create a new route for a device. @@ -52,7 +39,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} route add {route} dev {device}") - def device_up(self, device): + def device_up(self, device: str) -> None: """ Bring a device up. @@ -61,7 +48,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set {device} up") - def device_down(self, device): + def device_down(self, device: str) -> None: """ Bring a device down. @@ -70,7 +57,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set {device} down") - def device_name(self, device, name): + def device_name(self, device: str, name: str) -> None: """ Set a device name. @@ -80,7 +67,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set {device} name {name}") - def device_show(self, device): + def device_show(self, device: str) -> str: """ Show information for a device. @@ -90,7 +77,7 @@ class LinuxNetClient: """ return self.run(f"{IP_BIN} link show {device}") - def get_mac(self, device): + def get_mac(self, device: str) -> str: """ Retrieve MAC address for a given device. @@ -100,7 +87,7 @@ class LinuxNetClient: """ return self.run(f"cat /sys/class/net/{device}/address") - def get_ifindex(self, device): + def get_ifindex(self, device: str) -> str: """ Retrieve ifindex for a given device. @@ -110,7 +97,7 @@ class LinuxNetClient: """ return self.run(f"cat /sys/class/net/{device}/ifindex") - def device_ns(self, device, namespace): + def device_ns(self, device: str, namespace: str) -> None: """ Set netns for a device. @@ -120,7 +107,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set {device} netns {namespace}") - def device_flush(self, device): + def device_flush(self, device: str) -> None: """ Flush device addresses. @@ -132,7 +119,7 @@ class LinuxNetClient: shell=True, ) - def device_mac(self, device, mac): + def device_mac(self, device: str, mac: str) -> None: """ Set MAC address for a device. @@ -142,7 +129,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set dev {device} address {mac}") - def delete_device(self, device): + def delete_device(self, device: str) -> None: """ Delete device. @@ -151,7 +138,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link delete {device}") - def delete_tc(self, device): + def delete_tc(self, device: str) -> None: """ Remove traffic control settings for a device. @@ -160,7 +147,7 @@ class LinuxNetClient: """ self.run(f"{TC_BIN} qdisc delete dev {device} root") - def checksums_off(self, interface_name): + def checksums_off(self, interface_name: str) -> None: """ Turns interface checksums off. @@ -169,7 +156,7 @@ class LinuxNetClient: """ self.run(f"{ETHTOOL_BIN} -K {interface_name} rx off tx off") - def create_address(self, device, address, broadcast=None): + def create_address(self, device: str, address: str, broadcast: str = None) -> None: """ Create address for a device. @@ -185,7 +172,7 @@ class LinuxNetClient: else: self.run(f"{IP_BIN} address add {address} dev {device}") - def delete_address(self, device, address): + def delete_address(self, device: str, address: str) -> None: """ Delete an address from a device. @@ -195,7 +182,7 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} address delete {address} dev {device}") - def create_veth(self, name, peer): + def create_veth(self, name: str, peer: str) -> None: """ Create a veth pair. @@ -205,7 +192,9 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link add name {name} type veth peer name {peer}") - def create_gretap(self, device, address, local, ttl, key): + def create_gretap( + self, device: str, address: str, local: str, ttl: int, key: int + ) -> None: """ Create a GRE tap on a device. @@ -225,7 +214,7 @@ class LinuxNetClient: cmd += f" key {key}" self.run(cmd) - def create_bridge(self, name): + def create_bridge(self, name: str) -> None: """ Create a Linux bridge and bring it up. @@ -238,7 +227,7 @@ class LinuxNetClient: self.run(f"{IP_BIN} link set {name} type bridge mcast_snooping 0") self.device_up(name) - def delete_bridge(self, name): + def delete_bridge(self, name: str) -> None: """ Bring down and delete a Linux bridge. @@ -248,7 +237,7 @@ class LinuxNetClient: self.device_down(name) self.run(f"{IP_BIN} link delete {name} type bridge") - def create_interface(self, bridge_name, interface_name): + def create_interface(self, bridge_name: str, interface_name: str) -> None: """ Create an interface associated with a Linux bridge. @@ -259,7 +248,7 @@ class LinuxNetClient: self.run(f"{IP_BIN} link set dev {interface_name} master {bridge_name}") self.device_up(interface_name) - def delete_interface(self, bridge_name, interface_name): + def delete_interface(self, bridge_name: str, interface_name: str) -> None: """ Delete an interface associated with a Linux bridge. @@ -269,11 +258,12 @@ class LinuxNetClient: """ self.run(f"{IP_BIN} link set dev {interface_name} nomaster") - def existing_bridges(self, _id): + def existing_bridges(self, _id: int) -> bool: """ Checks if there are any existing Linux bridges for a node. :param _id: node id to check bridges for + :return: True if there are existing bridges, False otherwise """ output = self.run(f"{IP_BIN} -j link show type bridge") bridges = json.loads(output) @@ -286,7 +276,7 @@ class LinuxNetClient: return True return False - def disable_mac_learning(self, name): + def disable_mac_learning(self, name: str) -> None: """ Disable mac learning for a Linux bridge. @@ -301,7 +291,7 @@ class OvsNetClient(LinuxNetClient): Client for creating OVS bridges and ip interfaces for nodes. """ - def create_bridge(self, name): + def create_bridge(self, name: str) -> None: """ Create a OVS bridge and bring it up. @@ -314,7 +304,7 @@ class OvsNetClient(LinuxNetClient): self.run(f"{OVS_BIN} set bridge {name} other_config:stp-forward-delay=4") self.device_up(name) - def delete_bridge(self, name): + def delete_bridge(self, name: str) -> None: """ Bring down and delete a OVS bridge. @@ -324,7 +314,7 @@ class OvsNetClient(LinuxNetClient): self.device_down(name) self.run(f"{OVS_BIN} del-br {name}") - def create_interface(self, bridge_name, interface_name): + def create_interface(self, bridge_name: str, interface_name: str) -> None: """ Create an interface associated with a network bridge. @@ -335,7 +325,7 @@ class OvsNetClient(LinuxNetClient): self.run(f"{OVS_BIN} add-port {bridge_name} {interface_name}") self.device_up(interface_name) - def delete_interface(self, bridge_name, interface_name): + def delete_interface(self, bridge_name: str, interface_name: str) -> None: """ Delete an interface associated with a OVS bridge. @@ -345,11 +335,12 @@ class OvsNetClient(LinuxNetClient): """ self.run(f"{OVS_BIN} del-port {bridge_name} {interface_name}") - def existing_bridges(self, _id): + def existing_bridges(self, _id: int) -> bool: """ Checks if there are any existing OVS bridges for a node. :param _id: node id to check bridges for + :return: True if there are existing bridges, False otherwise """ output = self.run(f"{OVS_BIN} list-br") if output: @@ -359,7 +350,7 @@ class OvsNetClient(LinuxNetClient): return True return False - def disable_mac_learning(self, name): + def disable_mac_learning(self, name: str) -> None: """ Disable mac learning for a OVS bridge. @@ -367,3 +358,17 @@ class OvsNetClient(LinuxNetClient): :return: nothing """ self.run(f"{OVS_BIN} set bridge {name} other_config:mac-aging-time=0") + + +def get_net_client(use_ovs: bool, run: Callable[..., str]) -> LinuxNetClient: + """ + Retrieve desired net client for running network commands. + + :param bool use_ovs: True for OVS bridges, False for Linux bridges + :param func run: function used to run net client commands + :return: net client class + """ + if use_ovs: + return OvsNetClient(run) + else: + return LinuxNetClient(run) diff --git a/daemon/core/nodes/network.py b/daemon/core/nodes/network.py index b5199062..c7c36a1e 100644 --- a/daemon/core/nodes/network.py +++ b/daemon/core/nodes/network.py @@ -5,18 +5,26 @@ Defines network nodes used within core. import logging import threading import time +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type import netaddr from core import utils from core.constants import EBTABLES_BIN, TC_BIN -from core.emulator.data import LinkData +from core.emulator.data import LinkData, NodeData from core.emulator.enumerations import LinkTypes, NodeTypes, RegisterTlvs from core.errors import CoreCommandError, CoreError from core.nodes.base import CoreNetworkBase -from core.nodes.interface import GreTap, Veth +from core.nodes.interface import CoreInterface, GreTap, Veth from core.nodes.netclient import get_net_client +if TYPE_CHECKING: + from core.emulator.distributed import DistributedServer + from core.emulator.session import Session + from core.location.mobility import WirelessModel + + WirelessModelType = Type[WirelessModel] + ebtables_lock = threading.Lock() @@ -32,7 +40,7 @@ class EbtablesQueue: # ebtables atomic_file = "/tmp/pycore.ebtables.atomic" - def __init__(self): + def __init__(self) -> None: """ Initialize the helper class, but don't start the update thread until a WLAN is instantiated. @@ -49,7 +57,7 @@ class EbtablesQueue: # using this queue self.last_update_time = {} - def startupdateloop(self, wlan): + def startupdateloop(self, wlan: "CoreNetwork") -> None: """ Kick off the update loop; only needs to be invoked once. @@ -66,7 +74,7 @@ class EbtablesQueue: self.updatethread.daemon = True self.updatethread.start() - def stopupdateloop(self, wlan): + def stopupdateloop(self, wlan: "CoreNetwork") -> None: """ Kill the update loop thread if there are no more WLANs using it. @@ -88,17 +96,17 @@ class EbtablesQueue: self.updatethread.join() self.updatethread = None - def ebatomiccmd(self, cmd): + def ebatomiccmd(self, cmd: str) -> str: """ Helper for building ebtables atomic file command list. :param str cmd: ebtable command :return: ebtable atomic command - :rtype: list[str] + :rtype: str """ return f"{EBTABLES_BIN} --atomic-file {self.atomic_file} {cmd}" - def lastupdate(self, wlan): + def lastupdate(self, wlan: "CoreNetwork") -> float: """ Return the time elapsed since this WLAN was last updated. @@ -114,7 +122,7 @@ class EbtablesQueue: return elapsed - def updated(self, wlan): + def updated(self, wlan: "CoreNetwork") -> None: """ Keep track of when this WLAN was last updated. @@ -124,7 +132,7 @@ class EbtablesQueue: self.last_update_time[wlan] = time.monotonic() self.updates.remove(wlan) - def updateloop(self): + def updateloop(self) -> None: """ Thread target that looks for WLANs needing update, and rate limits the amount of ebtables activity. Only one userspace program @@ -153,7 +161,7 @@ class EbtablesQueue: time.sleep(self.rate) - def ebcommit(self, wlan): + def ebcommit(self, wlan: "CoreNetwork") -> None: """ Perform ebtables atomic commit using commands built in the self.cmds list. @@ -178,7 +186,7 @@ class EbtablesQueue: except CoreCommandError: logging.exception("error removing atomic file: %s", self.atomic_file) - def ebchange(self, wlan): + def ebchange(self, wlan: "CoreNetwork") -> None: """ Flag a change to the given WLAN's _linked dict, so the ebtables chain will be rebuilt at the next interval. @@ -189,7 +197,7 @@ class EbtablesQueue: if wlan not in self.updates: self.updates.append(wlan) - def buildcmds(self, wlan): + def buildcmds(self, wlan: "CoreNetwork") -> None: """ Inspect a _linked dict from a wlan, and rebuild the ebtables chain for that WLAN. @@ -231,7 +239,7 @@ class EbtablesQueue: ebq = EbtablesQueue() -def ebtablescmds(call, cmds): +def ebtablescmds(call: Callable[..., str], cmds: List[str]) -> None: """ Run ebtable commands. @@ -252,8 +260,14 @@ class CoreNetwork(CoreNetworkBase): policy = "DROP" def __init__( - self, session, _id=None, name=None, start=True, server=None, policy=None - ): + self, + session: "Session", + _id: int = None, + name: str = None, + start: bool = True, + server: "DistributedServer" = None, + policy: str = None, + ) -> None: """ Creates a LxBrNet instance. @@ -279,7 +293,14 @@ class CoreNetwork(CoreNetworkBase): self.startup() ebq.startupdateloop(self) - def host_cmd(self, args, env=None, cwd=None, wait=True, shell=False): + def host_cmd( + self, + args: str, + env: Dict[str, str] = None, + cwd: str = None, + wait: bool = True, + shell: bool = False, + ) -> str: """ Runs a command that is used to configure and setup the network on the host system and all configured distributed servers. @@ -298,7 +319,7 @@ class CoreNetwork(CoreNetworkBase): self.session.distributed.execute(lambda x: x.remote_cmd(args, env, cwd, wait)) return output - def startup(self): + def startup(self) -> None: """ Linux bridge starup logic. @@ -309,7 +330,7 @@ class CoreNetwork(CoreNetworkBase): self.has_ebtables_chain = False self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Linux bridge shutdown logic. @@ -340,18 +361,18 @@ class CoreNetwork(CoreNetworkBase): del self.session self.up = False - def attach(self, netif): + def attach(self, netif: CoreInterface) -> None: """ Attach a network interface. - :param core.nodes.interface.Veth netif: network interface to attach + :param core.nodes.interface.CoreInterface netif: network interface to attach :return: nothing """ if self.up: netif.net_client.create_interface(self.brname, netif.localname) super().attach(netif) - def detach(self, netif): + def detach(self, netif: CoreInterface) -> None: """ Detach a network interface. @@ -362,7 +383,7 @@ class CoreNetwork(CoreNetworkBase): netif.net_client.delete_interface(self.brname, netif.localname) super().detach(netif) - def linked(self, netif1, netif2): + def linked(self, netif1: CoreInterface, netif2: CoreInterface) -> bool: """ Determine if the provided network interfaces are linked. @@ -391,9 +412,9 @@ class CoreNetwork(CoreNetworkBase): return linked - def unlink(self, netif1, netif2): + def unlink(self, netif1: CoreInterface, netif2: CoreInterface) -> None: """ - Unlink two PyCoreNetIfs, resulting in adding or removing ebtables + Unlink two interfaces, resulting in adding or removing ebtables filtering rules. :param core.nodes.interface.CoreInterface netif1: interface one @@ -407,9 +428,9 @@ class CoreNetwork(CoreNetworkBase): ebq.ebchange(self) - def link(self, netif1, netif2): + def link(self, netif1: CoreInterface, netif2: CoreInterface) -> None: """ - Link two PyCoreNetIfs together, resulting in adding or removing + Link two interfaces together, resulting in adding or removing ebtables filtering rules. :param core.nodes.interface.CoreInterface netif1: interface one @@ -425,19 +446,19 @@ class CoreNetwork(CoreNetworkBase): def linkconfig( self, - netif, - bw=None, - delay=None, - loss=None, - duplicate=None, - jitter=None, - netif2=None, - devname=None, - ): + netif: CoreInterface, + bw: float = None, + delay: float = None, + loss: float = None, + duplicate: float = None, + jitter: float = None, + netif2: float = None, + devname: str = None, + ) -> None: """ Configure link parameters by applying tc queuing disciplines on the interface. - :param core.nodes.interface.Veth netif: interface one + :param core.nodes.interface.CoreInterface netif: interface one :param bw: bandwidth to set to :param delay: packet delay to set to :param loss: packet loss to set to @@ -520,14 +541,14 @@ class CoreNetwork(CoreNetworkBase): netif.host_cmd(cmd) netif.setparam("has_netem", True) - def linknet(self, net): + def linknet(self, net: CoreNetworkBase) -> CoreInterface: """ Link this bridge with another by creating a veth pair and installing each device into each bridge. - :param core.netns.vnet.LxBrNet net: network to link with + :param core.nodes.base.CoreNetworkBase net: network to link with :return: created interface - :rtype: Veth + :rtype: core.nodes.interface.CoreInterface """ sessionid = self.session.short_session_id() try: @@ -561,7 +582,7 @@ class CoreNetwork(CoreNetworkBase): netif.othernet = net return netif - def getlinknetif(self, net): + def getlinknetif(self, net: CoreNetworkBase) -> Optional[CoreInterface]: """ Return the interface of that links this net with another net (that were linked using linknet()). @@ -573,10 +594,9 @@ class CoreNetwork(CoreNetworkBase): for netif in self.netifs(): if hasattr(netif, "othernet") and netif.othernet == net: return netif - return None - def addrconfig(self, addrlist): + def addrconfig(self, addrlist: List[str]) -> None: """ Set addresses on the bridge. @@ -598,17 +618,17 @@ class GreTapBridge(CoreNetwork): def __init__( self, - session, - remoteip=None, - _id=None, - name=None, - policy="ACCEPT", - localip=None, - ttl=255, - key=None, - start=True, - server=None, - ): + session: "Session", + remoteip: str = None, + _id: int = None, + name: str = None, + policy: str = "ACCEPT", + localip: str = None, + ttl: int = 255, + key: int = None, + start: bool = True, + server: "DistributedServer" = None, + ) -> None: """ Create a GreTapBridge instance. @@ -647,7 +667,7 @@ class GreTapBridge(CoreNetwork): if start: self.startup() - def startup(self): + def startup(self) -> None: """ Creates a bridge and adds the gretap device to it. @@ -657,7 +677,7 @@ class GreTapBridge(CoreNetwork): if self.gretap: self.attach(self.gretap) - def shutdown(self): + def shutdown(self) -> None: """ Detach the gretap device and remove the bridge. @@ -669,7 +689,7 @@ class GreTapBridge(CoreNetwork): self.gretap = None super().shutdown() - def addrconfig(self, addrlist): + def addrconfig(self, addrlist: List[str]) -> None: """ Set the remote tunnel endpoint. This is a one-time method for creating the GreTap device, which requires the remoteip at startup. @@ -694,7 +714,7 @@ class GreTapBridge(CoreNetwork): ) self.attach(self.gretap) - def setkey(self, key): + def setkey(self, key: int) -> None: """ Set the GRE key used for the GreTap device. This needs to be set prior to instantiating the GreTap device (before addrconfig). @@ -722,17 +742,17 @@ class CtrlNet(CoreNetwork): def __init__( self, - session, - _id=None, - name=None, - prefix=None, - hostid=None, - start=True, - server=None, - assign_address=True, - updown_script=None, - serverintf=None, - ): + session: "Session", + _id: int = None, + name: str = None, + prefix: str = None, + hostid: int = None, + start: bool = True, + server: "DistributedServer" = None, + assign_address: bool = True, + updown_script: str = None, + serverintf: CoreInterface = None, + ) -> None: """ Creates a CtrlNet instance. @@ -756,7 +776,7 @@ class CtrlNet(CoreNetwork): self.serverintf = serverintf super().__init__(session, _id, name, start, server) - def add_addresses(self, index): + def add_addresses(self, index: int) -> None: """ Add addresses used for created control networks, @@ -777,7 +797,7 @@ class CtrlNet(CoreNetwork): net_client = get_net_client(use_ovs, server.remote_cmd) net_client.create_address(self.brname, current) - def startup(self): + def startup(self) -> None: """ Startup functionality for the control network. @@ -806,7 +826,7 @@ class CtrlNet(CoreNetwork): if self.serverintf: self.net_client.create_interface(self.brname, self.serverintf) - def shutdown(self): + def shutdown(self) -> None: """ Control network shutdown. @@ -835,7 +855,7 @@ class CtrlNet(CoreNetwork): super().shutdown() - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List[LinkData]: """ Do not include CtrlNet in link messages describing this session. @@ -853,11 +873,11 @@ class PtpNet(CoreNetwork): policy = "ACCEPT" - def attach(self, netif): + def attach(self, netif: CoreInterface) -> None: """ Attach a network interface, but limit attachment to two interfaces. - :param core.netns.vif.VEth netif: network interface + :param core.nodes.interface.CoreInterface netif: network interface :return: nothing """ if len(self._netif) >= 2: @@ -866,7 +886,14 @@ class PtpNet(CoreNetwork): ) super().attach(netif) - def data(self, message_type, lat=None, lon=None, alt=None): + def data( + self, + message_type: int, + lat: float = None, + lon: float = None, + alt: float = None, + source: str = None, + ) -> NodeData: """ Do not generate a Node Message for point-to-point links. They are built using a link message instead. @@ -875,12 +902,13 @@ class PtpNet(CoreNetwork): :param float lat: latitude :param float lon: longitude :param float alt: altitude + :param str source: source of node data :return: node data object :rtype: core.emulator.data.NodeData """ return None - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List[LinkData]: """ Build CORE API TLVs for a point-to-point link. One Link message describes this network. @@ -997,7 +1025,7 @@ class HubNode(CoreNetwork): policy = "ACCEPT" type = "hub" - def startup(self): + def startup(self) -> None: """ Startup for a hub node, that disables mac learning after normal startup. @@ -1018,8 +1046,14 @@ class WlanNode(CoreNetwork): type = "wlan" def __init__( - self, session, _id=None, name=None, start=True, server=None, policy=None - ): + self, + session: "Session", + _id: int = None, + name: str = None, + start: bool = True, + server: "DistributedServer" = None, + policy: str = None, + ) -> None: """ Create a WlanNode instance. @@ -1036,7 +1070,7 @@ class WlanNode(CoreNetwork): self.model = None self.mobility = None - def startup(self): + def startup(self) -> None: """ Startup for a wlan node, that disables mac learning after normal startup. @@ -1045,11 +1079,11 @@ class WlanNode(CoreNetwork): super().startup() self.net_client.disable_mac_learning(self.brname) - def attach(self, netif): + def attach(self, netif: CoreInterface) -> None: """ Attach a network interface. - :param core.nodes.interface.Veth netif: network interface + :param core.nodes.interface.CoreInterface netif: network interface :return: nothing """ super().attach(netif) @@ -1061,7 +1095,7 @@ class WlanNode(CoreNetwork): # invokes any netif.poshook netif.setposition(x, y, z) - def setmodel(self, model, config): + def setmodel(self, model: "WirelessModelType", config: Dict[str, str]): """ Sets the mobility and wireless model. @@ -1082,12 +1116,12 @@ class WlanNode(CoreNetwork): self.mobility = model(session=self.session, _id=self.id) self.mobility.update_config(config) - def update_mobility(self, config): + def update_mobility(self, config: Dict[str, str]) -> None: if not self.mobility: raise ValueError(f"no mobility set to update for node({self.id})") self.mobility.update_config(config) - def updatemodel(self, config): + def updatemodel(self, config: Dict[str, str]) -> None: if not self.model: raise ValueError(f"no model set to update for node({self.id})") logging.debug( @@ -1099,7 +1133,7 @@ class WlanNode(CoreNetwork): x, y, z = netif.node.position.get() netif.poshook(netif, x, y, z) - def all_link_data(self, flags): + def all_link_data(self, flags: int) -> List[LinkData]: """ Retrieve all link data. diff --git a/daemon/core/nodes/physical.py b/daemon/core/nodes/physical.py index 1d470b98..a963721f 100644 --- a/daemon/core/nodes/physical.py +++ b/daemon/core/nodes/physical.py @@ -5,20 +5,31 @@ PhysicalNode class for including real systems in the emulated network. import logging import os import threading +from typing import IO, TYPE_CHECKING, List, Optional from core import utils from core.constants import MOUNT_BIN, UMOUNT_BIN +from core.emulator.distributed import DistributedServer from core.emulator.enumerations import NodeTypes from core.errors import CoreCommandError, CoreError -from core.nodes.base import CoreNodeBase -from core.nodes.interface import CoreInterface +from core.nodes.base import CoreNetworkBase, CoreNodeBase +from core.nodes.interface import CoreInterface, Veth from core.nodes.network import CoreNetwork, GreTap +if TYPE_CHECKING: + from core.emulator.session import Session + class PhysicalNode(CoreNodeBase): def __init__( - self, session, _id=None, name=None, nodedir=None, start=True, server=None - ): + self, + session, + _id: int = None, + name: str = None, + nodedir: str = None, + start: bool = True, + server: DistributedServer = None, + ) -> None: super().__init__(session, _id, name, start, server) if not self.server: raise CoreError("physical nodes must be assigned to a remote server") @@ -29,11 +40,11 @@ class PhysicalNode(CoreNodeBase): if start: self.startup() - def startup(self): + def startup(self) -> None: with self.lock: self.makenodedir() - def shutdown(self): + def shutdown(self) -> None: if not self.up: return @@ -47,7 +58,7 @@ class PhysicalNode(CoreNodeBase): self.rmnodedir() - def termcmdstring(self, sh="/bin/sh"): + def termcmdstring(self, sh: str = "/bin/sh") -> str: """ Create a terminal command string. @@ -56,7 +67,7 @@ class PhysicalNode(CoreNodeBase): """ return sh - def sethwaddr(self, ifindex, addr): + def sethwaddr(self, ifindex: int, addr: str) -> None: """ Set hardware address for an interface. @@ -71,7 +82,7 @@ class PhysicalNode(CoreNodeBase): if self.up: self.net_client.device_mac(interface.name, addr) - def addaddr(self, ifindex, addr): + def addaddr(self, ifindex: int, addr: str) -> None: """ Add an address to an interface. @@ -85,9 +96,13 @@ class PhysicalNode(CoreNodeBase): self.net_client.create_address(interface.name, addr) interface.addaddr(addr) - def deladdr(self, ifindex, addr): + def deladdr(self, ifindex: int, addr: str) -> None: """ Delete an address from an interface. + + :param int ifindex: index of interface to delete + :param str addr: address to delete + :return: nothing """ interface = self._netif[ifindex] @@ -99,7 +114,9 @@ class PhysicalNode(CoreNodeBase): if self.up: self.net_client.delete_address(interface.name, str(addr)) - def adoptnetif(self, netif, ifindex, hwaddr, addrlist): + def adoptnetif( + self, netif: CoreInterface, ifindex: int, hwaddr: str, addrlist: List[str] + ) -> None: """ When a link message is received linking this node to another part of the emulation, no new interface is created; instead, adopt the @@ -127,18 +144,17 @@ class PhysicalNode(CoreNodeBase): def linkconfig( self, - netif, - bw=None, - delay=None, - loss=None, - duplicate=None, - jitter=None, - netif2=None, - ): + netif: CoreInterface, + bw: float = None, + delay: float = None, + loss: float = None, + duplicate: float = None, + jitter: float = None, + netif2: CoreInterface = None, + ) -> None: """ - Apply tc queing disciplines using LxBrNet.linkconfig() + Apply tc queing disciplines using linkconfig. """ - # borrow the tc qdisc commands from LxBrNet.linkconfig() linux_bridge = CoreNetwork(session=self.session, start=False) linux_bridge.up = True linux_bridge.linkconfig( @@ -152,7 +168,7 @@ class PhysicalNode(CoreNodeBase): ) del linux_bridge - def newifindex(self): + def newifindex(self) -> int: with self.lock: while self.ifindex in self._netif: self.ifindex += 1 @@ -160,7 +176,14 @@ class PhysicalNode(CoreNodeBase): self.ifindex += 1 return ifindex - def newnetif(self, net=None, addrlist=None, hwaddr=None, ifindex=None, ifname=None): + def newnetif( + self, + net: Veth = None, + addrlist: List[str] = None, + hwaddr: str = None, + ifindex: int = None, + ifname: str = None, + ) -> int: logging.info("creating interface") if not addrlist: addrlist = [] @@ -186,7 +209,7 @@ class PhysicalNode(CoreNodeBase): self.adoptnetif(netif, ifindex, hwaddr, addrlist) return ifindex - def privatedir(self, path): + def privatedir(self, path: str) -> None: if path[0] != "/": raise ValueError(f"path not fully qualified: {path}") hostpath = os.path.join( @@ -195,21 +218,21 @@ class PhysicalNode(CoreNodeBase): os.mkdir(hostpath) self.mount(hostpath, path) - def mount(self, source, target): + def mount(self, source: str, target: str) -> None: source = os.path.abspath(source) logging.info("mounting %s at %s", source, target) os.makedirs(target) self.host_cmd(f"{MOUNT_BIN} --bind {source} {target}", cwd=self.nodedir) self._mounts.append((source, target)) - def umount(self, target): + def umount(self, target: str) -> None: logging.info("unmounting '%s'", target) try: self.host_cmd(f"{UMOUNT_BIN} -l {target}", cwd=self.nodedir) except CoreCommandError: logging.exception("unmounting failed for %s", target) - def opennodefile(self, filename, mode="w"): + def opennodefile(self, filename: str, mode: str = "w") -> IO: dirname, basename = os.path.split(filename) if not basename: raise ValueError("no basename for filename: " + filename) @@ -225,13 +248,13 @@ class PhysicalNode(CoreNodeBase): hostfilename = os.path.join(dirname, basename) return open(hostfilename, mode) - def nodefile(self, filename, contents, mode=0o644): + def nodefile(self, filename: str, contents: str, mode: int = 0o644) -> None: with self.opennodefile(filename, "w") as node_file: node_file.write(contents) os.chmod(node_file.name, mode) logging.info("created nodefile: '%s'; mode: 0%o", node_file.name, mode) - def cmd(self, args, wait=True): + def cmd(self, args: str, wait: bool = True, shell: bool = False) -> str: return self.host_cmd(args, wait=wait) @@ -244,7 +267,15 @@ class Rj45Node(CoreNodeBase, CoreInterface): apitype = NodeTypes.RJ45.value type = "rj45" - def __init__(self, session, _id=None, name=None, mtu=1500, start=True, server=None): + def __init__( + self, + session: "Session", + _id: int = None, + name: str = None, + mtu: int = 1500, + start: bool = True, + server: DistributedServer = None, + ) -> None: """ Create an RJ45Node instance. @@ -270,7 +301,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): if start: self.startup() - def startup(self): + def startup(self) -> None: """ Set the interface in the up state. @@ -282,7 +313,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): self.net_client.device_up(self.localname) self.up = True - def shutdown(self): + def shutdown(self) -> None: """ Bring the interface down. Remove any addresses and queuing disciplines. @@ -304,18 +335,18 @@ class Rj45Node(CoreNodeBase, CoreInterface): # TODO: issue in that both classes inherited from provide the same method with # different signatures - def attachnet(self, net): + def attachnet(self, net: CoreNetworkBase) -> None: """ Attach a network. - :param core.coreobj.PyCoreNet net: network to attach + :param core.nodes.base.CoreNetworkBase net: network to attach :return: nothing """ CoreInterface.attachnet(self, net) # TODO: issue in that both classes inherited from provide the same method with # different signatures - def detachnet(self): + def detachnet(self) -> None: """ Detach a network. @@ -323,7 +354,14 @@ class Rj45Node(CoreNodeBase, CoreInterface): """ CoreInterface.detachnet(self) - def newnetif(self, net=None, addrlist=None, hwaddr=None, ifindex=None, ifname=None): + def newnetif( + self, + net: CoreNetworkBase = None, + addrlist: List[str] = None, + hwaddr: str = None, + ifindex: int = None, + ifname: str = None, + ) -> int: """ This is called when linking with another node. Since this node represents an interface, we do not create another object here, @@ -359,7 +397,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): return ifindex - def delnetif(self, ifindex): + def delnetif(self, ifindex: int) -> None: """ Delete a network interface. @@ -376,7 +414,9 @@ class Rj45Node(CoreNodeBase, CoreInterface): else: raise ValueError(f"ifindex {ifindex} does not exist") - def netif(self, ifindex, net=None): + def netif( + self, ifindex: int, net: CoreNetworkBase = None + ) -> Optional[CoreInterface]: """ This object is considered the network interface, so we only return self here. This keeps the RJ45Node compatible with @@ -398,20 +438,20 @@ class Rj45Node(CoreNodeBase, CoreInterface): return None - def getifindex(self, netif): + def getifindex(self, netif: CoreInterface) -> Optional[int]: """ Retrieve network interface index. - :param core.nodes.interface.CoreInterface netif: network interface to retrieve index for + :param core.nodes.interface.CoreInterface netif: network interface to retrieve + index for :return: interface index, None otherwise :rtype: int """ if netif != self: return None - return self.ifindex - def addaddr(self, addr): + def addaddr(self, addr: str) -> None: """ Add address to to network interface. @@ -424,7 +464,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): self.net_client.create_address(self.name, addr) CoreInterface.addaddr(self, addr) - def deladdr(self, addr): + def deladdr(self, addr: str) -> None: """ Delete address from network interface. @@ -434,10 +474,9 @@ class Rj45Node(CoreNodeBase, CoreInterface): """ if self.up: self.net_client.delete_address(self.name, str(addr)) - CoreInterface.deladdr(self, addr) - def savestate(self): + def savestate(self) -> None: """ Save the addresses and other interface state before using the interface for emulation purposes. TODO: save/restore the PROMISC flag @@ -464,7 +503,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): continue self.old_addrs.append((items[1], None)) - def restorestate(self): + def restorestate(self) -> None: """ Restore the addresses and other interface state after using it. @@ -482,7 +521,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): if self.old_up: self.net_client.device_up(self.localname) - def setposition(self, x=None, y=None, z=None): + def setposition(self, x: float = None, y: float = None, z: float = None) -> bool: """ Uses setposition from both parent classes. @@ -496,7 +535,7 @@ class Rj45Node(CoreNodeBase, CoreInterface): CoreInterface.setposition(self, x, y, z) return result - def termcmdstring(self, sh): + def termcmdstring(self, sh: str) -> str: """ Create a terminal command string. diff --git a/daemon/core/plugins/sdt.py b/daemon/core/plugins/sdt.py index 410bba9d..575cbcda 100644 --- a/daemon/core/plugins/sdt.py +++ b/daemon/core/plugins/sdt.py @@ -4,17 +4,19 @@ sdt.py: Scripted Display Tool (SDT3D) helper import logging import socket +from typing import TYPE_CHECKING, Any, Optional from urllib.parse import urlparse from core import constants +from core.api.tlv.coreapi import CoreLinkMessage, CoreMessage, CoreNodeMessage from core.constants import CORE_DATA_DIR from core.emane.nodes import EmaneNet +from core.emulator.data import LinkData, NodeData from core.emulator.enumerations import ( EventTypes, LinkTlvs, LinkTypes, MessageFlags, - MessageTypes, NodeTlvs, NodeTypes, ) @@ -22,6 +24,9 @@ from core.errors import CoreError from core.nodes.base import CoreNetworkBase, NodeBase from core.nodes.network import WlanNode +if TYPE_CHECKING: + from core.emulator.session import Session + # TODO: A named tuple may be more appropriate, than abusing a class dict like this class Bunch: @@ -29,7 +34,7 @@ class Bunch: Helper class for recording a collection of attributes. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """ Create a Bunch instance. @@ -62,7 +67,7 @@ class Sdt: ("tunnel", "tunnel.gif"), ] - def __init__(self, session): + def __init__(self, session: "Session") -> None: """ Creates a Sdt instance. @@ -83,7 +88,7 @@ class Sdt: # add handler for link updates self.session.link_handlers.append(self.handle_link_update) - def handle_node_update(self, node_data): + def handle_node_update(self, node_data: NodeData) -> None: """ Handler for node updates, specifically for updating their location. @@ -108,7 +113,7 @@ class Sdt: # TODO: z is not currently supported by node messages self.updatenode(node_data.id, 0, x, y, 0) - def handle_link_update(self, link_data): + def handle_link_update(self, link_data: LinkData) -> None: """ Handler for link updates, checking for wireless link/unlink messages. @@ -123,7 +128,7 @@ class Sdt: wireless=True, ) - def is_enabled(self): + def is_enabled(self) -> bool: """ Check for "enablesdt" session option. Return False by default if the option is missing. @@ -133,7 +138,7 @@ class Sdt: """ return self.session.options.get_config("enablesdt") == "1" - def seturl(self): + def seturl(self) -> None: """ Read "sdturl" from session options, or use the default value. Set self.url, self.address, self.protocol @@ -147,7 +152,7 @@ class Sdt: self.address = (self.url.hostname, self.url.port) self.protocol = self.url.scheme - def connect(self, flags=0): + def connect(self, flags: int = 0) -> bool: """ Connect to the SDT address/port if enabled. @@ -185,7 +190,7 @@ class Sdt: return True - def initialize(self): + def initialize(self) -> bool: """ Load icon sprites, and fly to the reference point location on the virtual globe. @@ -202,7 +207,7 @@ class Sdt: lat, long = self.session.location.refgeo[:2] return self.cmd(f"flyto {long:.6f},{lat:.6f},{self.DEFAULT_ALT}") - def disconnect(self): + def disconnect(self) -> None: """ Disconnect from SDT. @@ -218,7 +223,7 @@ class Sdt: self.connected = False - def shutdown(self): + def shutdown(self) -> None: """ Invoked from Session.shutdown() and Session.checkshutdown(). @@ -228,7 +233,7 @@ class Sdt: self.disconnect() self.showerror = True - def cmd(self, cmdstr): + def cmd(self, cmdstr: str) -> bool: """ Send an SDT command over a UDP socket. socket.sendall() is used as opposed to socket.sendto() because an exception is raised when @@ -250,7 +255,17 @@ class Sdt: self.connected = False return False - def updatenode(self, nodenum, flags, x, y, z, name=None, node_type=None, icon=None): + def updatenode( + self, + nodenum: int, + flags: int, + x: Optional[float], + y: Optional[float], + z: Optional[float], + name: str = None, + node_type: str = None, + icon: str = None, + ) -> None: """ Node is updated from a Node Message or mobility script. @@ -283,13 +298,13 @@ class Sdt: else: self.cmd(f"node {nodenum} {pos}") - def updatenodegeo(self, nodenum, lat, long, alt): + def updatenodegeo(self, nodenum: int, lat: float, lon: float, alt: float) -> None: """ Node is updated upon receiving an EMANE Location Event. :param int nodenum: node id to update geospatial for :param lat: latitude - :param long: longitude + :param lon: longitude :param alt: altitude :return: nothing """ @@ -297,10 +312,12 @@ class Sdt: # TODO: received Node Message with lat/long/alt. if not self.connect(): return - pos = f"pos {long:.6f},{lat:.6f},{alt:.6f}" + pos = f"pos {lon:.6f},{lat:.6f},{alt:.6f}" self.cmd(f"node {nodenum} {pos}") - def updatelink(self, node1num, node2num, flags, wireless=False): + def updatelink( + self, node1num: int, node2num: int, flags: int, wireless: bool = False + ) -> None: """ Link is updated from a Link Message or by a wireless model. @@ -323,7 +340,7 @@ class Sdt: attr = " line red,2" self.cmd(f"link {node1num},{node2num}{attr}") - def sendobjs(self): + def sendobjs(self) -> None: """ Session has already started, and the SDT3D GUI later connects. Send all node and link objects for display. Otherwise, nodes and @@ -379,21 +396,21 @@ class Sdt: for n2num, wireless_link in r.links: self.updatelink(n1num, n2num, MessageFlags.ADD.value, wireless_link) - def handle_distributed(self, message): + def handle_distributed(self, message: CoreMessage) -> None: """ Broker handler for processing CORE API messages as they are received. This is used to snoop the Node messages and update node positions. :param message: message to handle - :return: replies + :return: nothing """ - if message.message_type == MessageTypes.LINK.value: - return self.handlelinkmsg(message) - elif message.message_type == MessageTypes.NODE.value: - return self.handlenodemsg(message) + if isinstance(message, CoreLinkMessage): + self.handlelinkmsg(message) + elif isinstance(message, CoreNodeMessage): + self.handlenodemsg(message) - def handlenodemsg(self, msg): + def handlenodemsg(self, msg: CoreNodeMessage) -> None: """ Process a Node Message to add/delete or move a node on the SDT display. Node properties are found in a session or @@ -405,7 +422,7 @@ class Sdt: # for distributed sessions to work properly, the SDT option should be # enabled prior to starting the session if not self.is_enabled(): - return False + return # node.(_id, type, icon, name) are used. nodenum = msg.get_tlv(NodeTlvs.NUMBER.value) if not nodenum: @@ -461,7 +478,7 @@ class Sdt: remote.pos = (x, y, z) self.updatenode(nodenum, msg.flags, x, y, z, name, nodetype, icon) - def handlelinkmsg(self, msg): + def handlelinkmsg(self, msg: CoreLinkMessage) -> None: """ Process a Link Message to add/remove links on the SDT display. Links are recorded in the remotes[nodenum1].links set for updating @@ -471,7 +488,7 @@ class Sdt: :return: nothing """ if not self.is_enabled(): - return False + return nodenum1 = msg.get_tlv(LinkTlvs.N1_NUMBER.value) nodenum2 = msg.get_tlv(LinkTlvs.N2_NUMBER.value) link_msg_type = msg.get_tlv(LinkTlvs.TYPE.value) @@ -488,7 +505,7 @@ class Sdt: r.links.add((nodenum2, wl)) self.updatelink(nodenum1, nodenum2, msg.flags, wireless=wl) - def wlancheck(self, nodenum): + def wlancheck(self, nodenum: int) -> bool: """ Helper returns True if a node number corresponds to a WLAN or EMANE node. diff --git a/daemon/core/services/coreservices.py b/daemon/core/services/coreservices.py index 686023a2..9abbc977 100644 --- a/daemon/core/services/coreservices.py +++ b/daemon/core/services/coreservices.py @@ -10,12 +10,17 @@ services. import enum import logging import time +from typing import TYPE_CHECKING, Iterable, List, Tuple, Type from core import utils from core.constants import which from core.emulator.data import FileData from core.emulator.enumerations import ExceptionLevels, MessageFlags, RegisterTlvs from core.errors import CoreCommandError +from core.nodes.base import CoreNode + +if TYPE_CHECKING: + from core.emulator.session import Session class ServiceBootError(Exception): @@ -34,7 +39,7 @@ class ServiceDependencies: that all services will be booted and that all dependencies exist within the services provided. """ - def __init__(self, services): + def __init__(self, services: List["CoreService"]) -> None: # helpers to check validity self.dependents = {} self.booted = set() @@ -50,7 +55,7 @@ class ServiceDependencies: self.visited = set() self.visiting = set() - def boot_paths(self): + def boot_paths(self) -> List[List["CoreService"]]: """ Generates the boot paths for the services provided to the class. @@ -78,17 +83,17 @@ class ServiceDependencies: return paths - def _reset(self): + def _reset(self) -> None: self.path = [] self.visited.clear() self.visiting.clear() - def _start(self, service): + def _start(self, service: "CoreService") -> List["CoreService"]: logging.debug("starting service dependency check: %s", service.name) self._reset() return self._visit(service) - def _visit(self, current_service): + def _visit(self, current_service: "CoreService") -> List["CoreService"]: logging.debug("visiting service(%s): %s", current_service.name, self.path) self.visited.add(current_service.name) self.visiting.add(current_service.name) @@ -139,7 +144,7 @@ class ServiceShim: ] @classmethod - def tovaluelist(cls, node, service): + def tovaluelist(cls, node: CoreNode, service: "CoreService") -> str: """ Convert service properties into a string list of key=value pairs, separated by "|". @@ -168,7 +173,7 @@ class ServiceShim: return "|".join(vals) @classmethod - def fromvaluelist(cls, service, values): + def fromvaluelist(cls, service: "CoreService", values: None): """ Convert list of values into properties for this instantiated (customized) service. @@ -186,7 +191,7 @@ class ServiceShim: logging.exception("error indexing into key") @classmethod - def setvalue(cls, service, key, value): + def setvalue(cls, service: "CoreService", key: str, value: str) -> None: """ Set values for this service. @@ -220,7 +225,7 @@ class ServiceShim: service.meta = value @classmethod - def servicesfromopaque(cls, opaque): + def servicesfromopaque(cls, opaque: str) -> List[str]: """ Build a list of services from an opaque data string. @@ -242,7 +247,7 @@ class ServiceManager: services = {} @classmethod - def add(cls, service): + def add(cls, service: "CoreService") -> None: """ Add a service to manager. @@ -272,7 +277,7 @@ class ServiceManager: cls.services[name] = service @classmethod - def get(cls, name): + def get(cls, name: str) -> Type["CoreService"]: """ Retrieve a service from the manager. @@ -283,7 +288,7 @@ class ServiceManager: return cls.services.get(name) @classmethod - def add_services(cls, path): + def add_services(cls, path: str) -> List[str]: """ Method for retrieving all CoreServices from a given path. @@ -317,7 +322,7 @@ class CoreServices: name = "services" config_type = RegisterTlvs.UTILITY.value - def __init__(self, session): + def __init__(self, session: "Session") -> None: """ Creates a CoreServices instance. @@ -329,13 +334,13 @@ class CoreServices: # dict of node ids to dict of custom services by name self.custom_services = {} - def reset(self): + def reset(self) -> None: """ Called when config message with reset flag is received """ self.custom_services.clear() - def get_default_services(self, node_type): + def get_default_services(self, node_type: str) -> List[Type["CoreService"]]: """ Get the list of default services that should be enabled for a node for the given node type. @@ -356,16 +361,18 @@ class CoreServices: results.append(service) return results - def get_service(self, node_id, service_name, default_service=False): + def get_service( + self, node_id: int, service_name: str, default_service: bool = False + ) -> "CoreService": """ - Get any custom service configured for the given node that matches the specified service name. - If no custom service is found, return the specified service. + Get any custom service configured for the given node that matches the specified + service name. If no custom service is found, return the specified service. :param int node_id: object id to get service from :param str service_name: name of service to retrieve - :param bool default_service: True to return default service when custom does not exist, False returns None + :param bool default_service: True to return default service when custom does + not exist, False returns None :return: custom service from the node - :rtype: CoreService """ node_services = self.custom_services.setdefault(node_id, {}) default = None @@ -373,7 +380,7 @@ class CoreServices: default = ServiceManager.get(service_name) return node_services.get(service_name, default) - def set_service(self, node_id, service_name): + def set_service(self, node_id: int, service_name: str) -> None: """ Store service customizations in an instantiated service object using a list of values that came from a config message. @@ -392,7 +399,9 @@ class CoreServices: node_services = self.custom_services.setdefault(node_id, {}) node_services[service.name] = service - def add_services(self, node, node_type, services=None): + def add_services( + self, node: CoreNode, node_type: str, services: List[str] = None + ) -> None: """ Add services to a node. @@ -417,10 +426,10 @@ class CoreServices: continue node.services.append(service) - def all_configs(self): + def all_configs(self) -> List[Tuple[int, Type["CoreService"]]]: """ - Return (node_id, service) tuples for all stored configs. Used when reconnecting to a - session or opening XML. + Return (node_id, service) tuples for all stored configs. Used when reconnecting + to a session or opening XML. :return: list of tuples of node ids and services :rtype: list[tuple] @@ -433,7 +442,7 @@ class CoreServices: configs.append((node_id, service)) return configs - def all_files(self, service): + def all_files(self, service: "CoreService") -> List[Tuple[str, str]]: """ Return all customized files stored with a service. Used when reconnecting to a session or opening XML. @@ -454,7 +463,7 @@ class CoreServices: return files - def boot_services(self, node): + def boot_services(self, node: CoreNode) -> None: """ Start all services on a node. @@ -470,7 +479,7 @@ class CoreServices: if exceptions: raise ServiceBootError(*exceptions) - def _start_boot_paths(self, node, boot_path): + def _start_boot_paths(self, node: CoreNode, boot_path: List["CoreService"]) -> None: """ Start all service boot paths found, based on dependencies. @@ -491,7 +500,7 @@ class CoreServices: logging.exception("exception booting service: %s", service.name) raise - def boot_service(self, node, service): + def boot_service(self, node: CoreNode, service: "CoreService") -> None: """ Start a service on a node. Create private dirs, generate config files, and execute startup commands. @@ -555,7 +564,7 @@ class CoreServices: "node(%s) service(%s) failed validation" % (node.name, service.name) ) - def copy_service_file(self, node, filename, cfg): + def copy_service_file(self, node: CoreNode, filename: str, cfg: str) -> bool: """ Given a configured service filename and config, determine if the config references an existing file that should be copied. @@ -576,7 +585,7 @@ class CoreServices: return True return False - def validate_service(self, node, service): + def validate_service(self, node: CoreNode, service: "CoreService") -> int: """ Run the validation command(s) for a service. @@ -605,7 +614,7 @@ class CoreServices: return status - def stop_services(self, node): + def stop_services(self, node: CoreNode) -> None: """ Stop all services on a node. @@ -615,14 +624,13 @@ class CoreServices: for service in node.services: self.stop_service(node, service) - def stop_service(self, node, service): + def stop_service(self, node: CoreNode, service: "CoreService") -> int: """ Stop a service on a node. :param core.nodes.base.CoreNode node: node to stop a service on :param CoreService service: service to stop :return: status for stopping the services - :rtype: str """ status = 0 for args in service.shutdown: @@ -639,7 +647,7 @@ class CoreServices: status = -1 return status - def get_service_file(self, node, service_name, filename): + def get_service_file(self, node: CoreNode, service_name: str, filename: str) -> str: """ Send a File Message when the GUI has requested a service file. The file data is either auto-generated or comes from an existing config. @@ -681,7 +689,9 @@ class CoreServices: data=data, ) - def set_service_file(self, node_id, service_name, file_name, data): + def set_service_file( + self, node_id: int, service_name: str, file_name: str, data: str + ) -> None: """ Receive a File Message from the GUI and store the customized file in the service config. The filename must match one from the list of @@ -713,7 +723,9 @@ class CoreServices: # set custom service file data service.config_data[file_name] = data - def startup_service(self, node, service, wait=False): + def startup_service( + self, node: CoreNode, service: "CoreService", wait: bool = False + ) -> int: """ Startup a node service. @@ -737,7 +749,7 @@ class CoreServices: status = -1 return status - def create_service_files(self, node, service): + def create_service_files(self, node: CoreNode, service: "CoreService") -> None: """ Creates node service files. @@ -771,7 +783,7 @@ class CoreServices: node.nodefile(file_name, cfg) - def service_reconfigure(self, node, service): + def service_reconfigure(self, node: CoreNode, service: "CoreService") -> None: """ Reconfigure a node service. @@ -846,7 +858,7 @@ class CoreService: custom = False custom_needed = False - def __init__(self): + def __init__(self) -> None: """ Services are not necessarily instantiated. Classmethods may be used against their config. Services are instantiated when a custom @@ -856,11 +868,11 @@ class CoreService: self.config_data = self.__class__.config_data.copy() @classmethod - def on_load(cls): + def on_load(cls) -> None: pass @classmethod - def get_configs(cls, node): + def get_configs(cls, node: CoreNode) -> Iterable[str]: """ Return the tuple of configuration file filenames. This default method returns the cls._configs tuple, but this method may be overriden to @@ -873,7 +885,7 @@ class CoreService: return cls.configs @classmethod - def generate_config(cls, node, filename): + def generate_config(cls, node: CoreNode, filename: str) -> None: """ Generate configuration file given a node object. The filename is provided to allow for multiple config files. @@ -887,7 +899,7 @@ class CoreService: raise NotImplementedError @classmethod - def get_startup(cls, node): + def get_startup(cls, node: CoreNode) -> Iterable[str]: """ Return the tuple of startup commands. This default method returns the cls.startup tuple, but this method may be @@ -901,7 +913,7 @@ class CoreService: return cls.startup @classmethod - def get_validate(cls, node): + def get_validate(cls, node: CoreNode) -> Iterable[str]: """ Return the tuple of validate commands. This default method returns the cls.validate tuple, but this method may be diff --git a/daemon/core/utils.py b/daemon/core/utils.py index cf394a15..73e11cb8 100644 --- a/daemon/core/utils.py +++ b/daemon/core/utils.py @@ -15,15 +15,36 @@ import random import shlex import sys from subprocess import PIPE, STDOUT, Popen +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import netaddr from core.errors import CoreCommandError, CoreError +if TYPE_CHECKING: + from core.emulator.session import Session + from core.nodes.base import CoreNode +T = TypeVar("T") + DEVNULL = open(os.devnull, "wb") -def execute_file(path, exec_globals=None, exec_locals=None): +def execute_file( + path: str, exec_globals: Dict[str, str] = None, exec_locals: Dict[str, str] = None +) -> None: """ Provides an alternative way to run execfile to be compatible for both python2/3. @@ -41,7 +62,7 @@ def execute_file(path, exec_globals=None, exec_locals=None): exec(data, exec_globals, exec_locals) -def hashkey(value): +def hashkey(value: Union[str, int]) -> int: """ Provide a consistent hash that can be used in place of the builtin hash, that no longer behaves consistently @@ -57,7 +78,7 @@ def hashkey(value): return int(hashlib.sha256(value).hexdigest(), 16) -def _detach_init(): +def _detach_init() -> None: """ Fork a child process and exit. @@ -69,7 +90,7 @@ def _detach_init(): os.setsid() -def _valid_module(path, file_name): +def _valid_module(path: str, file_name: str) -> bool: """ Check if file is a valid python module. @@ -91,7 +112,7 @@ def _valid_module(path, file_name): return True -def _is_class(module, member, clazz): +def _is_class(module: Any, member: Type, clazz: Type) -> bool: """ Validates if a module member is a class and an instance of a CoreService. @@ -113,7 +134,7 @@ def _is_class(module, member, clazz): return True -def close_onexec(fd): +def close_onexec(fd: int) -> None: """ Close on execution of a shell process. @@ -124,7 +145,7 @@ def close_onexec(fd): fcntl.fcntl(fd, fcntl.F_SETFD, fdflags | fcntl.FD_CLOEXEC) -def which(command, required): +def which(command: str, required: bool) -> str: """ Find location of desired executable within current PATH. @@ -146,7 +167,7 @@ def which(command, required): return found_path -def make_tuple(obj): +def make_tuple(obj: Generic[T]) -> Tuple[T]: """ Create a tuple from an object, or return the object itself. @@ -160,7 +181,7 @@ def make_tuple(obj): return (obj,) -def make_tuple_fromstr(s, value_type): +def make_tuple_fromstr(s: str, value_type: Callable[[str], T]) -> Tuple[T]: """ Create a tuple from a string. @@ -179,11 +200,11 @@ def make_tuple_fromstr(s, value_type): return tuple(value_type(i) for i in values) -def mute_detach(args, **kwargs): +def mute_detach(args: str, **kwargs: Dict[str, Any]) -> int: """ Run a muted detached process by forking it. - :param list[str]|str args: arguments for the command + :param str args: arguments for the command :param dict kwargs: keyword arguments for the command :return: process id of the command :rtype: int @@ -195,7 +216,13 @@ def mute_detach(args, **kwargs): return Popen(args, **kwargs).pid -def cmd(args, env=None, cwd=None, wait=True, shell=False): +def cmd( + args: str, + env: Dict[str, str] = None, + cwd: str = None, + wait: bool = True, + shell: bool = False, +) -> str: """ Execute a command on the host and return a tuple containing the exit status and result string. stderr output is folded into the stdout result string. @@ -227,7 +254,7 @@ def cmd(args, env=None, cwd=None, wait=True, shell=False): raise CoreCommandError(-1, args) -def file_munge(pathname, header, text): +def file_munge(pathname: str, header: str, text: str) -> None: """ Insert text at the end of a file, surrounded by header comments. @@ -245,7 +272,7 @@ def file_munge(pathname, header, text): append_file.write(f"# END {header}\n") -def file_demunge(pathname, header): +def file_demunge(pathname: str, header: str) -> None: """ Remove text that was inserted in a file surrounded by header comments. @@ -273,7 +300,9 @@ def file_demunge(pathname, header): write_file.write("".join(lines)) -def expand_corepath(pathname, session=None, node=None): +def expand_corepath( + pathname: str, session: "Session" = None, node: "CoreNode" = None +) -> str: """ Expand a file path given session information. @@ -296,7 +325,7 @@ def expand_corepath(pathname, session=None, node=None): return pathname -def sysctl_devname(devname): +def sysctl_devname(devname: str) -> Optional[str]: """ Translate a device name to the name used with sysctl. @@ -309,7 +338,7 @@ def sysctl_devname(devname): return devname.replace(".", "/") -def load_config(filename, d): +def load_config(filename: str, d: Dict[str, str]) -> None: """ Read key=value pairs from a file, into a dict. Skip comments; strip newline characters and spacing. @@ -332,7 +361,7 @@ def load_config(filename, d): logging.exception("error reading file to dict: %s", filename) -def load_classes(path, clazz): +def load_classes(path: str, clazz: Generic[T]) -> T: """ Dynamically load classes for use within CORE. @@ -375,7 +404,7 @@ def load_classes(path, clazz): return classes -def load_logging_config(config_path): +def load_logging_config(config_path: str) -> None: """ Load CORE logging configuration file. @@ -387,7 +416,9 @@ def load_logging_config(config_path): logging.config.dictConfig(log_config) -def threadpool(funcs, workers=10): +def threadpool( + funcs: List[Tuple[Callable, Iterable[Any], Dict[Any, Any]]], workers: int = 10 +) -> Tuple[List[Any], List[Exception]]: """ Run provided functions, arguments, and keywords within a threadpool collecting results and exceptions. @@ -409,11 +440,12 @@ def threadpool(funcs, workers=10): result = future.result() results.append(result) except Exception as e: + logging.exception("thread pool exception") exceptions.append(e) return results, exceptions -def random_mac(): +def random_mac() -> str: """ Create a random mac address using Xen OID 00:16:3E. @@ -427,7 +459,7 @@ def random_mac(): return str(mac) -def validate_mac(value): +def validate_mac(value: str) -> str: """ Validate mac and return unix formatted version. @@ -443,7 +475,7 @@ def validate_mac(value): raise CoreError(f"invalid mac address {value}: {e}") -def validate_ip(value): +def validate_ip(value: str) -> str: """ Validate ip address with prefix and return formatted version. diff --git a/daemon/core/xml/corexml.py b/daemon/core/xml/corexml.py index 0266912d..df73901f 100644 --- a/daemon/core/xml/corexml.py +++ b/daemon/core/xml/corexml.py @@ -1,17 +1,30 @@ import logging +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar from lxml import etree import core.nodes.base import core.nodes.physical from core.emane.nodes import EmaneNet +from core.emulator.data import LinkData from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions from core.emulator.enumerations import NodeTypes -from core.nodes.base import CoreNetworkBase +from core.nodes.base import CoreNetworkBase, NodeBase from core.nodes.network import CtrlNet +from core.services.coreservices import CoreService + +if TYPE_CHECKING: + from core.emane.emanemanager import EmaneGlobalModel + from core.emane.emanemodel import EmaneModel + from core.emulator.session import Session + + EmaneModelType = Type[EmaneModel] +T = TypeVar("T") -def write_xml_file(xml_element, file_path, doctype=None): +def write_xml_file( + xml_element: etree.Element, file_path: str, doctype: str = None +) -> None: xml_data = etree.tostring( xml_element, xml_declaration=True, @@ -23,27 +36,27 @@ def write_xml_file(xml_element, file_path, doctype=None): xml_file.write(xml_data) -def get_type(element, name, _type): +def get_type(element: etree.Element, name: str, _type: Generic[T]) -> Optional[T]: value = element.get(name) if value is not None: value = _type(value) return value -def get_float(element, name): +def get_float(element: etree.Element, name: str) -> float: return get_type(element, name, float) -def get_int(element, name): +def get_int(element: etree.Element, name: str) -> int: return get_type(element, name, int) -def add_attribute(element, name, value): +def add_attribute(element: etree.Element, name: str, value: Any) -> None: if value is not None: element.set(name, str(value)) -def create_interface_data(interface_element): +def create_interface_data(interface_element: etree.Element) -> InterfaceData: interface_id = int(interface_element.get("id")) name = interface_element.get("name") mac = interface_element.get("mac") @@ -54,7 +67,9 @@ def create_interface_data(interface_element): return InterfaceData(interface_id, name, mac, ip4, ip4_mask, ip6, ip6_mask) -def create_emane_config(node_id, emane_config, config): +def create_emane_config( + node_id: int, emane_config: "EmaneGlobalModel", config: Dict[str, str] +) -> etree.Element: emane_configuration = etree.Element("emane_configuration") add_attribute(emane_configuration, "node", node_id) add_attribute(emane_configuration, "model", "emane") @@ -72,7 +87,9 @@ def create_emane_config(node_id, emane_config, config): return emane_configuration -def create_emane_model_config(node_id, model, config): +def create_emane_model_config( + node_id: int, model: "EmaneModelType", config: Dict[str, str] +) -> etree.Element: emane_element = etree.Element("emane_configuration") add_attribute(emane_element, "node", node_id) add_attribute(emane_element, "model", model.name) @@ -95,14 +112,14 @@ def create_emane_model_config(node_id, model, config): return emane_element -def add_configuration(parent, name, value): +def add_configuration(parent: etree.Element, name: str, value: str) -> None: config_element = etree.SubElement(parent, "configuration") add_attribute(config_element, "name", name) add_attribute(config_element, "value", value) class NodeElement: - def __init__(self, session, node, element_name): + def __init__(self, session: "Session", node: NodeBase, element_name: str) -> None: self.session = session self.node = node self.element = etree.Element(element_name) @@ -112,7 +129,7 @@ class NodeElement: add_attribute(self.element, "canvas", node.canvas) self.add_position() - def add_position(self): + def add_position(self) -> None: x = self.node.position.x y = self.node.position.y z = self.node.position.z @@ -129,7 +146,7 @@ class NodeElement: class ServiceElement: - def __init__(self, service): + def __init__(self, service: Type[CoreService]) -> None: self.service = service self.element = etree.Element("service") add_attribute(self.element, "name", service.name) @@ -139,7 +156,7 @@ class ServiceElement: self.add_shutdown() self.add_files() - def add_directories(self): + def add_directories(self) -> None: # get custom directories directories = etree.Element("directories") for directory in self.service.dirs: @@ -149,7 +166,7 @@ class ServiceElement: if directories.getchildren(): self.element.append(directories) - def add_files(self): + def add_files(self) -> None: # get custom files file_elements = etree.Element("files") for file_name in self.service.config_data: @@ -161,7 +178,7 @@ class ServiceElement: if file_elements.getchildren(): self.element.append(file_elements) - def add_startup(self): + def add_startup(self) -> None: # get custom startup startup_elements = etree.Element("startups") for startup in self.service.startup: @@ -171,7 +188,7 @@ class ServiceElement: if startup_elements.getchildren(): self.element.append(startup_elements) - def add_validate(self): + def add_validate(self) -> None: # get custom validate validate_elements = etree.Element("validates") for validate in self.service.validate: @@ -181,7 +198,7 @@ class ServiceElement: if validate_elements.getchildren(): self.element.append(validate_elements) - def add_shutdown(self): + def add_shutdown(self) -> None: # get custom shutdown shutdown_elements = etree.Element("shutdowns") for shutdown in self.service.shutdown: @@ -193,12 +210,12 @@ class ServiceElement: class DeviceElement(NodeElement): - def __init__(self, session, node): + def __init__(self, session: "Session", node: NodeBase) -> None: super().__init__(session, node, "device") add_attribute(self.element, "type", node.type) self.add_services() - def add_services(self): + def add_services(self) -> None: service_elements = etree.Element("services") for service in self.node.services: etree.SubElement(service_elements, "service", name=service.name) @@ -208,7 +225,7 @@ class DeviceElement(NodeElement): class NetworkElement(NodeElement): - def __init__(self, session, node): + def __init__(self, session: "Session", node: NodeBase) -> None: super().__init__(session, node, "network") model = getattr(self.node, "model", None) if model: @@ -221,7 +238,7 @@ class NetworkElement(NodeElement): add_attribute(self.element, "grekey", grekey) self.add_type() - def add_type(self): + def add_type(self) -> None: if self.node.apitype: node_type = NodeTypes(self.node.apitype).name else: @@ -230,14 +247,14 @@ class NetworkElement(NodeElement): class CoreXmlWriter: - def __init__(self, session): + def __init__(self, session: "Session") -> None: self.session = session self.scenario = etree.Element("scenario") self.networks = None self.devices = None self.write_session() - def write_session(self): + def write_session(self) -> None: # generate xml content links = self.write_nodes() self.write_links(links) @@ -250,7 +267,7 @@ class CoreXmlWriter: self.write_session_metadata() self.write_default_services() - def write(self, file_name): + def write(self, file_name: str) -> None: self.scenario.set("name", file_name) # write out generated xml @@ -259,7 +276,7 @@ class CoreXmlWriter: file_name, xml_declaration=True, pretty_print=True, encoding="UTF-8" ) - def write_session_origin(self): + def write_session_origin(self) -> None: # origin: geolocation of cartesian coordinate 0,0,0 lat, lon, alt = self.session.location.refgeo origin = etree.Element("session_origin") @@ -279,7 +296,7 @@ class CoreXmlWriter: add_attribute(origin, "y", y) add_attribute(origin, "z", z) - def write_session_hooks(self): + def write_session_hooks(self) -> None: # hook scripts hooks = etree.Element("session_hooks") for state in sorted(self.session._hooks.keys()): @@ -292,7 +309,7 @@ class CoreXmlWriter: if hooks.getchildren(): self.scenario.append(hooks) - def write_session_options(self): + def write_session_options(self) -> None: option_elements = etree.Element("session_options") options_config = self.session.options.get_configs() if not options_config: @@ -307,7 +324,7 @@ class CoreXmlWriter: if option_elements.getchildren(): self.scenario.append(option_elements) - def write_session_metadata(self): + def write_session_metadata(self) -> None: # metadata metadata_elements = etree.Element("session_metadata") config = self.session.metadata @@ -321,7 +338,7 @@ class CoreXmlWriter: if metadata_elements.getchildren(): self.scenario.append(metadata_elements) - def write_emane_configs(self): + def write_emane_configs(self) -> None: emane_configurations = etree.Element("emane_configurations") for node_id in self.session.emane.nodes(): all_configs = self.session.emane.get_all_configs(node_id) @@ -347,7 +364,7 @@ class CoreXmlWriter: if emane_configurations.getchildren(): self.scenario.append(emane_configurations) - def write_mobility_configs(self): + def write_mobility_configs(self) -> None: mobility_configurations = etree.Element("mobility_configurations") for node_id in self.session.mobility.nodes(): all_configs = self.session.mobility.get_all_configs(node_id) @@ -371,7 +388,7 @@ class CoreXmlWriter: if mobility_configurations.getchildren(): self.scenario.append(mobility_configurations) - def write_service_configs(self): + def write_service_configs(self) -> None: service_configurations = etree.Element("service_configurations") service_configs = self.session.services.all_configs() for node_id, service in service_configs: @@ -382,7 +399,7 @@ class CoreXmlWriter: if service_configurations.getchildren(): self.scenario.append(service_configurations) - def write_default_services(self): + def write_default_services(self) -> None: node_types = etree.Element("default_services") for node_type in self.session.services.default_services: services = self.session.services.default_services[node_type] @@ -393,7 +410,7 @@ class CoreXmlWriter: if node_types.getchildren(): self.scenario.append(node_types) - def write_nodes(self): + def write_nodes(self) -> List[LinkData]: self.networks = etree.SubElement(self.scenario, "networks") self.devices = etree.SubElement(self.scenario, "devices") @@ -416,7 +433,7 @@ class CoreXmlWriter: return links - def write_network(self, node): + def write_network(self, node: NodeBase) -> None: # ignore p2p and other nodes that are not part of the api if not node.apitype: return @@ -424,7 +441,7 @@ class CoreXmlWriter: network = NetworkElement(self.session, node) self.networks.append(network.element) - def write_links(self, links): + def write_links(self, links: List[LinkData]) -> None: link_elements = etree.Element("links") # add link data for link_data in links: @@ -438,13 +455,21 @@ class CoreXmlWriter: if link_elements.getchildren(): self.scenario.append(link_elements) - def write_device(self, node): + def write_device(self, node: NodeBase) -> None: device = DeviceElement(self.session, node) self.devices.append(device.element) def create_interface_element( - self, element_name, node_id, interface_id, mac, ip4, ip4_mask, ip6, ip6_mask - ): + self, + element_name: str, + node_id: int, + interface_id: int, + mac: str, + ip4: str, + ip4_mask: int, + ip6: str, + ip6_mask: int, + ) -> etree.Element: interface = etree.Element(element_name) node = self.session.get_node(node_id) interface_name = None @@ -467,7 +492,7 @@ class CoreXmlWriter: return interface - def create_link_element(self, link_data): + def create_link_element(self, link_data: LinkData) -> etree.Element: link_element = etree.Element("link") add_attribute(link_element, "node_one", link_data.node1_id) add_attribute(link_element, "node_two", link_data.node2_id) @@ -525,11 +550,11 @@ class CoreXmlWriter: class CoreXmlReader: - def __init__(self, session): + def __init__(self, session: "Session") -> None: self.session = session self.scenario = None - def read(self, file_name): + def read(self, file_name: str) -> None: xml_tree = etree.parse(file_name) self.scenario = xml_tree.getroot() @@ -545,7 +570,7 @@ class CoreXmlReader: self.read_nodes() self.read_links() - def read_default_services(self): + def read_default_services(self) -> None: default_services = self.scenario.find("default_services") if default_services is None: return @@ -560,7 +585,7 @@ class CoreXmlReader: ) self.session.services.default_services[node_type] = services - def read_session_metadata(self): + def read_session_metadata(self) -> None: session_metadata = self.scenario.find("session_metadata") if session_metadata is None: return @@ -573,7 +598,7 @@ class CoreXmlReader: logging.info("reading session metadata: %s", configs) self.session.metadata = configs - def read_session_options(self): + def read_session_options(self) -> None: session_options = self.scenario.find("session_options") if session_options is None: return @@ -586,7 +611,7 @@ class CoreXmlReader: logging.info("reading session options: %s", configs) self.session.options.set_configs(configs) - def read_session_hooks(self): + def read_session_hooks(self) -> None: session_hooks = self.scenario.find("session_hooks") if session_hooks is None: return @@ -601,7 +626,7 @@ class CoreXmlReader: hook_type, file_name=name, source_name=None, data=data ) - def read_session_origin(self): + def read_session_origin(self) -> None: session_origin = self.scenario.find("session_origin") if session_origin is None: return @@ -625,7 +650,7 @@ class CoreXmlReader: logging.info("reading session reference xyz: %s, %s, %s", x, y, z) self.session.location.refxyz = (x, y, z) - def read_service_configs(self): + def read_service_configs(self) -> None: service_configurations = self.scenario.find("service_configurations") if service_configurations is None: return @@ -669,7 +694,7 @@ class CoreXmlReader: files.add(name) service.configs = tuple(files) - def read_emane_configs(self): + def read_emane_configs(self) -> None: emane_configurations = self.scenario.find("emane_configurations") if emane_configurations is None: return @@ -702,7 +727,7 @@ class CoreXmlReader: ) self.session.emane.set_model_config(node_id, model_name, configs) - def read_mobility_configs(self): + def read_mobility_configs(self) -> None: mobility_configurations = self.scenario.find("mobility_configurations") if mobility_configurations is None: return @@ -722,7 +747,7 @@ class CoreXmlReader: ) self.session.mobility.set_model_config(node_id, model_name, configs) - def read_nodes(self): + def read_nodes(self) -> None: device_elements = self.scenario.find("devices") if device_elements is not None: for device_element in device_elements.iterchildren(): @@ -733,7 +758,7 @@ class CoreXmlReader: for network_element in network_elements.iterchildren(): self.read_network(network_element) - def read_device(self, device_element): + def read_device(self, device_element: etree.Element) -> None: node_id = get_int(device_element, "id") name = device_element.get("name") model = device_element.get("type") @@ -759,7 +784,7 @@ class CoreXmlReader: logging.info("reading node id(%s) model(%s) name(%s)", node_id, model, name) self.session.add_node(_id=node_id, options=options) - def read_network(self, network_element): + def read_network(self, network_element: etree.Element) -> None: node_id = get_int(network_element, "id") name = network_element.get("name") node_type = NodeTypes[network_element.get("type")] @@ -783,7 +808,7 @@ class CoreXmlReader: ) self.session.add_node(_type=node_type, _id=node_id, options=options) - def read_links(self): + def read_links(self) -> None: link_elements = self.scenario.find("links") if link_elements is None: return diff --git a/daemon/core/xml/corexmldeployment.py b/daemon/core/xml/corexmldeployment.py index 5c817fc2..5f340b69 100644 --- a/daemon/core/xml/corexmldeployment.py +++ b/daemon/core/xml/corexmldeployment.py @@ -1,5 +1,6 @@ import os import socket +from typing import TYPE_CHECKING, List, Tuple import netaddr from lxml import etree @@ -7,26 +8,40 @@ from lxml import etree from core import utils from core.constants import IP_BIN from core.emane.nodes import EmaneNet -from core.nodes.base import CoreNodeBase +from core.nodes.base import CoreNodeBase, NodeBase +from core.nodes.interface import CoreInterface + +if TYPE_CHECKING: + from core.emulator.session import Session -def add_type(parent_element, name): +def add_type(parent_element: etree.Element, name: str) -> None: type_element = etree.SubElement(parent_element, "type") type_element.text = name -def add_address(parent_element, address_type, address, interface_name=None): +def add_address( + parent_element: etree.Element, + address_type: str, + address: str, + interface_name: str = None, +) -> None: address_element = etree.SubElement(parent_element, "address", type=address_type) address_element.text = address if interface_name is not None: address_element.set("iface", interface_name) -def add_mapping(parent_element, maptype, mapref): +def add_mapping(parent_element: etree.Element, maptype: str, mapref: str) -> None: etree.SubElement(parent_element, "mapping", type=maptype, ref=mapref) -def add_emane_interface(host_element, netif, platform_name="p1", transport_name="t1"): +def add_emane_interface( + host_element: etree.Element, + netif: CoreInterface, + platform_name: str = "p1", + transport_name: str = "t1", +) -> etree.Element: nem_id = netif.net.nemidmap[netif] host_id = host_element.get("id") @@ -54,7 +69,7 @@ def add_emane_interface(host_element, netif, platform_name="p1", transport_name= return platform_element -def get_address_type(address): +def get_address_type(address: str) -> str: addr, _slash, _prefixlen = address.partition("/") if netaddr.valid_ipv4(addr): address_type = "IPv4" @@ -65,7 +80,7 @@ def get_address_type(address): return address_type -def get_ipv4_addresses(hostname): +def get_ipv4_addresses(hostname: str) -> List[Tuple[str, str]]: if hostname == "localhost": addresses = [] args = f"{IP_BIN} -o -f inet address show" @@ -85,7 +100,7 @@ def get_ipv4_addresses(hostname): class CoreXmlDeployment: - def __init__(self, session, scenario): + def __init__(self, session: "Session", scenario: etree.Element) -> None: self.session = session self.scenario = scenario self.root = etree.SubElement( @@ -93,17 +108,17 @@ class CoreXmlDeployment: ) self.add_deployment() - def find_device(self, name): + def find_device(self, name: str) -> etree.Element: device = self.scenario.find(f"devices/device[@name='{name}']") return device - def find_interface(self, device, name): + def find_interface(self, device: NodeBase, name: str) -> etree.Element: interface = self.scenario.find( f"devices/device[@name='{device.name}']/interfaces/interface[@name='{name}']" ) return interface - def add_deployment(self): + def add_deployment(self) -> None: physical_host = self.add_physical_host(socket.gethostname()) for node_id in self.session.nodes: @@ -111,7 +126,7 @@ class CoreXmlDeployment: if isinstance(node, CoreNodeBase): self.add_virtual_host(physical_host, node) - def add_physical_host(self, name): + def add_physical_host(self, name: str) -> etree.Element: # add host root_id = self.root.get("id") host_id = f"{root_id}/{name}" @@ -126,7 +141,7 @@ class CoreXmlDeployment: return host_element - def add_virtual_host(self, physical_host, node): + def add_virtual_host(self, physical_host: etree.Element, node: NodeBase) -> None: if not isinstance(node, CoreNodeBase): raise TypeError(f"invalid node type: {node}") diff --git a/daemon/core/xml/emanexml.py b/daemon/core/xml/emanexml.py index a62b54e5..da1e089e 100644 --- a/daemon/core/xml/emanexml.py +++ b/daemon/core/xml/emanexml.py @@ -1,16 +1,26 @@ import logging import os from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from lxml import etree from core import utils +from core.config import Configuration +from core.emane.nodes import EmaneNet +from core.emulator.distributed import DistributedServer +from core.nodes.interface import CoreInterface +from core.nodes.network import CtrlNet from core.xml import corexml +if TYPE_CHECKING: + from core.emane.emanemanager import EmaneManager + from core.emane.emanemodel import EmaneModel + _hwaddr_prefix = "02:02" -def is_external(config): +def is_external(config: Dict[str, str]) -> bool: """ Checks if the configuration is for an external transport. @@ -21,7 +31,7 @@ def is_external(config): return config.get("external") == "1" -def _value_to_params(value): +def _value_to_params(value: str) -> Optional[Tuple[str]]: """ Helper to convert a parameter to a parameter tuple. @@ -44,7 +54,12 @@ def _value_to_params(value): return None -def create_file(xml_element, doc_name, file_path, server=None): +def create_file( + xml_element: etree.Element, + doc_name: str, + file_path: str, + server: DistributedServer = None, +) -> None: """ Create xml file. @@ -68,7 +83,7 @@ def create_file(xml_element, doc_name, file_path, server=None): corexml.write_xml_file(xml_element, file_path, doctype=doctype) -def add_param(xml_element, name, value): +def add_param(xml_element: etree.Element, name: str, value: str) -> None: """ Add emane configuration parameter to xml element. @@ -80,7 +95,12 @@ def add_param(xml_element, name, value): etree.SubElement(xml_element, "param", name=name, value=value) -def add_configurations(xml_element, configurations, config, config_ignore): +def add_configurations( + xml_element: etree.Element, + configurations: List[Configuration], + config: Dict[str, str], + config_ignore: Set, +) -> None: """ Add emane model configurations to xml element. @@ -107,7 +127,13 @@ def add_configurations(xml_element, configurations, config, config_ignore): add_param(xml_element, name, value) -def build_node_platform_xml(emane_manager, control_net, node, nem_id, platform_xmls): +def build_node_platform_xml( + emane_manager: "EmaneManager", + control_net: CtrlNet, + node: EmaneNet, + nem_id: int, + platform_xmls: Dict[str, etree.Element], +) -> int: """ Create platform xml for a specific node. @@ -131,7 +157,7 @@ def build_node_platform_xml(emane_manager, control_net, node, nem_id, platform_x if node.model is None: logging.warning("warning: EMANE network %s has no associated model", node.name) - return nem_entries + return nem_id for netif in node.netifs(): logging.debug( @@ -228,7 +254,7 @@ def build_node_platform_xml(emane_manager, control_net, node, nem_id, platform_x return nem_id -def build_xml_files(emane_manager, node): +def build_xml_files(emane_manager: "EmaneManager", node: EmaneNet) -> None: """ Generate emane xml files required for node. @@ -276,7 +302,9 @@ def build_xml_files(emane_manager, node): build_transport_xml(emane_manager, node, rtype) -def build_transport_xml(emane_manager, node, transport_type): +def build_transport_xml( + emane_manager: "EmaneManager", node: EmaneNet, transport_type: str +) -> None: """ Build transport xml file for node and transport type. @@ -317,7 +345,12 @@ def build_transport_xml(emane_manager, node, transport_type): ) -def create_phy_xml(emane_model, config, file_path, server): +def create_phy_xml( + emane_model: "EmaneModel", + config: Dict[str, str], + file_path: str, + server: DistributedServer, +) -> None: """ Create the phy xml document. @@ -345,7 +378,12 @@ def create_phy_xml(emane_model, config, file_path, server): ) -def create_mac_xml(emane_model, config, file_path, server): +def create_mac_xml( + emane_model: "EmaneModel", + config: Dict[str, str], + file_path: str, + server: DistributedServer, +) -> None: """ Create the mac xml document. @@ -376,14 +414,14 @@ def create_mac_xml(emane_model, config, file_path, server): def create_nem_xml( - emane_model, - config, - nem_file, - transport_definition, - mac_definition, - phy_definition, - server, -): + emane_model: "EmaneModel", + config: Dict[str, str], + nem_file: str, + transport_definition: str, + mac_definition: str, + phy_definition: str, + server: DistributedServer, +) -> None: """ Create the nem xml document. @@ -413,7 +451,13 @@ def create_nem_xml( ) -def create_event_service_xml(group, port, device, file_directory, server=None): +def create_event_service_xml( + group: str, + port: str, + device: str, + file_directory: str, + server: DistributedServer = None, +) -> None: """ Create a emane event service xml file. @@ -440,7 +484,7 @@ def create_event_service_xml(group, port, device, file_directory, server=None): create_file(event_element, "emaneeventmsgsvc", file_path, server) -def transport_file_name(node_id, transport_type): +def transport_file_name(node_id: int, transport_type: str) -> str: """ Create name for a transport xml file. @@ -451,10 +495,11 @@ def transport_file_name(node_id, transport_type): return f"n{node_id}trans{transport_type}.xml" -def _basename(emane_model, interface=None): +def _basename(emane_model: "EmaneModel", interface: CoreInterface = None) -> str: """ Create name that is leveraged for configuration file creation. + :param emane_model: emane model to create name for :param interface: interface for this model :return: basename used for file creation :rtype: str @@ -469,7 +514,7 @@ def _basename(emane_model, interface=None): return f"{name}{emane_model.name}" -def nem_file_name(emane_model, interface=None): +def nem_file_name(emane_model: "EmaneModel", interface: CoreInterface = None) -> str: """ Return the string name for the NEM XML file, e.g. "n3rfpipenem.xml" @@ -485,7 +530,7 @@ def nem_file_name(emane_model, interface=None): return f"{basename}nem{append}.xml" -def shim_file_name(emane_model, interface=None): +def shim_file_name(emane_model: "EmaneModel", interface: CoreInterface = None) -> str: """ Return the string name for the SHIM XML file, e.g. "commeffectshim.xml" @@ -498,7 +543,7 @@ def shim_file_name(emane_model, interface=None): return f"{name}shim.xml" -def mac_file_name(emane_model, interface=None): +def mac_file_name(emane_model: "EmaneModel", interface: CoreInterface = None) -> str: """ Return the string name for the MAC XML file, e.g. "n3rfpipemac.xml" @@ -511,7 +556,7 @@ def mac_file_name(emane_model, interface=None): return f"{name}mac.xml" -def phy_file_name(emane_model, interface=None): +def phy_file_name(emane_model: "EmaneModel", interface: CoreInterface = None) -> str: """ Return the string name for the PHY XML file, e.g. "n3rfpipephy.xml" diff --git a/ns3/build/lib/corens3/__init__.py b/ns3/build/lib/corens3/__init__.py new file mode 100644 index 00000000..9bc2e7eb --- /dev/null +++ b/ns3/build/lib/corens3/__init__.py @@ -0,0 +1,9 @@ +""" +corens3 + +Python package containing CORE components for use +with the ns-3 simulator. + +See http://code.google.com/p/coreemu/ +for more information on CORE. +""" diff --git a/ns3/build/lib/corens3/obj.py b/ns3/build/lib/corens3/obj.py new file mode 100644 index 00000000..750dc7a5 --- /dev/null +++ b/ns3/build/lib/corens3/obj.py @@ -0,0 +1,550 @@ +""" +ns3.py: defines classes for running emulations with ns-3 simulated networks. +""" + +import logging +import subprocess +import threading +import time + +import ns.core +import ns.internet +import ns.lte +import ns.mobility +import ns.network +import ns.tap_bridge +import ns.wifi +import ns.wimax + +from core import constants +from core.emulator.enumerations import EventTypes +from core.emulator.enumerations import LinkTypes +from core.emulator.enumerations import NodeTypes +from core.utils import make_tuple +from core.location.mobility import WayPointMobility +from core.nodes.base import CoreNode, CoreNetworkBase +from core.emulator.session import Session + +ns.core.GlobalValue.Bind( + "SimulatorImplementationType", + ns.core.StringValue("ns3::RealtimeSimulatorImpl") +) +ns.core.GlobalValue.Bind( + "ChecksumEnabled", + ns.core.BooleanValue("true") +) + + +class CoreNs3Node(CoreNode, ns.network.Node): + """ + The CoreNs3Node is both a CoreNode backed by a network namespace and + an ns-3 Node simulator object. When linked to simulated networks, the TunTap + device will be used. + """ + + def __init__(self, *args, **kwds): + ns.network.Node.__init__(self) + # ns-3 ID starts at 0, CORE uses 1 + _id = self.GetId() + 1 + if '_id' not in kwds: + kwds['_id'] = _id + CoreNode.__init__(self, *args, **kwds) + + def newnetif(self, net=None, addrlist=None, hwaddr=None, ifindex=None, ifname=None): + """ + Add a network interface. If we are attaching to a CoreNs3Net, this + will be a TunTap. Otherwise dispatch to CoreNode.newnetif(). + """ + if not addrlist: + addrlist = [] + + if not isinstance(net, CoreNs3Net): + return CoreNode.newnetif(self, net, addrlist, hwaddr, ifindex, ifname) + ifindex = self.newtuntap(ifindex, ifname) + self.attachnet(ifindex, net) + netif = self.netif(ifindex) + netif.sethwaddr(hwaddr) + for addr in make_tuple(addrlist): + netif.addaddr(addr) + + addrstr = netif.addrlist[0] + addr, mask = addrstr.split('/') + tap = net._tapdevs[netif] + tap.SetAttribute( + "IpAddress", + ns.network.Ipv4AddressValue(ns.network.Ipv4Address(addr)) + ) + tap.SetAttribute( + "Netmask", + ns.network.Ipv4MaskValue(ns.network.Ipv4Mask(f"/{mask}")) + ) + ns.core.Simulator.Schedule(ns.core.Time("0"), netif.install) + return ifindex + + def getns3position(self): + """ + Return the ns-3 (x, y, z) position of a node. + """ + try: + mm = self.GetObject(ns.mobility.MobilityModel.GetTypeId()) + pos = mm.GetPosition() + return pos.x, pos.y, pos.z + except AttributeError: + self.warn("ns-3 mobility model not found") + return 0, 0, 0 + + def setns3position(self, x, y, z): + """ + Set the ns-3 (x, y, z) position of a node. + """ + try: + mm = self.GetObject(ns.mobility.MobilityModel.GetTypeId()) + if z is None: + z = 0.0 + mm.SetPosition(ns.core.Vector(x, y, z)) + except AttributeError: + self.warn("ns-3 mobility model not found, not setting position") + + +class CoreNs3Net(CoreNetworkBase): + """ + The CoreNs3Net is a helper PyCoreNet object. Networks are represented + entirely in simulation with the TunTap device bridging the emulated and + simulated worlds. + """ + apitype = NodeTypes.WIRELESS_LAN.value + linktype = LinkTypes.WIRELESS.value + # icon used + type = "wlan" + + def __init__( + self, session, _id=None, name=None, start=True, server=None + ): + CoreNetworkBase.__init__(self, session, _id, name, start, server) + self.tapbridge = ns.tap_bridge.TapBridgeHelper() + self._ns3devs = {} + self._tapdevs = {} + + def attach(self, netif): + """ + Invoked from netif.attach(). Create a TAP device using the TapBridge + object. Call getns3dev() to get model-specific device. + """ + self._netif[netif] = netif + self._linked[netif] = {} + ns3dev = self.getns3dev(netif.node) + tap = self.tapbridge.Install(netif.node, ns3dev) + tap.SetMode(ns.tap_bridge.TapBridge.CONFIGURE_LOCAL) + tap.SetAttribute( + "DeviceName", + ns.core.StringValue(netif.localname) + ) + self._ns3devs[netif] = ns3dev + self._tapdevs[netif] = tap + + def getns3dev(self, node): + """ + Implement depending on network helper. Install this network onto + the given node and return the device. Register the ns3 device into + self._ns3devs + """ + raise NotImplementedError + + def findns3dev(self, node): + """ + Given a node, return the interface and ns3 device associated with + this network. + """ + for netif in node.netifs(): + if netif in self._ns3devs: + return netif, self._ns3devs[netif] + return None, None + + def shutdown(self): + """ + Session.shutdown() will invoke this. + """ + pass + + def usecorepositions(self): + """ + Set position callbacks for interfaces on this net so the CORE GUI + can update the ns-3 node position when moved with the mouse. + """ + for netif in self.netifs(): + netif.poshook = self.setns3position + + def setns3position(self, netif, x, y, z): + logging.info("setns3position: %s (%s, %s, %s)", netif.node.name, x, y, z) + netif.node.setns3position(x, y, z) + + +class Ns3LteNet(CoreNs3Net): + def __init__(self, *args, **kwds): + """ + Uses a LteHelper to create an ns-3 based LTE network. + """ + CoreNs3Net.__init__(self, *args, **kwds) + self.lte = ns.lte.LteHelper() + # enhanced NodeB node list + self.enbnodes = [] + self.dlsubchannels = None + self.ulsubchannels = None + + def setsubchannels(self, downlink, uplink): + """ + Set the downlink/uplink subchannels, which are a list of ints. + These should be set prior to using CoreNs3Node.newnetif(). + """ + self.dlsubchannels = downlink + self.ulsubchannels = uplink + + def setnodeb(self, node): + """ + Mark the given node as a nodeb (base transceiver station) + """ + self.enbnodes.append(node) + + def linknodeb(self, node, nodeb, mob, mobb): + """ + Register user equipment with a nodeb. + Optionally install mobility model while we have the ns-3 devs handy. + """ + _tmp, nodebdev = self.findns3dev(nodeb) + _tmp, dev = self.findns3dev(node) + if nodebdev is None or dev is None: + raise KeyError("ns-3 device for node not found") + self.lte.RegisterUeToTheEnb(dev, nodebdev) + if mob: + self.lte.AddMobility(dev.GetPhy(), mob) + if mobb: + self.lte.AddDownlinkChannelRealization(mobb, mob, dev.GetPhy()) + + def getns3dev(self, node): + """ + Get the ns3 NetDevice using the LteHelper. + """ + if node in self.enbnodes: + devtype = ns.lte.LteHelper.DEVICE_TYPE_ENODEB + else: + devtype = ns.lte.LteHelper.DEVICE_TYPE_USER_EQUIPMENT + nodes = ns.network.NodeContainer(node) + devs = self.lte.Install(nodes, devtype) + devs.Get(0).GetPhy().SetDownlinkSubChannels(self.dlsubchannels) + devs.Get(0).GetPhy().SetUplinkSubChannels(self.ulsubchannels) + return devs.Get(0) + + def attach(self, netif): + """ + Invoked from netif.attach(). Create a TAP device using the TapBridge + object. Call getns3dev() to get model-specific device. + """ + self._netif[netif] = netif + self._linked[netif] = {} + ns3dev = self.getns3dev(netif.node) + self.tapbridge.SetAttribute("Mode", ns.core.StringValue("UseLocal")) + # self.tapbridge.SetAttribute("Mode", + # ns.core.IntegerValue(ns.tap_bridge.TapBridge.USE_LOCAL)) + tap = self.tapbridge.Install(netif.node, ns3dev) + # tap.SetMode(ns.tap_bridge.TapBridge.USE_LOCAL) + logging.info("using TAP device %s for %s/%s", netif.localname, netif.node.name, netif.name) + subprocess.check_call(['tunctl', '-t', netif.localname, '-n']) + # check_call([IP_BIN, 'link', 'set', 'dev', netif.localname, \ + # 'address', '%s' % netif.hwaddr]) + subprocess.check_call([constants.IP_BIN, 'link', 'set', netif.localname, 'up']) + tap.SetAttribute("DeviceName", ns.core.StringValue(netif.localname)) + self._ns3devs[netif] = ns3dev + self._tapdevs[netif] = tap + + +class Ns3WifiNet(CoreNs3Net): + def __init__(self, *args, **kwds): + """ + Uses a WifiHelper to create an ns-3 based Wifi network. + """ + rate = kwds.pop('rate', 'OfdmRate54Mbps') + CoreNs3Net.__init__(self, *args, **kwds) + self.wifi = ns.wifi.WifiHelper().Default() + self.wifi.SetStandard(ns.wifi.WIFI_PHY_STANDARD_80211a) + self.wifi.SetRemoteStationManager( + "ns3::ConstantRateWifiManager", + "DataMode", + ns.core.StringValue(rate), + "NonUnicastMode", + ns.core.StringValue(rate) + ) + self.mac = ns.wifi.NqosWifiMacHelper.Default() + self.mac.SetType("ns3::AdhocWifiMac") + + channel = ns.wifi.YansWifiChannelHelper.Default() + self.phy = ns.wifi.YansWifiPhyHelper.Default() + self.phy.SetChannel(channel.Create()) + + def getns3dev(self, node): + """ + Get the ns3 NetDevice using the WifiHelper. + """ + devs = self.wifi.Install(self.phy, self.mac, node) + return devs.Get(0) + + +class Ns3WimaxNet(CoreNs3Net): + def __init__(self, *args, **kwds): + CoreNs3Net.__init__(self, *args, **kwds) + self.wimax = ns.wimax.WimaxHelper() + self.scheduler = ns.wimax.WimaxHelper.SCHED_TYPE_SIMPLE + self.phy = ns.wimax.WimaxHelper.SIMPLE_PHY_TYPE_OFDM + # base station node list + self.bsnodes = [] + + def setbasestation(self, node): + self.bsnodes.append(node) + + def getns3dev(self, node): + if node in self.bsnodes: + devtype = ns.wimax.WimaxHelper.DEVICE_TYPE_BASE_STATION + else: + devtype = ns.wimax.WimaxHelper.DEVICE_TYPE_SUBSCRIBER_STATION + nodes = ns.network.NodeContainer(node) + devs = self.wimax.Install(nodes, devtype, self.phy, self.scheduler) + if node not in self.bsnodes: + devs.Get(0).SetModulationType(ns.wimax.WimaxPhy.MODULATION_TYPE_QAM16_12) + # debug + self.wimax.EnableAscii("wimax-device-%s" % node.name, devs) + return devs.Get(0) + + @staticmethod + def ipv4netifaddr(netif): + for addr in netif.addrlist: + if ':' in addr: + # skip ipv6 + continue + ip = ns.network.Ipv4Address(addr.split('/')[0]) + mask = ns.network.Ipv4Mask('/' + addr.split('/')[1]) + return ip, mask + return None, None + + def addflow(self, node1, node2, upclass, downclass): + """ + Add a Wimax service flow between two nodes. + """ + netif1, ns3dev1 = self.findns3dev(node1) + netif2, ns3dev2 = self.findns3dev(node2) + if not netif1 or not netif2: + raise ValueError("interface not found") + addr1, mask1 = self.ipv4netifaddr(netif1) + addr2, mask2 = self.ipv4netifaddr(netif2) + clargs1 = (addr1, mask1, addr2, mask2) + downclass + clargs2 = (addr2, mask2, addr1, mask1) + upclass + clrec1 = ns.wimax.IpcsClassifierRecord(*clargs1) + clrec2 = ns.wimax.IpcsClassifierRecord(*clargs2) + ns3dev1.AddServiceFlow(self.wimax.CreateServiceFlow( + ns.wimax.ServiceFlow.SF_DIRECTION_DOWN, + ns.wimax.ServiceFlow.SF_TYPE_RTPS, clrec1) + ) + ns3dev1.AddServiceFlow(self.wimax.CreateServiceFlow( + ns.wimax.ServiceFlow.SF_DIRECTION_UP, + ns.wimax.ServiceFlow.SF_TYPE_RTPS, clrec2) + ) + ns3dev2.AddServiceFlow(self.wimax.CreateServiceFlow( + ns.wimax.ServiceFlow.SF_DIRECTION_DOWN, + ns.wimax.ServiceFlow.SF_TYPE_RTPS, clrec2) + ) + ns3dev2.AddServiceFlow(self.wimax.CreateServiceFlow( + ns.wimax.ServiceFlow.SF_DIRECTION_UP, + ns.wimax.ServiceFlow.SF_TYPE_RTPS, clrec1) + ) + + +class Ns3Session(Session): + """ + A Session that starts an ns-3 simulation thread. + """ + + def __init__(self, _id, persistent=False, duration=600): + self.duration = duration + self.nodes = ns.network.NodeContainer() + self.mobhelper = ns.mobility.MobilityHelper() + Session.__init__(self, _id) + + def run(self, vis=False): + """ + Run the ns-3 simulation and return the simulator thread. + """ + + def runthread(): + ns.core.Simulator.Stop(ns.core.Seconds(self.duration)) + logging.info("running ns-3 simulation for %d seconds", self.duration) + if vis: + try: + import visualizer + except ImportError: + logging.exception("visualizer is not available") + ns.core.Simulator.Run() + else: + visualizer.start() + else: + ns.core.Simulator.Run() + + # self.evq.run() # event queue may have WayPointMobility events + self.set_state(EventTypes.RUNTIME_STATE, send_event=True) + t = threading.Thread(target=runthread) + t.daemon = True + t.start() + return t + + def shutdown(self): + # TODO: the following line tends to segfault ns-3 (and therefore core-daemon) + ns.core.Simulator.Destroy() + Session.shutdown(self) + + def addnode(self, name): + """ + A convenience helper for Session.addobj(), for adding CoreNs3Nodes + to this session. Keeps a NodeContainer for later use. + """ + n = self.create_node(cls=CoreNs3Node, name=name) + self.nodes.Add(n) + return n + + def setupconstantmobility(self): + """ + Install a ConstantPositionMobilityModel. + """ + palloc = ns.mobility.ListPositionAllocator() + for i in xrange(self.nodes.GetN()): + (x, y, z) = ((100.0 * i) + 50, 200.0, 0.0) + palloc.Add(ns.core.Vector(x, y, z)) + node = self.nodes.Get(i) + node.position.set(x, y, z) + self.mobhelper.SetPositionAllocator(palloc) + self.mobhelper.SetMobilityModel("ns3::ConstantPositionMobilityModel") + self.mobhelper.Install(self.nodes) + + def setuprandomwalkmobility(self, bounds, time=10, speed=25.0): + """ + Set up the random walk mobility model within a bounding box. + - bounds is the max (x, y, z) boundary + - time is the number of seconds to maintain the current speed + and direction + - speed is the maximum speed, with node speed randomly chosen + from [0, speed] + """ + x, y, z = map(float, bounds) + self.mobhelper.SetPositionAllocator( + "ns3::RandomBoxPositionAllocator", + "X", + ns.core.StringValue("ns3::UniformRandomVariable[Min=0|Max=%s]" % x), + "Y", + ns.core.StringValue("ns3::UniformRandomVariable[Min=0|Max=%s]" % y), + "Z", + ns.core.StringValue("ns3::UniformRandomVariable[Min=0|Max=%s]" % z) + ) + self.mobhelper.SetMobilityModel( + "ns3::RandomWalk2dMobilityModel", + "Mode", ns.core.StringValue("Time"), + "Time", ns.core.StringValue("%ss" % time), + "Speed", + ns.core.StringValue("ns3::UniformRandomVariable[Min=0|Max=%s]" % speed), + "Bounds", ns.core.StringValue("0|%s|0|%s" % (x, y)) + ) + self.mobhelper.Install(self.nodes) + + def startns3mobility(self, refresh_ms=300): + """ + Start a thread that updates CORE nodes based on their ns-3 + positions. + """ + self.set_state(EventTypes.INSTANTIATION_STATE) + self.mobilitythread = threading.Thread( + target=self.ns3mobilitythread, + args=(refresh_ms,)) + self.mobilitythread.daemon = True + self.mobilitythread.start() + + def ns3mobilitythread(self, refresh_ms): + """ + Thread target that updates CORE nodes every refresh_ms based on + their ns-3 positions. + """ + valid_states = ( + EventTypes.RUNTIME_STATE.value, + EventTypes.INSTANTIATION_STATE.value + ) + while self.state in valid_states: + for i in xrange(self.nodes.GetN()): + node = self.nodes.Get(i) + x, y, z = node.getns3position() + if (x, y, z) == node.position.get(): + continue + # from WayPointMobility.setnodeposition(node, x, y, z) + node.position.set(x, y, z) + node_data = node.data(0) + self.broadcast_node(node_data) + self.sdt.updatenode(node.id, flags=0, x=x, y=y, z=z) + time.sleep(0.001 * refresh_ms) + + def setupmobilitytracing(self, net, filename, nodes): + """ + Start a tracing thread using the ASCII output from the ns3 + mobility helper. + """ + net.mobility = WayPointMobility(session=self, _id=net.id) + net.mobility.setendtime() + net.mobility.refresh_ms = 300 + net.mobility.empty_queue_stop = False + of = ns.network.OutputStreamWrapper(filename, filemode=0o777) + self.mobhelper.EnableAsciiAll(of) + self.mobilitytracethread = threading.Thread( + target=self.mobilitytrace, + args=(net, filename, nodes) + ) + self.mobilitytracethread.daemon = True + self.mobilitytracethread.start() + + def mobilitytrace(self, net, filename, nodes, verbose): + nodemap = {} + # move nodes to initial positions + for node in nodes: + x, y, z = node.getns3position() + net.mobility.setnodeposition(node, x, y, z) + nodemap[node.GetId()] = node + + logging.info("mobilitytrace opening '%s'", filename) + + f = None + try: + f = open(filename) + f.seek(0, 2) + + sleep = 0.001 + kickstart = True + while True: + if self.state != EventTypes.RUNTIME_STATE.value: + break + line = f.readline() + if not line: + time.sleep(sleep) + if sleep < 1.0: + sleep += 0.001 + continue + sleep = 0.001 + items = dict(x.split("=") for x in line.split()) + logging.info("trace: %s %s %s", items['node'], items['pos'], items['vel']) + x, y, z = map(float, items['pos'].split(':')) + vel = map(float, items['vel'].split(':')) + node = nodemap[int(items['node'])] + net.mobility.addwaypoint(time=0, nodenum=node.id, x=x, y=y, z=z, speed=vel) + if kickstart: + kickstart = False + self.event_loop.add_event(0, net.mobility.start) + self.event_loop.run() + else: + if net.mobility.state != net.mobility.STATE_RUNNING: + net.mobility.state = net.mobility.STATE_RUNNING + self.event_loop.add_event(0, net.mobility.runround) + except IOError: + logging.exception("mobilitytrace error opening: %s", filename) + finally: + if f: + f.close() diff --git a/ns3/setup.py b/ns3/setup.py new file mode 100644 index 00000000..d1e022f1 --- /dev/null +++ b/ns3/setup.py @@ -0,0 +1,19 @@ +import glob + +from setuptools import setup + +_EXAMPLES_DIR = "share/corens3/examples" + +setup( + name="core-ns3", + version="5.5.2", + packages=[ + "corens3", + ], + data_files=[(_EXAMPLES_DIR, glob.glob("examples/*"))], + description="Python ns-3 components of CORE", + url="https://github.com/coreemu/core", + author="Boeing Research & Technology", + license="GPLv2", + long_description="Python scripts and modules for building virtual simulated networks." +)