changes for sessions to use EventTypes for state/hooks directly

This commit is contained in:
Blake Harnden 2020-03-06 22:35:23 -08:00
parent 0e299d5af4
commit 1e8d1ecd9f
12 changed files with 95 additions and 95 deletions

View file

@ -102,7 +102,7 @@ def handle_session_event(event: EventData) -> core_pb2.SessionEvent:
event_time = float(event_time) event_time = float(event_time)
return core_pb2.SessionEvent( return core_pb2.SessionEvent(
node_id=event.node, node_id=event.node,
event=event.event_type, event=event.event_type.value,
name=event.name, name=event.name,
data=event.data, data=event.data,
time=event_time, time=event_time,

View file

@ -173,7 +173,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
# add all hooks # add all hooks
for hook in request.hooks: for hook in request.hooks:
session.add_hook(hook.state, hook.file, None, hook.data) state = EventTypes(hook.state)
session.add_hook(state, hook.file, None, hook.data)
# create nodes # create nodes
_, exceptions = grpcutils.create_nodes(session, request.nodes) _, exceptions = grpcutils.create_nodes(session, request.nodes)
@ -279,7 +280,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session.location.setrefgeo(47.57917, -122.13232, 2.0) session.location.setrefgeo(47.57917, -122.13232, 2.0)
session.location.refscale = 150000.0 session.location.refscale = 150000.0
return core_pb2.CreateSessionResponse( return core_pb2.CreateSessionResponse(
session_id=session.id, state=session.state session_id=session.id, state=session.state.value
) )
def DeleteSession( def DeleteSession(
@ -312,7 +313,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session = self.coreemu.sessions[session_id] session = self.coreemu.sessions[session_id]
session_summary = core_pb2.SessionSummary( session_summary = core_pb2.SessionSummary(
id=session_id, id=session_id,
state=session.state, state=session.state.value,
nodes=session.get_node_count(), nodes=session.get_node_count(),
file=session.file_name, file=session.file_name,
) )
@ -521,7 +522,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
node_links = get_links(session, node) node_links = get_links(session, node)
links.extend(node_links) links.extend(node_links)
session_proto = core_pb2.Session(state=session.state, nodes=nodes, links=links) session_proto = core_pb2.Session(
state=session.state.value, nodes=nodes, links=links
)
return core_pb2.GetSessionResponse(session=session_proto) return core_pb2.GetSessionResponse(session=session_proto)
def AddSessionServer( def AddSessionServer(
@ -896,7 +899,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
for state in session._hooks: for state in session._hooks:
state_hooks = session._hooks[state] state_hooks = session._hooks[state]
for file_name, file_data in state_hooks: for file_name, file_data in state_hooks:
hook = core_pb2.Hook(state=state, file=file_name, data=file_data) hook = core_pb2.Hook(state=state.value, file=file_name, data=file_data)
hooks.append(hook) hooks.append(hook)
return core_pb2.GetHooksResponse(hooks=hooks) return core_pb2.GetHooksResponse(hooks=hooks)
@ -913,7 +916,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
logging.debug("add hook: %s", request) logging.debug("add hook: %s", request)
session = self.get_session(request.session_id, context) session = self.get_session(request.session_id, context)
hook = request.hook hook = request.hook
session.add_hook(hook.state, hook.file, None, hook.data) state = EventTypes(hook.state)
session.add_hook(state, hook.file, None, hook.data)
return core_pb2.AddHookResponse(result=True) return core_pb2.AddHookResponse(result=True)
def GetMobilityConfigs( def GetMobilityConfigs(
@ -1267,7 +1271,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session.mobility.set_model_config( session.mobility.set_model_config(
wlan_config.node_id, BasicRangeModel.name, wlan_config.config wlan_config.node_id, BasicRangeModel.name, wlan_config.config
) )
if session.state == EventTypes.RUNTIME_STATE.value: if session.state == EventTypes.RUNTIME_STATE:
node = self.get_node(session, wlan_config.node_id, context) node = self.get_node(session, wlan_config.node_id, context)
node.updatemodel(wlan_config.config) node.updatemodel(wlan_config.config)
return core_pb2.SetWlanConfigResponse(result=True) return core_pb2.SetWlanConfigResponse(result=True)

View file

@ -228,7 +228,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
coreapi.CoreEventTlv, coreapi.CoreEventTlv,
[ [
(EventTlvs.NODE, event_data.node), (EventTlvs.NODE, event_data.node),
(EventTlvs.TYPE, event_data.event_type), (EventTlvs.TYPE, event_data.event_type.value),
(EventTlvs.NAME, event_data.name), (EventTlvs.NAME, event_data.name),
(EventTlvs.DATA, event_data.data), (EventTlvs.DATA, event_data.data),
(EventTlvs.TIME, event_data.time), (EventTlvs.TIME, event_data.time),
@ -723,7 +723,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
if message.flags & MessageFlags.STRING.value: if message.flags & MessageFlags.STRING.value:
self.node_status_request[node.id] = True self.node_status_request[node.id] = True
if self.session.state == EventTypes.RUNTIME_STATE.value: if self.session.state == EventTypes.RUNTIME_STATE:
self.send_node_emulation_id(node.id) self.send_node_emulation_id(node.id)
elif message.flags & MessageFlags.DELETE.value: elif message.flags & MessageFlags.DELETE.value:
with self._shutdown_lock: with self._shutdown_lock:
@ -966,7 +966,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
retries = 10 retries = 10
# wait for session to enter RUNTIME state, to prevent GUI from # wait for session to enter RUNTIME state, to prevent GUI from
# connecting while nodes are still being instantiated # connecting while nodes are still being instantiated
while session.state != EventTypes.RUNTIME_STATE.value: while session.state != EventTypes.RUNTIME_STATE:
logging.debug( logging.debug(
"waiting for session %d to enter RUNTIME state", sid "waiting for session %d to enter RUNTIME state", sid
) )
@ -1375,7 +1375,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
parsed_config = ConfigShim.str_to_dict(values_str) parsed_config = ConfigShim.str_to_dict(values_str)
self.session.mobility.set_model_config(node_id, object_name, parsed_config) self.session.mobility.set_model_config(node_id, object_name, parsed_config)
if self.session.state == EventTypes.RUNTIME_STATE.value and parsed_config: if self.session.state == EventTypes.RUNTIME_STATE and parsed_config:
try: try:
node = self.session.get_node(node_id) node = self.session.get_node(node_id)
if object_name == BasicRangeModel.name: if object_name == BasicRangeModel.name:
@ -1502,6 +1502,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
logging.error("error setting hook having state '%s'", state) logging.error("error setting hook having state '%s'", state)
return () return ()
state = int(state) state = int(state)
state = EventTypes(state)
self.session.add_hook(state, file_name, source_name, data) self.session.add_hook(state, file_name, source_name, data)
return () return ()
@ -1538,9 +1539,11 @@ class CoreHandler(socketserver.BaseRequestHandler):
:return: reply messages :return: reply messages
:raises core.CoreError: when event type <= SHUTDOWN_STATE and not a known node id :raises core.CoreError: when event type <= SHUTDOWN_STATE and not a known node id
""" """
event_type_value = message.get_tlv(EventTlvs.TYPE.value)
event_type = EventTypes(event_type_value)
event_data = EventData( event_data = EventData(
node=message.get_tlv(EventTlvs.NODE.value), node=message.get_tlv(EventTlvs.NODE.value),
event_type=message.get_tlv(EventTlvs.TYPE.value), event_type=event_type,
name=message.get_tlv(EventTlvs.NAME.value), name=message.get_tlv(EventTlvs.NAME.value),
data=message.get_tlv(EventTlvs.DATA.value), data=message.get_tlv(EventTlvs.DATA.value),
time=message.get_tlv(EventTlvs.TIME.value), time=message.get_tlv(EventTlvs.TIME.value),
@ -1549,7 +1552,6 @@ class CoreHandler(socketserver.BaseRequestHandler):
if event_data.event_type is None: if event_data.event_type is None:
raise NotImplementedError("Event message missing event type") raise NotImplementedError("Event message missing event type")
event_type = EventTypes(event_data.event_type)
node_id = event_data.node node_id = event_data.node
logging.debug("handling event %s at %s", event_type.name, time.ctime()) logging.debug("handling event %s at %s", event_type.name, time.ctime())
@ -1667,25 +1669,19 @@ class CoreHandler(socketserver.BaseRequestHandler):
unknown.append(service_name) unknown.append(service_name)
continue continue
if ( if event_type in [EventTypes.STOP, EventTypes.RESTART]:
event_type == EventTypes.STOP.value
or event_type == EventTypes.RESTART.value
):
status = self.session.services.stop_service(node, service) status = self.session.services.stop_service(node, service)
if status: if status:
fail += f"Stop {service.name}," fail += f"Stop {service.name},"
if ( if event_type in [EventTypes.START, EventTypes.RESTART]:
event_type == EventTypes.START.value
or event_type == EventTypes.RESTART.value
):
status = self.session.services.startup_service(node, service) status = self.session.services.startup_service(node, service)
if status: if status:
fail += f"Start ({service.name})," fail += f"Start ({service.name}),"
if event_type == EventTypes.PAUSE.value: if event_type == EventTypes.PAUSE:
status = self.session.services.validate_service(node, service) status = self.session.services.validate_service(node, service)
if status: if status:
fail += f"{service.name}," fail += f"{service.name},"
if event_type == EventTypes.RECONFIGURE.value: if event_type == EventTypes.RECONFIGURE:
self.session.services.service_reconfigure(node, service) self.session.services.service_reconfigure(node, service)
fail_data = "" fail_data = ""
@ -2052,7 +2048,7 @@ class CoreUdpHandler(CoreHandler):
current_session = self.server.mainserver.coreemu.sessions[session_id] current_session = self.server.mainserver.coreemu.sessions[session_id]
current_node_count = current_session.get_node_count() current_node_count = current_session.get_node_count()
if ( if (
current_session.state == EventTypes.RUNTIME_STATE.value current_session.state == EventTypes.RUNTIME_STATE
and current_node_count > node_count and current_node_count > node_count
): ):
node_count = current_node_count node_count = current_node_count

View file

@ -293,6 +293,9 @@ class EventTypes(Enum):
RECONFIGURE = 14 RECONFIGURE = 14
INSTANTIATION_COMPLETE = 15 INSTANTIATION_COMPLETE = 15
def should_start(self) -> bool:
return self.value > self.DEFINITION_STATE.value
class SessionTlvs(Enum): class SessionTlvs(Enum):
""" """

View file

@ -112,8 +112,7 @@ class Session:
self.nodes = {} self.nodes = {}
self._nodes_lock = threading.Lock() self._nodes_lock = threading.Lock()
# TODO: should the default state be definition? self.state = EventTypes.DEFINITION_STATE
self.state = EventTypes.NONE.value
self._state_time = time.monotonic() self._state_time = time.monotonic()
self._state_file = os.path.join(self.session_dir, "state") self._state_file = os.path.join(self.session_dir, "state")
@ -121,7 +120,7 @@ class Session:
self._hooks = {} self._hooks = {}
self._state_hooks = {} self._state_hooks = {}
self.add_state_hook( self.add_state_hook(
state=EventTypes.RUNTIME_STATE.value, hook=self.runtime_state_hook state=EventTypes.RUNTIME_STATE, hook=self.runtime_state_hook
) )
# handlers for broadcasting information # handlers for broadcasting information
@ -345,7 +344,7 @@ class Session:
node_one.name, node_one.name,
node_two.name, node_two.name,
) )
start = self.state > EventTypes.DEFINITION_STATE.value start = self.state.should_start()
net_one = self.create_node(cls=PtpNet, start=start) net_one = self.create_node(cls=PtpNet, start=start)
# node to network # node to network
@ -680,7 +679,7 @@ class Session:
node_class = _cls node_class = _cls
# set node start based on current session state, override and check when rj45 # set node start based on current session state, override and check when rj45
start = self.state > EventTypes.DEFINITION_STATE.value start = self.state.should_start()
enable_rj45 = self.options.get_config("enablerj45") == "1" enable_rj45 = self.options.get_config("enablerj45") == "1"
if _type == NodeTypes.RJ45 and not enable_rj45: if _type == NodeTypes.RJ45 and not enable_rj45:
start = False start = False
@ -755,7 +754,7 @@ class Session:
# boot nodes after runtime, CoreNodes, Physical, and RJ45 are all nodes # boot nodes after runtime, CoreNodes, Physical, and RJ45 are all nodes
is_boot_node = isinstance(node, CoreNodeBase) and not isinstance(node, Rj45Node) is_boot_node = isinstance(node, CoreNodeBase) and not isinstance(node, Rj45Node)
if self.state == EventTypes.RUNTIME_STATE.value and is_boot_node: if self.state == EventTypes.RUNTIME_STATE and is_boot_node:
self.write_nodes() self.write_nodes()
self.add_remove_control_interface(node=node, remove=False) self.add_remove_control_interface(node=node, remove=False)
self.services.boot_services(node) self.services.boot_services(node)
@ -850,10 +849,7 @@ class Session:
:return: True if active, False otherwise :return: True if active, False otherwise
""" """
result = self.state in { result = self.state in {EventTypes.RUNTIME_STATE, EventTypes.DATACOLLECT_STATE}
EventTypes.RUNTIME_STATE.value,
EventTypes.DATACOLLECT_STATE.value,
}
logging.info("session(%s) checking if active: %s", self.id, result) logging.info("session(%s) checking if active: %s", self.id, result)
return result return result
@ -894,7 +890,9 @@ class Session:
""" """
CoreXmlWriter(self).write(file_name) CoreXmlWriter(self).write(file_name)
def add_hook(self, state: int, file_name: str, source_name: str, data: str) -> None: def add_hook(
self, state: EventTypes, file_name: str, source_name: str, data: str
) -> None:
""" """
Store a hook from a received file message. Store a hook from a received file message.
@ -904,9 +902,17 @@ class Session:
:param data: hook data :param data: hook data
:return: nothing :return: nothing
""" """
# hack to conform with old logic until updated logging.info(
state = f":{state}" "setting state hook: %s - %s from %s", state, file_name, source_name
self.set_hook(state, file_name, source_name, data) )
hook = file_name, data
state_hooks = self._hooks.setdefault(state, [])
state_hooks.append(hook)
# immediately run a hook if it is in the current state
if self.state == state:
logging.info("immediately running new state hook")
self.run_hook(hook)
def add_node_file( def add_node_file(
self, node_id: int, source_name: str, file_name: str, data: str self, node_id: int, source_name: str, file_name: str, data: str
@ -1071,10 +1077,8 @@ class Session:
:param send_event: if true, generate core API event messages :param send_event: if true, generate core API event messages
:return: nothing :return: nothing
""" """
state_value = state.value
state_name = state.name state_name = state.name
if self.state == state:
if self.state == state_value:
logging.info( logging.info(
"session(%s) is already in state: %s, skipping change", "session(%s) is already in state: %s, skipping change",
self.id, self.id,
@ -1082,33 +1086,32 @@ class Session:
) )
return return
self.state = state_value self.state = state
self._state_time = time.monotonic() self._state_time = time.monotonic()
logging.info("changing session(%s) to state %s", self.id, state_name) logging.info("changing session(%s) to state %s", self.id, state_name)
self.write_state(state)
self.write_state(state_value) self.run_hooks(state)
self.run_hooks(state_value) self.run_state_hooks(state)
self.run_state_hooks(state_value)
if send_event: if send_event:
event_data = EventData(event_type=state_value, time=str(time.monotonic())) event_data = EventData(event_type=state, time=str(time.monotonic()))
self.broadcast_event(event_data) self.broadcast_event(event_data)
def write_state(self, state: int) -> None: def write_state(self, state: EventTypes) -> None:
""" """
Write the current state to a state file in the session dir. Write the state to a state file in the session dir.
:param state: state to write to file :param state: state to write to file
:return: nothing :return: nothing
""" """
try: try:
state_file = open(self._state_file, "w") state_file = open(self._state_file, "w")
state_file.write(f"{state} {EventTypes(self.state).name}\n") state_file.write(f"{state.value} {state.name}\n")
state_file.close() state_file.close()
except IOError: except IOError:
logging.exception("error writing state file: %s", state) logging.exception("error writing state file: %s", state.name)
def run_hooks(self, state: int) -> None: def run_hooks(self, state: EventTypes) -> None:
""" """
Run hook scripts upon changing states. If hooks is not specified, run all hooks Run hook scripts upon changing states. If hooks is not specified, run all hooks
in the given state. in the given state.
@ -1212,7 +1215,7 @@ class Session:
except (OSError, subprocess.CalledProcessError): except (OSError, subprocess.CalledProcessError):
logging.exception("error running hook: %s", file_name) logging.exception("error running hook: %s", file_name)
def run_state_hooks(self, state: int) -> None: def run_state_hooks(self, state: EventTypes) -> None:
""" """
Run state hooks. Run state hooks.
@ -1223,16 +1226,17 @@ class Session:
try: try:
hook(state) hook(state)
except Exception: except Exception:
state_name = EventTypes(self.state).name
message = ( message = (
f"exception occured when running {state_name} state hook: {hook}" f"exception occured when running {state.name} state hook: {hook}"
) )
logging.exception(message) logging.exception(message)
self.exception( self.exception(
ExceptionLevels.ERROR, "Session.run_state_hooks", None, message ExceptionLevels.ERROR, "Session.run_state_hooks", None, message
) )
def add_state_hook(self, state: int, hook: Callable[[int], None]) -> None: def add_state_hook(
self, state: EventTypes, hook: Callable[[EventTypes], None]
) -> None:
""" """
Add a state hook. Add a state hook.
@ -1259,14 +1263,14 @@ class Session:
hooks = self._state_hooks.setdefault(state, []) hooks = self._state_hooks.setdefault(state, [])
hooks.remove(hook) hooks.remove(hook)
def runtime_state_hook(self, state: int) -> None: def runtime_state_hook(self, state: EventTypes) -> None:
""" """
Runtime state hook check. Runtime state hook check.
:param state: state to check :param state: state to check
:return: nothing :return: nothing
""" """
if state == EventTypes.RUNTIME_STATE.value: if state == EventTypes.RUNTIME_STATE:
self.emane.poststartup() self.emane.poststartup()
# create session deployed xml # create session deployed xml
@ -1510,7 +1514,7 @@ class Session:
self.mobility.startup() self.mobility.startup()
# notify listeners that instantiation is complete # notify listeners that instantiation is complete
event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE.value) event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE)
self.broadcast_event(event) self.broadcast_event(event)
# assume either all nodes have booted already, or there are some # assume either all nodes have booted already, or there are some
@ -1553,9 +1557,9 @@ class Session:
logging.debug( logging.debug(
"session(%s) checking if not in runtime state, current state: %s", "session(%s) checking if not in runtime state, current state: %s",
self.id, self.id,
EventTypes(self.state).name, self.state.name,
) )
if self.state == EventTypes.RUNTIME_STATE.value: if self.state == EventTypes.RUNTIME_STATE:
logging.info("valid runtime state found, returning") logging.info("valid runtime state found, returning")
return return

View file

@ -142,17 +142,11 @@ class MobilityManager(ModelManager):
) )
continue continue
if ( if event_type in [EventTypes.STOP, EventTypes.RESTART]:
event_type == EventTypes.STOP.value
or event_type == EventTypes.RESTART.value
):
model.stop(move_initial=True) model.stop(move_initial=True)
if ( if event_type in [EventTypes.START, EventTypes.RESTART]:
event_type == EventTypes.START.value
or event_type == EventTypes.RESTART.value
):
model.start() model.start()
if event_type == EventTypes.PAUSE.value: if event_type == EventTypes.PAUSE:
model.pause() model.pause()
def sendevent(self, model: "WayPointMobility") -> None: def sendevent(self, model: "WayPointMobility") -> None:
@ -163,13 +157,13 @@ class MobilityManager(ModelManager):
:param model: mobility model to send event for :param model: mobility model to send event for
:return: nothing :return: nothing
""" """
event_type = EventTypes.NONE.value event_type = EventTypes.NONE
if model.state == model.STATE_STOPPED: if model.state == model.STATE_STOPPED:
event_type = EventTypes.STOP.value event_type = EventTypes.STOP
elif model.state == model.STATE_RUNNING: elif model.state == model.STATE_RUNNING:
event_type = EventTypes.START.value event_type = EventTypes.START
elif model.state == model.STATE_PAUSED: elif model.state == model.STATE_PAUSED:
event_type = EventTypes.PAUSE.value event_type = EventTypes.PAUSE
start_time = int(model.lasttime - model.timezero) start_time = int(model.lasttime - model.timezero)
end_time = int(model.endtime) end_time = int(model.endtime)

View file

@ -100,7 +100,7 @@ class Sdt:
return False return False
if self.connected: if self.connected:
return True return True
if self.session.state == EventTypes.SHUTDOWN_STATE.value: if self.session.state == EventTypes.SHUTDOWN_STATE:
return False return False
self.seturl() self.seturl()

View file

@ -8,7 +8,7 @@ import core.nodes.physical
from core.emane.nodes import EmaneNet from core.emane.nodes import EmaneNet
from core.emulator.data import LinkData from core.emulator.data import LinkData
from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions
from core.emulator.enumerations import NodeTypes from core.emulator.enumerations import EventTypes, NodeTypes
from core.nodes.base import CoreNetworkBase, CoreNodeBase, NodeBase from core.nodes.base import CoreNetworkBase, CoreNodeBase, NodeBase
from core.nodes.docker import DockerNode from core.nodes.docker import DockerNode
from core.nodes.lxd import LxcNode from core.nodes.lxd import LxcNode
@ -324,7 +324,7 @@ class CoreXmlWriter:
for file_name, data in self.session._hooks[state]: for file_name, data in self.session._hooks[state]:
hook = etree.SubElement(hooks, "hook") hook = etree.SubElement(hooks, "hook")
add_attribute(hook, "name", file_name) add_attribute(hook, "name", file_name)
add_attribute(hook, "state", state) add_attribute(hook, "state", state.value)
hook.text = data hook.text = data
if hooks.getchildren(): if hooks.getchildren():
@ -666,13 +666,11 @@ class CoreXmlReader:
for hook in session_hooks.iterchildren(): for hook in session_hooks.iterchildren():
name = hook.get("name") name = hook.get("name")
state = hook.get("state") state = get_int(hook, "state")
state = EventTypes(state)
data = hook.text data = hook.text
hook_type = f"hook:{state}"
logging.info("reading hook: state(%s) name(%s)", state, name) logging.info("reading hook: state(%s) name(%s)", state, name)
self.session.set_hook( self.session.add_hook(state, name, None, data)
hook_type, file_name=name, source_name=None, data=data
)
def read_session_origin(self) -> None: def read_session_origin(self) -> None:
session_origin = self.scenario.find("session_origin") session_origin = self.scenario.find("session_origin")

View file

@ -14,7 +14,7 @@
} }
}, },
"root": { "root": {
"level": "INFO", "level": "DEBUG",
"handlers": ["console"] "handlers": ["console"]
} }
} }

View file

@ -127,7 +127,7 @@ class TestGrpc:
assert wlan_node.id in session.nodes assert wlan_node.id in session.nodes
assert session.nodes[node_one.id].netif(0) is not None assert session.nodes[node_one.id].netif(0) is not None
assert session.nodes[node_two.id].netif(0) is not None assert session.nodes[node_two.id].netif(0) is not None
hook_file, hook_data = session._hooks[core_pb2.SessionState.RUNTIME][0] hook_file, hook_data = session._hooks[EventTypes.RUNTIME_STATE][0]
assert hook_file == hook.file assert hook_file == hook.file
assert hook_data == hook.data assert hook_data == hook.data
assert session.location.refxyz == (location_x, location_y, location_z) assert session.location.refxyz == (location_x, location_y, location_z)
@ -169,7 +169,7 @@ class TestGrpc:
assert isinstance(response.state, int) assert isinstance(response.state, int)
session = grpc_server.coreemu.sessions.get(response.session_id) session = grpc_server.coreemu.sessions.get(response.session_id)
assert session is not None assert session is not None
assert session.state == response.state assert session.state == EventTypes(response.state)
if session_id is not None: if session_id is not None:
assert response.session_id == session_id assert response.session_id == session_id
assert session.id == session_id assert session.id == session_id
@ -341,7 +341,7 @@ class TestGrpc:
# then # then
assert response.result is True assert response.result is True
assert session.state == core_pb2.SessionState.DEFINITION assert session.state == EventTypes.DEFINITION_STATE
def test_add_node(self, grpc_server): def test_add_node(self, grpc_server):
# given # given
@ -447,7 +447,7 @@ class TestGrpc:
session = grpc_server.coreemu.create_session() session = grpc_server.coreemu.create_session()
file_name = "test" file_name = "test"
file_data = "echo hello" file_data = "echo hello"
session.add_hook(EventTypes.RUNTIME_STATE.value, file_name, None, file_data) session.add_hook(EventTypes.RUNTIME_STATE, file_name, None, file_data)
# then # then
with client.context_connect(): with client.context_connect():
@ -1065,7 +1065,7 @@ class TestGrpc:
client.events(session.id, handle_event) client.events(session.id, handle_event)
time.sleep(0.1) time.sleep(0.1)
event = EventData( event = EventData(
event_type=EventTypes.RUNTIME_STATE.value, time=str(time.monotonic()) event_type=EventTypes.RUNTIME_STATE, time=str(time.monotonic())
) )
session.broadcast_event(event) session.broadcast_event(event)

View file

@ -376,14 +376,14 @@ class TestGui:
assert len(coretlv.coreemu.sessions) == 0 assert len(coretlv.coreemu.sessions) == 0
def test_file_hook_add(self, coretlv): def test_file_hook_add(self, coretlv):
state = EventTypes.DATACOLLECT_STATE.value state = EventTypes.DATACOLLECT_STATE
assert coretlv.session._hooks.get(state) is None assert coretlv.session._hooks.get(state) is None
file_name = "test.sh" file_name = "test.sh"
file_data = "echo hello" file_data = "echo hello"
message = coreapi.CoreFileMessage.create( message = coreapi.CoreFileMessage.create(
MessageFlags.ADD.value, MessageFlags.ADD.value,
[ [
(FileTlvs.TYPE, f"hook:{state}"), (FileTlvs.TYPE, f"hook:{state.value}"),
(FileTlvs.NAME, file_name), (FileTlvs.NAME, file_name),
(FileTlvs.DATA, file_data), (FileTlvs.DATA, file_data),
], ],
@ -514,7 +514,7 @@ class TestGui:
coretlv.handle_message(message) coretlv.handle_message(message)
assert coretlv.session.state == state.value assert coretlv.session.state == state
def test_event_schedule(self, coretlv): def test_event_schedule(self, coretlv):
coretlv.session.add_event = mock.MagicMock() coretlv.session.add_event = mock.MagicMock()

View file

@ -3,7 +3,7 @@ from xml.etree import ElementTree
import pytest import pytest
from core.emulator.emudata import LinkOptions, NodeOptions from core.emulator.emudata import LinkOptions, NodeOptions
from core.emulator.enumerations import NodeTypes from core.emulator.enumerations import EventTypes, NodeTypes
from core.errors import CoreError from core.errors import CoreError
from core.location.mobility import BasicRangeModel from core.location.mobility import BasicRangeModel
from core.services.utility import SshService from core.services.utility import SshService
@ -20,7 +20,8 @@ class TestXml:
# create hook # create hook
file_name = "runtime_hook.sh" file_name = "runtime_hook.sh"
data = "#!/bin/sh\necho hello" data = "#!/bin/sh\necho hello"
session.set_hook("hook:4", file_name, None, data) state = EventTypes.RUNTIME_STATE
session.add_hook(state, file_name, None, data)
# save xml # save xml
xml_file = tmpdir.join("session.xml") xml_file = tmpdir.join("session.xml")
@ -38,7 +39,7 @@ class TestXml:
session.open_xml(file_path, start=True) session.open_xml(file_path, start=True)
# verify nodes have been recreated # verify nodes have been recreated
runtime_hooks = session._hooks.get(4) runtime_hooks = session._hooks.get(state)
assert runtime_hooks assert runtime_hooks
runtime_hook = runtime_hooks[0] runtime_hook = runtime_hooks[0]
assert file_name == runtime_hook[0] assert file_name == runtime_hook[0]