added type hinting to core.xml

This commit is contained in:
Blake Harnden 2020-01-14 14:56:00 -08:00
parent 02156867e2
commit 8cd8b2ae2c
3 changed files with 178 additions and 93 deletions

View file

@ -1,17 +1,30 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar
from lxml import etree
import core.nodes.base
import core.nodes.physical
from core.emane.nodes import EmaneNet
from core.emulator.data import LinkData
from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions
from core.emulator.enumerations import NodeTypes
from core.nodes.base import CoreNetworkBase
from core.nodes.base import CoreNetworkBase, NodeBase
from core.nodes.network import CtrlNet
from core.services.coreservices import CoreService
if TYPE_CHECKING:
from core.emane.emanemanager import EmaneGlobalModel
from core.emane.emanemodel import EmaneModel
from core.emulator.session import Session
EmaneModelType = Type[EmaneModel]
T = TypeVar("T")
def write_xml_file(xml_element, file_path, doctype=None):
def write_xml_file(
xml_element: etree.Element, file_path: str, doctype: str = None
) -> None:
xml_data = etree.tostring(
xml_element,
xml_declaration=True,
@ -23,27 +36,27 @@ def write_xml_file(xml_element, file_path, doctype=None):
xml_file.write(xml_data)
def get_type(element, name, _type):
def get_type(element: etree.Element, name: str, _type: Generic[T]) -> Optional[T]:
value = element.get(name)
if value is not None:
value = _type(value)
return value
def get_float(element, name):
def get_float(element: etree.Element, name: str) -> float:
return get_type(element, name, float)
def get_int(element, name):
def get_int(element: etree.Element, name: str) -> int:
return get_type(element, name, int)
def add_attribute(element, name, value):
def add_attribute(element: etree.Element, name: str, value: Any) -> None:
if value is not None:
element.set(name, str(value))
def create_interface_data(interface_element):
def create_interface_data(interface_element: etree.Element) -> InterfaceData:
interface_id = int(interface_element.get("id"))
name = interface_element.get("name")
mac = interface_element.get("mac")
@ -54,7 +67,9 @@ def create_interface_data(interface_element):
return InterfaceData(interface_id, name, mac, ip4, ip4_mask, ip6, ip6_mask)
def create_emane_config(node_id, emane_config, config):
def create_emane_config(
node_id: int, emane_config: "EmaneGlobalModel", config: Dict[str, str]
) -> etree.Element:
emane_configuration = etree.Element("emane_configuration")
add_attribute(emane_configuration, "node", node_id)
add_attribute(emane_configuration, "model", "emane")
@ -72,7 +87,9 @@ def create_emane_config(node_id, emane_config, config):
return emane_configuration
def create_emane_model_config(node_id, model, config):
def create_emane_model_config(
node_id: int, model: "EmaneModelType", config: Dict[str, str]
) -> etree.Element:
emane_element = etree.Element("emane_configuration")
add_attribute(emane_element, "node", node_id)
add_attribute(emane_element, "model", model.name)
@ -95,14 +112,14 @@ def create_emane_model_config(node_id, model, config):
return emane_element
def add_configuration(parent, name, value):
def add_configuration(parent: etree.Element, name: str, value: str) -> None:
config_element = etree.SubElement(parent, "configuration")
add_attribute(config_element, "name", name)
add_attribute(config_element, "value", value)
class NodeElement:
def __init__(self, session, node, element_name):
def __init__(self, session: "Session", node: NodeBase, element_name: str) -> None:
self.session = session
self.node = node
self.element = etree.Element(element_name)
@ -112,7 +129,7 @@ class NodeElement:
add_attribute(self.element, "canvas", node.canvas)
self.add_position()
def add_position(self):
def add_position(self) -> None:
x = self.node.position.x
y = self.node.position.y
z = self.node.position.z
@ -129,7 +146,7 @@ class NodeElement:
class ServiceElement:
def __init__(self, service):
def __init__(self, service: Type[CoreService]) -> None:
self.service = service
self.element = etree.Element("service")
add_attribute(self.element, "name", service.name)
@ -139,7 +156,7 @@ class ServiceElement:
self.add_shutdown()
self.add_files()
def add_directories(self):
def add_directories(self) -> None:
# get custom directories
directories = etree.Element("directories")
for directory in self.service.dirs:
@ -149,7 +166,7 @@ class ServiceElement:
if directories.getchildren():
self.element.append(directories)
def add_files(self):
def add_files(self) -> None:
# get custom files
file_elements = etree.Element("files")
for file_name in self.service.config_data:
@ -161,7 +178,7 @@ class ServiceElement:
if file_elements.getchildren():
self.element.append(file_elements)
def add_startup(self):
def add_startup(self) -> None:
# get custom startup
startup_elements = etree.Element("startups")
for startup in self.service.startup:
@ -171,7 +188,7 @@ class ServiceElement:
if startup_elements.getchildren():
self.element.append(startup_elements)
def add_validate(self):
def add_validate(self) -> None:
# get custom validate
validate_elements = etree.Element("validates")
for validate in self.service.validate:
@ -181,7 +198,7 @@ class ServiceElement:
if validate_elements.getchildren():
self.element.append(validate_elements)
def add_shutdown(self):
def add_shutdown(self) -> None:
# get custom shutdown
shutdown_elements = etree.Element("shutdowns")
for shutdown in self.service.shutdown:
@ -193,12 +210,12 @@ class ServiceElement:
class DeviceElement(NodeElement):
def __init__(self, session, node):
def __init__(self, session: "Session", node: NodeBase) -> None:
super().__init__(session, node, "device")
add_attribute(self.element, "type", node.type)
self.add_services()
def add_services(self):
def add_services(self) -> None:
service_elements = etree.Element("services")
for service in self.node.services:
etree.SubElement(service_elements, "service", name=service.name)
@ -208,7 +225,7 @@ class DeviceElement(NodeElement):
class NetworkElement(NodeElement):
def __init__(self, session, node):
def __init__(self, session: "Session", node: NodeBase) -> None:
super().__init__(session, node, "network")
model = getattr(self.node, "model", None)
if model:
@ -221,7 +238,7 @@ class NetworkElement(NodeElement):
add_attribute(self.element, "grekey", grekey)
self.add_type()
def add_type(self):
def add_type(self) -> None:
if self.node.apitype:
node_type = NodeTypes(self.node.apitype).name
else:
@ -230,14 +247,14 @@ class NetworkElement(NodeElement):
class CoreXmlWriter:
def __init__(self, session):
def __init__(self, session: "Session") -> None:
self.session = session
self.scenario = etree.Element("scenario")
self.networks = None
self.devices = None
self.write_session()
def write_session(self):
def write_session(self) -> None:
# generate xml content
links = self.write_nodes()
self.write_links(links)
@ -250,7 +267,7 @@ class CoreXmlWriter:
self.write_session_metadata()
self.write_default_services()
def write(self, file_name):
def write(self, file_name: str) -> None:
self.scenario.set("name", file_name)
# write out generated xml
@ -259,7 +276,7 @@ class CoreXmlWriter:
file_name, xml_declaration=True, pretty_print=True, encoding="UTF-8"
)
def write_session_origin(self):
def write_session_origin(self) -> None:
# origin: geolocation of cartesian coordinate 0,0,0
lat, lon, alt = self.session.location.refgeo
origin = etree.Element("session_origin")
@ -279,7 +296,7 @@ class CoreXmlWriter:
add_attribute(origin, "y", y)
add_attribute(origin, "z", z)
def write_session_hooks(self):
def write_session_hooks(self) -> None:
# hook scripts
hooks = etree.Element("session_hooks")
for state in sorted(self.session._hooks.keys()):
@ -292,7 +309,7 @@ class CoreXmlWriter:
if hooks.getchildren():
self.scenario.append(hooks)
def write_session_options(self):
def write_session_options(self) -> None:
option_elements = etree.Element("session_options")
options_config = self.session.options.get_configs()
if not options_config:
@ -307,7 +324,7 @@ class CoreXmlWriter:
if option_elements.getchildren():
self.scenario.append(option_elements)
def write_session_metadata(self):
def write_session_metadata(self) -> None:
# metadata
metadata_elements = etree.Element("session_metadata")
config = self.session.metadata
@ -321,7 +338,7 @@ class CoreXmlWriter:
if metadata_elements.getchildren():
self.scenario.append(metadata_elements)
def write_emane_configs(self):
def write_emane_configs(self) -> None:
emane_configurations = etree.Element("emane_configurations")
for node_id in self.session.emane.nodes():
all_configs = self.session.emane.get_all_configs(node_id)
@ -347,7 +364,7 @@ class CoreXmlWriter:
if emane_configurations.getchildren():
self.scenario.append(emane_configurations)
def write_mobility_configs(self):
def write_mobility_configs(self) -> None:
mobility_configurations = etree.Element("mobility_configurations")
for node_id in self.session.mobility.nodes():
all_configs = self.session.mobility.get_all_configs(node_id)
@ -371,7 +388,7 @@ class CoreXmlWriter:
if mobility_configurations.getchildren():
self.scenario.append(mobility_configurations)
def write_service_configs(self):
def write_service_configs(self) -> None:
service_configurations = etree.Element("service_configurations")
service_configs = self.session.services.all_configs()
for node_id, service in service_configs:
@ -382,7 +399,7 @@ class CoreXmlWriter:
if service_configurations.getchildren():
self.scenario.append(service_configurations)
def write_default_services(self):
def write_default_services(self) -> None:
node_types = etree.Element("default_services")
for node_type in self.session.services.default_services:
services = self.session.services.default_services[node_type]
@ -393,7 +410,7 @@ class CoreXmlWriter:
if node_types.getchildren():
self.scenario.append(node_types)
def write_nodes(self):
def write_nodes(self) -> List[LinkData]:
self.networks = etree.SubElement(self.scenario, "networks")
self.devices = etree.SubElement(self.scenario, "devices")
@ -416,7 +433,7 @@ class CoreXmlWriter:
return links
def write_network(self, node):
def write_network(self, node: NodeBase) -> None:
# ignore p2p and other nodes that are not part of the api
if not node.apitype:
return
@ -424,7 +441,7 @@ class CoreXmlWriter:
network = NetworkElement(self.session, node)
self.networks.append(network.element)
def write_links(self, links):
def write_links(self, links: List[LinkData]) -> None:
link_elements = etree.Element("links")
# add link data
for link_data in links:
@ -438,13 +455,21 @@ class CoreXmlWriter:
if link_elements.getchildren():
self.scenario.append(link_elements)
def write_device(self, node):
def write_device(self, node: NodeBase) -> None:
device = DeviceElement(self.session, node)
self.devices.append(device.element)
def create_interface_element(
self, element_name, node_id, interface_id, mac, ip4, ip4_mask, ip6, ip6_mask
):
self,
element_name: str,
node_id: int,
interface_id: int,
mac: str,
ip4: str,
ip4_mask: int,
ip6: str,
ip6_mask: int,
) -> etree.Element:
interface = etree.Element(element_name)
node = self.session.get_node(node_id)
interface_name = None
@ -467,7 +492,7 @@ class CoreXmlWriter:
return interface
def create_link_element(self, link_data):
def create_link_element(self, link_data: LinkData) -> etree.Element:
link_element = etree.Element("link")
add_attribute(link_element, "node_one", link_data.node1_id)
add_attribute(link_element, "node_two", link_data.node2_id)
@ -525,11 +550,11 @@ class CoreXmlWriter:
class CoreXmlReader:
def __init__(self, session):
def __init__(self, session: "Session") -> None:
self.session = session
self.scenario = None
def read(self, file_name):
def read(self, file_name: str) -> None:
xml_tree = etree.parse(file_name)
self.scenario = xml_tree.getroot()
@ -545,7 +570,7 @@ class CoreXmlReader:
self.read_nodes()
self.read_links()
def read_default_services(self):
def read_default_services(self) -> None:
default_services = self.scenario.find("default_services")
if default_services is None:
return
@ -560,7 +585,7 @@ class CoreXmlReader:
)
self.session.services.default_services[node_type] = services
def read_session_metadata(self):
def read_session_metadata(self) -> None:
session_metadata = self.scenario.find("session_metadata")
if session_metadata is None:
return
@ -573,7 +598,7 @@ class CoreXmlReader:
logging.info("reading session metadata: %s", configs)
self.session.metadata = configs
def read_session_options(self):
def read_session_options(self) -> None:
session_options = self.scenario.find("session_options")
if session_options is None:
return
@ -586,7 +611,7 @@ class CoreXmlReader:
logging.info("reading session options: %s", configs)
self.session.options.set_configs(configs)
def read_session_hooks(self):
def read_session_hooks(self) -> None:
session_hooks = self.scenario.find("session_hooks")
if session_hooks is None:
return
@ -601,7 +626,7 @@ class CoreXmlReader:
hook_type, file_name=name, source_name=None, data=data
)
def read_session_origin(self):
def read_session_origin(self) -> None:
session_origin = self.scenario.find("session_origin")
if session_origin is None:
return
@ -625,7 +650,7 @@ class CoreXmlReader:
logging.info("reading session reference xyz: %s, %s, %s", x, y, z)
self.session.location.refxyz = (x, y, z)
def read_service_configs(self):
def read_service_configs(self) -> None:
service_configurations = self.scenario.find("service_configurations")
if service_configurations is None:
return
@ -669,7 +694,7 @@ class CoreXmlReader:
files.add(name)
service.configs = tuple(files)
def read_emane_configs(self):
def read_emane_configs(self) -> None:
emane_configurations = self.scenario.find("emane_configurations")
if emane_configurations is None:
return
@ -702,7 +727,7 @@ class CoreXmlReader:
)
self.session.emane.set_model_config(node_id, model_name, configs)
def read_mobility_configs(self):
def read_mobility_configs(self) -> None:
mobility_configurations = self.scenario.find("mobility_configurations")
if mobility_configurations is None:
return
@ -722,7 +747,7 @@ class CoreXmlReader:
)
self.session.mobility.set_model_config(node_id, model_name, configs)
def read_nodes(self):
def read_nodes(self) -> None:
device_elements = self.scenario.find("devices")
if device_elements is not None:
for device_element in device_elements.iterchildren():
@ -733,7 +758,7 @@ class CoreXmlReader:
for network_element in network_elements.iterchildren():
self.read_network(network_element)
def read_device(self, device_element):
def read_device(self, device_element: etree.Element) -> None:
node_id = get_int(device_element, "id")
name = device_element.get("name")
model = device_element.get("type")
@ -759,7 +784,7 @@ class CoreXmlReader:
logging.info("reading node id(%s) model(%s) name(%s)", node_id, model, name)
self.session.add_node(_id=node_id, options=options)
def read_network(self, network_element):
def read_network(self, network_element: etree.Element) -> None:
node_id = get_int(network_element, "id")
name = network_element.get("name")
node_type = NodeTypes[network_element.get("type")]
@ -783,7 +808,7 @@ class CoreXmlReader:
)
self.session.add_node(_type=node_type, _id=node_id, options=options)
def read_links(self):
def read_links(self) -> None:
link_elements = self.scenario.find("links")
if link_elements is None:
return