changes for sessions to use EventTypes for state/hooks directly

This commit is contained in:
Blake Harnden 2020-03-06 22:35:23 -08:00
parent 0e299d5af4
commit 1e8d1ecd9f
12 changed files with 95 additions and 95 deletions

View file

@ -102,7 +102,7 @@ def handle_session_event(event: EventData) -> core_pb2.SessionEvent:
event_time = float(event_time)
return core_pb2.SessionEvent(
node_id=event.node,
event=event.event_type,
event=event.event_type.value,
name=event.name,
data=event.data,
time=event_time,

View file

@ -173,7 +173,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
# add all hooks
for hook in request.hooks:
session.add_hook(hook.state, hook.file, None, hook.data)
state = EventTypes(hook.state)
session.add_hook(state, hook.file, None, hook.data)
# create nodes
_, exceptions = grpcutils.create_nodes(session, request.nodes)
@ -279,7 +280,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session.location.setrefgeo(47.57917, -122.13232, 2.0)
session.location.refscale = 150000.0
return core_pb2.CreateSessionResponse(
session_id=session.id, state=session.state
session_id=session.id, state=session.state.value
)
def DeleteSession(
@ -312,7 +313,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session = self.coreemu.sessions[session_id]
session_summary = core_pb2.SessionSummary(
id=session_id,
state=session.state,
state=session.state.value,
nodes=session.get_node_count(),
file=session.file_name,
)
@ -521,7 +522,9 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
node_links = get_links(session, node)
links.extend(node_links)
session_proto = core_pb2.Session(state=session.state, nodes=nodes, links=links)
session_proto = core_pb2.Session(
state=session.state.value, nodes=nodes, links=links
)
return core_pb2.GetSessionResponse(session=session_proto)
def AddSessionServer(
@ -896,7 +899,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
for state in session._hooks:
state_hooks = session._hooks[state]
for file_name, file_data in state_hooks:
hook = core_pb2.Hook(state=state, file=file_name, data=file_data)
hook = core_pb2.Hook(state=state.value, file=file_name, data=file_data)
hooks.append(hook)
return core_pb2.GetHooksResponse(hooks=hooks)
@ -913,7 +916,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
logging.debug("add hook: %s", request)
session = self.get_session(request.session_id, context)
hook = request.hook
session.add_hook(hook.state, hook.file, None, hook.data)
state = EventTypes(hook.state)
session.add_hook(state, hook.file, None, hook.data)
return core_pb2.AddHookResponse(result=True)
def GetMobilityConfigs(
@ -1267,7 +1271,7 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
session.mobility.set_model_config(
wlan_config.node_id, BasicRangeModel.name, wlan_config.config
)
if session.state == EventTypes.RUNTIME_STATE.value:
if session.state == EventTypes.RUNTIME_STATE:
node = self.get_node(session, wlan_config.node_id, context)
node.updatemodel(wlan_config.config)
return core_pb2.SetWlanConfigResponse(result=True)

View file

@ -228,7 +228,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
coreapi.CoreEventTlv,
[
(EventTlvs.NODE, event_data.node),
(EventTlvs.TYPE, event_data.event_type),
(EventTlvs.TYPE, event_data.event_type.value),
(EventTlvs.NAME, event_data.name),
(EventTlvs.DATA, event_data.data),
(EventTlvs.TIME, event_data.time),
@ -723,7 +723,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
if message.flags & MessageFlags.STRING.value:
self.node_status_request[node.id] = True
if self.session.state == EventTypes.RUNTIME_STATE.value:
if self.session.state == EventTypes.RUNTIME_STATE:
self.send_node_emulation_id(node.id)
elif message.flags & MessageFlags.DELETE.value:
with self._shutdown_lock:
@ -966,7 +966,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
retries = 10
# wait for session to enter RUNTIME state, to prevent GUI from
# connecting while nodes are still being instantiated
while session.state != EventTypes.RUNTIME_STATE.value:
while session.state != EventTypes.RUNTIME_STATE:
logging.debug(
"waiting for session %d to enter RUNTIME state", sid
)
@ -1375,7 +1375,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
parsed_config = ConfigShim.str_to_dict(values_str)
self.session.mobility.set_model_config(node_id, object_name, parsed_config)
if self.session.state == EventTypes.RUNTIME_STATE.value and parsed_config:
if self.session.state == EventTypes.RUNTIME_STATE and parsed_config:
try:
node = self.session.get_node(node_id)
if object_name == BasicRangeModel.name:
@ -1502,6 +1502,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
logging.error("error setting hook having state '%s'", state)
return ()
state = int(state)
state = EventTypes(state)
self.session.add_hook(state, file_name, source_name, data)
return ()
@ -1538,9 +1539,11 @@ class CoreHandler(socketserver.BaseRequestHandler):
:return: reply messages
:raises core.CoreError: when event type <= SHUTDOWN_STATE and not a known node id
"""
event_type_value = message.get_tlv(EventTlvs.TYPE.value)
event_type = EventTypes(event_type_value)
event_data = EventData(
node=message.get_tlv(EventTlvs.NODE.value),
event_type=message.get_tlv(EventTlvs.TYPE.value),
event_type=event_type,
name=message.get_tlv(EventTlvs.NAME.value),
data=message.get_tlv(EventTlvs.DATA.value),
time=message.get_tlv(EventTlvs.TIME.value),
@ -1549,7 +1552,6 @@ class CoreHandler(socketserver.BaseRequestHandler):
if event_data.event_type is None:
raise NotImplementedError("Event message missing event type")
event_type = EventTypes(event_data.event_type)
node_id = event_data.node
logging.debug("handling event %s at %s", event_type.name, time.ctime())
@ -1667,25 +1669,19 @@ class CoreHandler(socketserver.BaseRequestHandler):
unknown.append(service_name)
continue
if (
event_type == EventTypes.STOP.value
or event_type == EventTypes.RESTART.value
):
if event_type in [EventTypes.STOP, EventTypes.RESTART]:
status = self.session.services.stop_service(node, service)
if status:
fail += f"Stop {service.name},"
if (
event_type == EventTypes.START.value
or event_type == EventTypes.RESTART.value
):
if event_type in [EventTypes.START, EventTypes.RESTART]:
status = self.session.services.startup_service(node, service)
if status:
fail += f"Start ({service.name}),"
if event_type == EventTypes.PAUSE.value:
if event_type == EventTypes.PAUSE:
status = self.session.services.validate_service(node, service)
if status:
fail += f"{service.name},"
if event_type == EventTypes.RECONFIGURE.value:
if event_type == EventTypes.RECONFIGURE:
self.session.services.service_reconfigure(node, service)
fail_data = ""
@ -2052,7 +2048,7 @@ class CoreUdpHandler(CoreHandler):
current_session = self.server.mainserver.coreemu.sessions[session_id]
current_node_count = current_session.get_node_count()
if (
current_session.state == EventTypes.RUNTIME_STATE.value
current_session.state == EventTypes.RUNTIME_STATE
and current_node_count > node_count
):
node_count = current_node_count

View file

@ -293,6 +293,9 @@ class EventTypes(Enum):
RECONFIGURE = 14
INSTANTIATION_COMPLETE = 15
def should_start(self) -> bool:
return self.value > self.DEFINITION_STATE.value
class SessionTlvs(Enum):
"""

View file

@ -112,8 +112,7 @@ class Session:
self.nodes = {}
self._nodes_lock = threading.Lock()
# TODO: should the default state be definition?
self.state = EventTypes.NONE.value
self.state = EventTypes.DEFINITION_STATE
self._state_time = time.monotonic()
self._state_file = os.path.join(self.session_dir, "state")
@ -121,7 +120,7 @@ class Session:
self._hooks = {}
self._state_hooks = {}
self.add_state_hook(
state=EventTypes.RUNTIME_STATE.value, hook=self.runtime_state_hook
state=EventTypes.RUNTIME_STATE, hook=self.runtime_state_hook
)
# handlers for broadcasting information
@ -345,7 +344,7 @@ class Session:
node_one.name,
node_two.name,
)
start = self.state > EventTypes.DEFINITION_STATE.value
start = self.state.should_start()
net_one = self.create_node(cls=PtpNet, start=start)
# node to network
@ -680,7 +679,7 @@ class Session:
node_class = _cls
# set node start based on current session state, override and check when rj45
start = self.state > EventTypes.DEFINITION_STATE.value
start = self.state.should_start()
enable_rj45 = self.options.get_config("enablerj45") == "1"
if _type == NodeTypes.RJ45 and not enable_rj45:
start = False
@ -755,7 +754,7 @@ class Session:
# boot nodes after runtime, CoreNodes, Physical, and RJ45 are all nodes
is_boot_node = isinstance(node, CoreNodeBase) and not isinstance(node, Rj45Node)
if self.state == EventTypes.RUNTIME_STATE.value and is_boot_node:
if self.state == EventTypes.RUNTIME_STATE and is_boot_node:
self.write_nodes()
self.add_remove_control_interface(node=node, remove=False)
self.services.boot_services(node)
@ -850,10 +849,7 @@ class Session:
:return: True if active, False otherwise
"""
result = self.state in {
EventTypes.RUNTIME_STATE.value,
EventTypes.DATACOLLECT_STATE.value,
}
result = self.state in {EventTypes.RUNTIME_STATE, EventTypes.DATACOLLECT_STATE}
logging.info("session(%s) checking if active: %s", self.id, result)
return result
@ -894,7 +890,9 @@ class Session:
"""
CoreXmlWriter(self).write(file_name)
def add_hook(self, state: int, file_name: str, source_name: str, data: str) -> None:
def add_hook(
self, state: EventTypes, file_name: str, source_name: str, data: str
) -> None:
"""
Store a hook from a received file message.
@ -904,9 +902,17 @@ class Session:
:param data: hook data
:return: nothing
"""
# hack to conform with old logic until updated
state = f":{state}"
self.set_hook(state, file_name, source_name, data)
logging.info(
"setting state hook: %s - %s from %s", state, file_name, source_name
)
hook = file_name, data
state_hooks = self._hooks.setdefault(state, [])
state_hooks.append(hook)
# immediately run a hook if it is in the current state
if self.state == state:
logging.info("immediately running new state hook")
self.run_hook(hook)
def add_node_file(
self, node_id: int, source_name: str, file_name: str, data: str
@ -1071,10 +1077,8 @@ class Session:
:param send_event: if true, generate core API event messages
:return: nothing
"""
state_value = state.value
state_name = state.name
if self.state == state_value:
if self.state == state:
logging.info(
"session(%s) is already in state: %s, skipping change",
self.id,
@ -1082,33 +1086,32 @@ class Session:
)
return
self.state = state_value
self.state = state
self._state_time = time.monotonic()
logging.info("changing session(%s) to state %s", self.id, state_name)
self.write_state(state_value)
self.run_hooks(state_value)
self.run_state_hooks(state_value)
self.write_state(state)
self.run_hooks(state)
self.run_state_hooks(state)
if send_event:
event_data = EventData(event_type=state_value, time=str(time.monotonic()))
event_data = EventData(event_type=state, time=str(time.monotonic()))
self.broadcast_event(event_data)
def write_state(self, state: int) -> None:
def write_state(self, state: EventTypes) -> None:
"""
Write the current state to a state file in the session dir.
Write the state to a state file in the session dir.
:param state: state to write to file
:return: nothing
"""
try:
state_file = open(self._state_file, "w")
state_file.write(f"{state} {EventTypes(self.state).name}\n")
state_file.write(f"{state.value} {state.name}\n")
state_file.close()
except IOError:
logging.exception("error writing state file: %s", state)
logging.exception("error writing state file: %s", state.name)
def run_hooks(self, state: int) -> None:
def run_hooks(self, state: EventTypes) -> None:
"""
Run hook scripts upon changing states. If hooks is not specified, run all hooks
in the given state.
@ -1212,7 +1215,7 @@ class Session:
except (OSError, subprocess.CalledProcessError):
logging.exception("error running hook: %s", file_name)
def run_state_hooks(self, state: int) -> None:
def run_state_hooks(self, state: EventTypes) -> None:
"""
Run state hooks.
@ -1223,16 +1226,17 @@ class Session:
try:
hook(state)
except Exception:
state_name = EventTypes(self.state).name
message = (
f"exception occured when running {state_name} state hook: {hook}"
f"exception occured when running {state.name} state hook: {hook}"
)
logging.exception(message)
self.exception(
ExceptionLevels.ERROR, "Session.run_state_hooks", None, message
)
def add_state_hook(self, state: int, hook: Callable[[int], None]) -> None:
def add_state_hook(
self, state: EventTypes, hook: Callable[[EventTypes], None]
) -> None:
"""
Add a state hook.
@ -1259,14 +1263,14 @@ class Session:
hooks = self._state_hooks.setdefault(state, [])
hooks.remove(hook)
def runtime_state_hook(self, state: int) -> None:
def runtime_state_hook(self, state: EventTypes) -> None:
"""
Runtime state hook check.
:param state: state to check
:return: nothing
"""
if state == EventTypes.RUNTIME_STATE.value:
if state == EventTypes.RUNTIME_STATE:
self.emane.poststartup()
# create session deployed xml
@ -1510,7 +1514,7 @@ class Session:
self.mobility.startup()
# notify listeners that instantiation is complete
event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE.value)
event = EventData(event_type=EventTypes.INSTANTIATION_COMPLETE)
self.broadcast_event(event)
# assume either all nodes have booted already, or there are some
@ -1553,9 +1557,9 @@ class Session:
logging.debug(
"session(%s) checking if not in runtime state, current state: %s",
self.id,
EventTypes(self.state).name,
self.state.name,
)
if self.state == EventTypes.RUNTIME_STATE.value:
if self.state == EventTypes.RUNTIME_STATE:
logging.info("valid runtime state found, returning")
return

View file

@ -142,17 +142,11 @@ class MobilityManager(ModelManager):
)
continue
if (
event_type == EventTypes.STOP.value
or event_type == EventTypes.RESTART.value
):
if event_type in [EventTypes.STOP, EventTypes.RESTART]:
model.stop(move_initial=True)
if (
event_type == EventTypes.START.value
or event_type == EventTypes.RESTART.value
):
if event_type in [EventTypes.START, EventTypes.RESTART]:
model.start()
if event_type == EventTypes.PAUSE.value:
if event_type == EventTypes.PAUSE:
model.pause()
def sendevent(self, model: "WayPointMobility") -> None:
@ -163,13 +157,13 @@ class MobilityManager(ModelManager):
:param model: mobility model to send event for
:return: nothing
"""
event_type = EventTypes.NONE.value
event_type = EventTypes.NONE
if model.state == model.STATE_STOPPED:
event_type = EventTypes.STOP.value
event_type = EventTypes.STOP
elif model.state == model.STATE_RUNNING:
event_type = EventTypes.START.value
event_type = EventTypes.START
elif model.state == model.STATE_PAUSED:
event_type = EventTypes.PAUSE.value
event_type = EventTypes.PAUSE
start_time = int(model.lasttime - model.timezero)
end_time = int(model.endtime)

View file

@ -100,7 +100,7 @@ class Sdt:
return False
if self.connected:
return True
if self.session.state == EventTypes.SHUTDOWN_STATE.value:
if self.session.state == EventTypes.SHUTDOWN_STATE:
return False
self.seturl()

View file

@ -8,7 +8,7 @@ import core.nodes.physical
from core.emane.nodes import EmaneNet
from core.emulator.data import LinkData
from core.emulator.emudata import InterfaceData, LinkOptions, NodeOptions
from core.emulator.enumerations import NodeTypes
from core.emulator.enumerations import EventTypes, NodeTypes
from core.nodes.base import CoreNetworkBase, CoreNodeBase, NodeBase
from core.nodes.docker import DockerNode
from core.nodes.lxd import LxcNode
@ -324,7 +324,7 @@ class CoreXmlWriter:
for file_name, data in self.session._hooks[state]:
hook = etree.SubElement(hooks, "hook")
add_attribute(hook, "name", file_name)
add_attribute(hook, "state", state)
add_attribute(hook, "state", state.value)
hook.text = data
if hooks.getchildren():
@ -666,13 +666,11 @@ class CoreXmlReader:
for hook in session_hooks.iterchildren():
name = hook.get("name")
state = hook.get("state")
state = get_int(hook, "state")
state = EventTypes(state)
data = hook.text
hook_type = f"hook:{state}"
logging.info("reading hook: state(%s) name(%s)", state, name)
self.session.set_hook(
hook_type, file_name=name, source_name=None, data=data
)
self.session.add_hook(state, name, None, data)
def read_session_origin(self) -> None:
session_origin = self.scenario.find("session_origin")

View file

@ -14,7 +14,7 @@
}
},
"root": {
"level": "INFO",
"level": "DEBUG",
"handlers": ["console"]
}
}

View file

@ -127,7 +127,7 @@ class TestGrpc:
assert wlan_node.id in session.nodes
assert session.nodes[node_one.id].netif(0) is not None
assert session.nodes[node_two.id].netif(0) is not None
hook_file, hook_data = session._hooks[core_pb2.SessionState.RUNTIME][0]
hook_file, hook_data = session._hooks[EventTypes.RUNTIME_STATE][0]
assert hook_file == hook.file
assert hook_data == hook.data
assert session.location.refxyz == (location_x, location_y, location_z)
@ -169,7 +169,7 @@ class TestGrpc:
assert isinstance(response.state, int)
session = grpc_server.coreemu.sessions.get(response.session_id)
assert session is not None
assert session.state == response.state
assert session.state == EventTypes(response.state)
if session_id is not None:
assert response.session_id == session_id
assert session.id == session_id
@ -341,7 +341,7 @@ class TestGrpc:
# then
assert response.result is True
assert session.state == core_pb2.SessionState.DEFINITION
assert session.state == EventTypes.DEFINITION_STATE
def test_add_node(self, grpc_server):
# given
@ -447,7 +447,7 @@ class TestGrpc:
session = grpc_server.coreemu.create_session()
file_name = "test"
file_data = "echo hello"
session.add_hook(EventTypes.RUNTIME_STATE.value, file_name, None, file_data)
session.add_hook(EventTypes.RUNTIME_STATE, file_name, None, file_data)
# then
with client.context_connect():
@ -1065,7 +1065,7 @@ class TestGrpc:
client.events(session.id, handle_event)
time.sleep(0.1)
event = EventData(
event_type=EventTypes.RUNTIME_STATE.value, time=str(time.monotonic())
event_type=EventTypes.RUNTIME_STATE, time=str(time.monotonic())
)
session.broadcast_event(event)

View file

@ -376,14 +376,14 @@ class TestGui:
assert len(coretlv.coreemu.sessions) == 0
def test_file_hook_add(self, coretlv):
state = EventTypes.DATACOLLECT_STATE.value
state = EventTypes.DATACOLLECT_STATE
assert coretlv.session._hooks.get(state) is None
file_name = "test.sh"
file_data = "echo hello"
message = coreapi.CoreFileMessage.create(
MessageFlags.ADD.value,
[
(FileTlvs.TYPE, f"hook:{state}"),
(FileTlvs.TYPE, f"hook:{state.value}"),
(FileTlvs.NAME, file_name),
(FileTlvs.DATA, file_data),
],
@ -514,7 +514,7 @@ class TestGui:
coretlv.handle_message(message)
assert coretlv.session.state == state.value
assert coretlv.session.state == state
def test_event_schedule(self, coretlv):
coretlv.session.add_event = mock.MagicMock()

View file

@ -3,7 +3,7 @@ from xml.etree import ElementTree
import pytest
from core.emulator.emudata import LinkOptions, NodeOptions
from core.emulator.enumerations import NodeTypes
from core.emulator.enumerations import EventTypes, NodeTypes
from core.errors import CoreError
from core.location.mobility import BasicRangeModel
from core.services.utility import SshService
@ -20,7 +20,8 @@ class TestXml:
# create hook
file_name = "runtime_hook.sh"
data = "#!/bin/sh\necho hello"
session.set_hook("hook:4", file_name, None, data)
state = EventTypes.RUNTIME_STATE
session.add_hook(state, file_name, None, data)
# save xml
xml_file = tmpdir.join("session.xml")
@ -38,7 +39,7 @@ class TestXml:
session.open_xml(file_path, start=True)
# verify nodes have been recreated
runtime_hooks = session._hooks.get(4)
runtime_hooks = session._hooks.get(state)
assert runtime_hooks
runtime_hook = runtime_hooks[0]
assert file_name == runtime_hook[0]