diff --git a/daemon/core/api/grpc/events.py b/daemon/core/api/grpc/events.py index 172cec82..cf98a9a2 100644 --- a/daemon/core/api/grpc/events.py +++ b/daemon/core/api/grpc/events.py @@ -102,7 +102,7 @@ def handle_session_event(event: EventData) -> core_pb2.SessionEvent: event_time = float(event_time) return core_pb2.SessionEvent( node_id=event.node, - event=event.event_type, + event=event.event_type.value, name=event.name, data=event.data, time=event_time, diff --git a/daemon/core/api/grpc/server.py b/daemon/core/api/grpc/server.py index c71800f4..4c8d6640 100644 --- a/daemon/core/api/grpc/server.py +++ b/daemon/core/api/grpc/server.py @@ -173,7 +173,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): # add all hooks for hook in request.hooks: - session.add_hook(hook.state, hook.file, None, hook.data) + state = EventTypes(hook.state) + session.add_hook(state, hook.file, None, hook.data) # create nodes _, exceptions = grpcutils.create_nodes(session, request.nodes) @@ -279,7 +280,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.location.setrefgeo(47.57917, -122.13232, 2.0) session.location.refscale = 150000.0 return core_pb2.CreateSessionResponse( - session_id=session.id, state=session.state + session_id=session.id, state=session.state.value ) def DeleteSession( @@ -312,7 +313,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session = self.coreemu.sessions[session_id] session_summary = core_pb2.SessionSummary( id=session_id, - state=session.state, + state=session.state.value, nodes=session.get_node_count(), file=session.file_name, ) @@ -521,7 +522,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): node_links = get_links(session, node) links.extend(node_links) - session_proto = core_pb2.Session(state=session.state, nodes=nodes, links=links) + session_proto = core_pb2.Session( + state=session.state.value, nodes=nodes, links=links + ) return core_pb2.GetSessionResponse(session=session_proto) def AddSessionServer( @@ -896,7 +899,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): for state in session._hooks: state_hooks = session._hooks[state] for file_name, file_data in state_hooks: - hook = core_pb2.Hook(state=state, file=file_name, data=file_data) + hook = core_pb2.Hook(state=state.value, file=file_name, data=file_data) hooks.append(hook) return core_pb2.GetHooksResponse(hooks=hooks) @@ -913,7 +916,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): logging.debug("add hook: %s", request) session = self.get_session(request.session_id, context) hook = request.hook - session.add_hook(hook.state, hook.file, None, hook.data) + state = EventTypes(hook.state) + session.add_hook(state, hook.file, None, hook.data) return core_pb2.AddHookResponse(result=True) def GetMobilityConfigs( @@ -1267,7 +1271,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer): session.mobility.set_model_config( wlan_config.node_id, BasicRangeModel.name, wlan_config.config ) - if session.state == EventTypes.RUNTIME_STATE.value: + if session.state == EventTypes.RUNTIME_STATE: node = self.get_node(session, wlan_config.node_id, context) node.updatemodel(wlan_config.config) return core_pb2.SetWlanConfigResponse(result=True) diff --git a/daemon/core/api/tlv/corehandlers.py b/daemon/core/api/tlv/corehandlers.py index 1f3b24e9..09ac2444 100644 --- a/daemon/core/api/tlv/corehandlers.py +++ b/daemon/core/api/tlv/corehandlers.py @@ -228,7 +228,7 @@ class CoreHandler(socketserver.BaseRequestHandler): coreapi.CoreEventTlv, [ (EventTlvs.NODE, event_data.node), - (EventTlvs.TYPE, event_data.event_type), + (EventTlvs.TYPE, event_data.event_type.value), (EventTlvs.NAME, event_data.name), (EventTlvs.DATA, event_data.data), (EventTlvs.TIME, event_data.time), @@ -723,7 +723,7 @@ class CoreHandler(socketserver.BaseRequestHandler): if message.flags & MessageFlags.STRING.value: self.node_status_request[node.id] = True - if self.session.state == EventTypes.RUNTIME_STATE.value: + if self.session.state == EventTypes.RUNTIME_STATE: self.send_node_emulation_id(node.id) elif message.flags & MessageFlags.DELETE.value: with self._shutdown_lock: @@ -966,7 +966,7 @@ class CoreHandler(socketserver.BaseRequestHandler): retries = 10 # wait for session to enter RUNTIME state, to prevent GUI from # connecting while nodes are still being instantiated - while session.state != EventTypes.RUNTIME_STATE.value: + while session.state != EventTypes.RUNTIME_STATE: logging.debug( "waiting for session %d to enter RUNTIME state", sid ) @@ -1375,7 +1375,7 @@ class CoreHandler(socketserver.BaseRequestHandler): parsed_config = ConfigShim.str_to_dict(values_str) self.session.mobility.set_model_config(node_id, object_name, parsed_config) - if self.session.state == EventTypes.RUNTIME_STATE.value and parsed_config: + if self.session.state == EventTypes.RUNTIME_STATE and parsed_config: try: node = self.session.get_node(node_id) if object_name == BasicRangeModel.name: @@ -1502,6 +1502,7 @@ class CoreHandler(socketserver.BaseRequestHandler): logging.error("error setting hook having state '%s'", state) return () state = int(state) + state = EventTypes(state) self.session.add_hook(state, file_name, source_name, data) return () @@ -1538,9 +1539,11 @@ class CoreHandler(socketserver.BaseRequestHandler): :return: reply messages :raises core.CoreError: when event type <= SHUTDOWN_STATE and not a known node id """ + event_type_value = message.get_tlv(EventTlvs.TYPE.value) + event_type = EventTypes(event_type_value) event_data = EventData( node=message.get_tlv(EventTlvs.NODE.value), - event_type=message.get_tlv(EventTlvs.TYPE.value), + event_type=event_type, name=message.get_tlv(EventTlvs.NAME.value), data=message.get_tlv(EventTlvs.DATA.value), time=message.get_tlv(EventTlvs.TIME.value), @@ -1549,7 +1552,6 @@ class CoreHandler(socketserver.BaseRequestHandler): if event_data.event_type is None: raise NotImplementedError("Event message missing event type") - event_type = EventTypes(event_data.event_type) node_id = event_data.node logging.debug("handling event %s at %s", event_type.name, time.ctime()) @@ -1667,25 +1669,19 @@ class CoreHandler(socketserver.BaseRequestHandler): unknown.append(service_name) continue - if ( - event_type == EventTypes.STOP.value - or event_type == EventTypes.RESTART.value - ): + if event_type in [EventTypes.STOP, EventTypes.RESTART]: status = self.session.services.stop_service(node, service) if status: fail += f"Stop {service.name}," - if ( - event_type == EventTypes.START.value - or event_type == EventTypes.RESTART.value - ): + if event_type in [EventTypes.START, EventTypes.RESTART]: status = self.session.services.startup_service(node, service) if status: fail += f"Start ({service.name})," - if event_type == EventTypes.PAUSE.value: + if event_type == EventTypes.PAUSE: status = self.session.services.validate_service(node, service) if status: fail += f"{service.name}," - if event_type == EventTypes.RECONFIGURE.value: + if event_type == EventTypes.RECONFIGURE: self.session.services.service_reconfigure(node, service) fail_data = "" @@ -2052,7 +2048,7 @@ class CoreUdpHandler(CoreHandler): current_session = self.server.mainserver.coreemu.sessions[session_id] current_node_count = current_session.get_node_count() if ( - current_session.state == EventTypes.RUNTIME_STATE.value + current_session.state == EventTypes.RUNTIME_STATE and current_node_count > node_count ): node_count = current_node_count diff --git a/daemon/core/emulator/enumerations.py b/daemon/core/emulator/enumerations.py index f426774e..44f60877 100644 --- a/daemon/core/emulator/enumerations.py +++ b/daemon/core/emulator/enumerations.py @@ -293,6 +293,9 @@ class EventTypes(Enum): RECONFIGURE = 14 INSTANTIATION_COMPLETE = 15 + def should_start(self) -> bool: + return self.value > self.DEFINITION_STATE.value + class SessionTlvs(Enum): """ diff --git a/daemon/core/emulator/session.py b/daemon/core/emulator/session.py index b0e44cb5..cd78af8f 100644 --- a/daemon/core/emulator/session.py +++ b/daemon/core/emulator/session.py @@ -112,8 +112,7 @@ class Session: self.nodes = {} self._nodes_lock = threading.Lock() - # TODO: should the default state be definition? - self.state = EventTypes.NONE.value + self.state = EventTypes.DEFINITION_STATE self._state_time = time.monotonic() self._state_file = os.path.join(self.session_dir, "state") @@ -121,7 +120,7 @@ class Session: self._hooks = {} self._state_hooks = {} self.add_state_hook( - state=EventTypes.RUNTIME_STATE.value, hook=self.runtime_state_hook + state=EventTypes.RUNTIME_STATE, hook=self.runtime_state_hook ) # handlers for broadcasting information @@ -345,7 +344,7 @@ class Session: node_one.name, node_two.name, ) - start = self.state > EventTypes.DEFINITION_STATE.value + start = self.state.should_start() net_one = self.create_node(cls=PtpNet, start=start) # node to network @@ -680,7 +679,7 @@ class Session: node_class = _cls # set node start based on current session state, override and check when rj45 - start = self.state > EventTypes.DEFINITION_STATE.value + start = self.state.should_start() enable_rj45 = self.options.get_config("enablerj45") == "1" if _type == NodeTypes.RJ45 and not enable_rj45: start = False @@ -755,7 +754,7 @@ class Session: # boot nodes after runtime, CoreNodes, Physical, and RJ45 are all nodes is_boot_node = isinstance(node, CoreNodeBase) and not isinstance(node, Rj45Node) - if self.state == EventTypes.RUNTIME_STATE.value and is_boot_node: + if self.state == EventTypes.RUNTIME_STATE and is_boot_node: self.write_nodes() self.add_remove_control_interface(node=node, remove=False) self.services.boot_services(node) @@ -850,10 +849,7 @@ class Session: :return: True if active, False otherwise """ - result = self.state in { - EventTypes.RUNTIME_STATE.value, - EventTypes.DATACOLLECT_STATE.value, - } + result = self.state in {EventTypes.RUNTIME_STATE, EventTypes.DATACOLLECT_STATE} logging.info("session(%s) checking if active: %s", self.id, result) return result @@ -894,7 +890,9 @@ class Session: """ CoreXmlWriter(self).write(file_name) - def add_hook(self, state: int, file_name: str, source_name: str, data: str) -> None: + def add_hook( + self, state: EventTypes, file_name: str, source_name: str, data: str + ) -> None: """ Store a hook from a received file message. @@ -904,9 +902,17 @@ class Session: :param data: hook data :return: nothing """ - # hack to conform with old logic until updated - state = f":{state}" - self.set_hook(state, file_name, source_name, data) + logging.info( + "setting state hook: %s - %s from %s", state, file_name, source_name + ) + hook = file_name, data + state_hooks = self._hooks.setdefault(state, []) + state_hooks.append(hook) + + # immediately run a hook if it is in the current state + if self.state == state: + logging.info("immediately running new state hook") + self.run_hook(hook) def add_node_file( self, node_id: int, source_name: str, file_name: str, data: str @@ -1071,10 +1077,8 @@ class Session: :param send_event: if true, generate core API event messages :return: nothing """ - state_value = state.value state_name = state.name - - if self.state == state_value: + if self.state == state: logging.info( "session(%s) is already in state: %s, skipping change", self.id, @@ -1082,33 +1086,32 @@ class Session: ) return - self.state = state_value + self.state = state self._state_time = time.monotonic() logging.info("changing session(%s) to state %s", self.id, state_name) - - self.write_state(state_value) - self.run_hooks(state_value) - self.run_state_hooks(state_value) + self.write_state(state) + self.run_hooks(state) + self.run_state_hooks(state) if send_event: - event_data = EventData(event_type=state_value, time=str(time.monotonic())) + event_data = EventData(event_type=state, time=str(time.monotonic())) self.broadcast_event(event_data) - def write_state(self, state: int) -> None: + def write_state(self, state: EventTypes) -> None: """ - Write the current state to a state file in the session dir. + Write the state to a state file in the session dir. :param state: state to write to file :return: nothing """ try: state_file = open(self._state_file, "w") - state_file.write(f"{state} {EventTypes(self.state).name}\n") + state_file.write(f"{state.value} {state.name}\n") state_file.close() except IOError: - logging.exception("error writing state file: %s", state) + logging.exception("error writing state file: %s", state.name) - def run_hooks(self, state: int) -> None: + def run_hooks(self, state: EventTypes) -> None: """ Run hook scripts upon changing states. If hooks is not specified, run all hooks in the given state. @@ -1212,7 +1215,7 @@ class Session: except (OSError, subprocess.CalledProcessError): logging.exception("error running hook: %s", file_name) - def run_state_hooks(self, state: int) -> None: + def run_state_hooks(self, state: EventTypes) -> None: """ Run state hooks. @@ -1223,16 +1226,17 @@ class Session: try: hook(state) except Exception: - state_name = EventTypes(self.state).name message = ( - f"exception occured when running {state_name} state hook: {hook}" + f"exception occured when running {state.name} state hook: {hook}" ) logging.exception(message) self.exception( ExceptionLevels.ERROR, "Session.run_state_hooks", None, message ) - def add_state_hook(self, state: int, hook: Callable[[int], None]) -> None: + def add_state_hook( + self, state: EventTypes, hook: Callable[[EventTypes], None] + ) -> None: """ Add a state hook. @@ -1259,14 +1263,14 @@ class Session: hooks = self._state_hooks.setdefault(state, []) hooks.remove(hook) - def runtime_state_hook(self, state: int) -> None: + def runtime_state_hook(self, state: EventTypes) -> None: """ Runtime state hook check. :param state: state to check :return: nothing """ - if state == EventTypes.RUNTIME_STATE.value: + if state == EventTypes.RUNTIME_STATE: self.emane.poststartup() # create session deployed xml @@ -1510,7 +1514,7 @@ class Session: self.mobility.startup() # notify listeners that instantiation is complete - event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE.value) + event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE) self.broadcast_event(event) # assume either all nodes have booted already, or there are some @@ -1553,9 +1557,9 @@ class Session: logging.debug( "session(%s) checking if not in runtime state, current state: %s", self.id, - EventTypes(self.state).name, + self.state.name, ) - if self.state == EventTypes.RUNTIME_STATE.value: + if self.state == EventTypes.RUNTIME_STATE: logging.info("valid runtime state found, returning") return diff --git a/daemon/core/location/mobility.py b/daemon/core/location/mobility.py index 55af58d9..7b0ec8ac 100644 --- a/daemon/core/location/mobility.py +++ b/daemon/core/location/mobility.py @@ -142,17 +142,11 @@ class MobilityManager(ModelManager): ) continue - if ( - event_type == EventTypes.STOP.value - or event_type == EventTypes.RESTART.value - ): + if event_type in [EventTypes.STOP, EventTypes.RESTART]: model.stop(move_initial=True) - if ( - event_type == EventTypes.START.value - or event_type == EventTypes.RESTART.value - ): + if event_type in [EventTypes.START, EventTypes.RESTART]: model.start() - if event_type == EventTypes.PAUSE.value: + if event_type == EventTypes.PAUSE: model.pause() def sendevent(self, model: "WayPointMobility") -> None: @@ -163,13 +157,13 @@ class MobilityManager(ModelManager): :param model: mobility model to send event for :return: nothing """ - event_type = EventTypes.NONE.value + event_type = EventTypes.NONE if model.state == model.STATE_STOPPED: - event_type = EventTypes.STOP.value + event_type = EventTypes.STOP elif model.state == model.STATE_RUNNING: - event_type = EventTypes.START.value + event_type = EventTypes.START elif model.state == model.STATE_PAUSED: - event_type = EventTypes.PAUSE.value + event_type = EventTypes.PAUSE start_time = int(model.lasttime - model.timezero) end_time = int(model.endtime) diff --git a/daemon/core/plugins/sdt.py b/daemon/core/plugins/sdt.py index 1ccf40a5..658ee1e3 100644 --- a/daemon/core/plugins/sdt.py +++ b/daemon/core/plugins/sdt.py @@ -100,7 +100,7 @@ class Sdt: return False if self.connected: return True - if self.session.state == EventTypes.SHUTDOWN_STATE.value: + if self.session.state == EventTypes.SHUTDOWN_STATE: return False self.seturl() diff --git a/daemon/core/xml/corexml.py b/daemon/core/xml/corexml.py index 8eab98c2..074a6913 100644 --- a/daemon/core/xml/corexml.py +++ b/daemon/core/xml/corexml.py @@ -8,7 +8,7 @@ import core.nodes.physical from core.emane.nodes import EmaneNet from core.emulator.data import LinkData from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions -from core.emulator.enumerations import NodeTypes +from core.emulator.enumerations import EventTypes, NodeTypes from core.nodes.base import CoreNetworkBase, CoreNodeBase, NodeBase from core.nodes.docker import DockerNode from core.nodes.lxd import LxcNode @@ -324,7 +324,7 @@ class CoreXmlWriter: for file_name, data in self.session._hooks[state]: hook = etree.SubElement(hooks, "hook") add_attribute(hook, "name", file_name) - add_attribute(hook, "state", state) + add_attribute(hook, "state", state.value) hook.text = data if hooks.getchildren(): @@ -666,13 +666,11 @@ class CoreXmlReader: for hook in session_hooks.iterchildren(): name = hook.get("name") - state = hook.get("state") + state = get_int(hook, "state") + state = EventTypes(state) data = hook.text - hook_type = f"hook:{state}" logging.info("reading hook: state(%s) name(%s)", state, name) - self.session.set_hook( - hook_type, file_name=name, source_name=None, data=data - ) + self.session.add_hook(state, name, None, data) def read_session_origin(self) -> None: session_origin = self.scenario.find("session_origin") diff --git a/daemon/data/logging.conf b/daemon/data/logging.conf index 7f3d496f..46de6e92 100644 --- a/daemon/data/logging.conf +++ b/daemon/data/logging.conf @@ -14,7 +14,7 @@ } }, "root": { - "level": "INFO", + "level": "DEBUG", "handlers": ["console"] } } diff --git a/daemon/tests/test_grpc.py b/daemon/tests/test_grpc.py index d26c46e4..557265dc 100644 --- a/daemon/tests/test_grpc.py +++ b/daemon/tests/test_grpc.py @@ -127,7 +127,7 @@ class TestGrpc: assert wlan_node.id in session.nodes assert session.nodes[node_one.id].netif(0) is not None assert session.nodes[node_two.id].netif(0) is not None - hook_file, hook_data = session._hooks[core_pb2.SessionState.RUNTIME][0] + hook_file, hook_data = session._hooks[EventTypes.RUNTIME_STATE][0] assert hook_file == hook.file assert hook_data == hook.data assert session.location.refxyz == (location_x, location_y, location_z) @@ -169,7 +169,7 @@ class TestGrpc: assert isinstance(response.state, int) session = grpc_server.coreemu.sessions.get(response.session_id) assert session is not None - assert session.state == response.state + assert session.state == EventTypes(response.state) if session_id is not None: assert response.session_id == session_id assert session.id == session_id @@ -341,7 +341,7 @@ class TestGrpc: # then assert response.result is True - assert session.state == core_pb2.SessionState.DEFINITION + assert session.state == EventTypes.DEFINITION_STATE def test_add_node(self, grpc_server): # given @@ -447,7 +447,7 @@ class TestGrpc: session = grpc_server.coreemu.create_session() file_name = "test" file_data = "echo hello" - session.add_hook(EventTypes.RUNTIME_STATE.value, file_name, None, file_data) + session.add_hook(EventTypes.RUNTIME_STATE, file_name, None, file_data) # then with client.context_connect(): @@ -1065,7 +1065,7 @@ class TestGrpc: client.events(session.id, handle_event) time.sleep(0.1) event = EventData( - event_type=EventTypes.RUNTIME_STATE.value, time=str(time.monotonic()) + event_type=EventTypes.RUNTIME_STATE, time=str(time.monotonic()) ) session.broadcast_event(event) diff --git a/daemon/tests/test_gui.py b/daemon/tests/test_gui.py index a47aba75..8bf6f4b7 100644 --- a/daemon/tests/test_gui.py +++ b/daemon/tests/test_gui.py @@ -376,14 +376,14 @@ class TestGui: assert len(coretlv.coreemu.sessions) == 0 def test_file_hook_add(self, coretlv): - state = EventTypes.DATACOLLECT_STATE.value + state = EventTypes.DATACOLLECT_STATE assert coretlv.session._hooks.get(state) is None file_name = "test.sh" file_data = "echo hello" message = coreapi.CoreFileMessage.create( MessageFlags.ADD.value, [ - (FileTlvs.TYPE, f"hook:{state}"), + (FileTlvs.TYPE, f"hook:{state.value}"), (FileTlvs.NAME, file_name), (FileTlvs.DATA, file_data), ], @@ -514,7 +514,7 @@ class TestGui: coretlv.handle_message(message) - assert coretlv.session.state == state.value + assert coretlv.session.state == state def test_event_schedule(self, coretlv): coretlv.session.add_event = mock.MagicMock() diff --git a/daemon/tests/test_xml.py b/daemon/tests/test_xml.py index 496623a6..ebbb6b1e 100644 --- a/daemon/tests/test_xml.py +++ b/daemon/tests/test_xml.py @@ -3,7 +3,7 @@ from xml.etree import ElementTree import pytest from core.emulator.emudata import LinkOptions, NodeOptions -from core.emulator.enumerations import NodeTypes +from core.emulator.enumerations import EventTypes, NodeTypes from core.errors import CoreError from core.location.mobility import BasicRangeModel from core.services.utility import SshService @@ -20,7 +20,8 @@ class TestXml: # create hook file_name = "runtime_hook.sh" data = "#!/bin/sh\necho hello" - session.set_hook("hook:4", file_name, None, data) + state = EventTypes.RUNTIME_STATE + session.add_hook(state, file_name, None, data) # save xml xml_file = tmpdir.join("session.xml") @@ -38,7 +39,7 @@ class TestXml: session.open_xml(file_path, start=True) # verify nodes have been recreated - runtime_hooks = session._hooks.get(4) + runtime_hooks = session._hooks.get(state) assert runtime_hooks runtime_hook = runtime_hooks[0] assert file_name == runtime_hook[0]