daemon: Session cleanup, removed unused functions, used context managers for writing files, made variables used externally no longer private

This commit is contained in:
Blake Harnden 2020-06-12 20:22:51 -07:00
parent 178d12b327
commit 23d957679e
11 changed files with 99 additions and 220 deletions

View file

@ -930,8 +930,8 @@ class CoreGrpcServer(core_pb2_grpc.CoreApiServicer):
logging.debug("get hooks: %s", request)
session = self.get_session(request.session_id, context)
hooks = []
for state in session._hooks:
state_hooks = session._hooks[state]
for state in session.hooks:
state_hooks = session.hooks[state]
for file_name, file_data in state_hooks:
hook = core_pb2.Hook(state=state.value, file=file_name, data=file_data)
hooks.append(hook)

View file

@ -12,6 +12,7 @@ import threading
import time
from itertools import repeat
from queue import Empty, Queue
from typing import Optional
from core import utils
from core.api.tlv import coreapi, dataconversion, structutils
@ -39,6 +40,7 @@ from core.emulator.enumerations import (
NodeTypes,
RegisterTlvs,
)
from core.emulator.session import Session
from core.errors import CoreCommandError, CoreError
from core.location.mobility import BasicRangeModel
from core.nodes.base import CoreNode, CoreNodeBase, NodeBase
@ -83,7 +85,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
thread.start()
self.handler_threads.append(thread)
self.session = None
self.session: Optional[Session] = None
self.coreemu = server.coreemu
utils.close_onexec(request.fileno())
socketserver.BaseRequestHandler.__init__(self, request, client_address, server)
@ -176,7 +178,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
node_count_list.append(str(session.get_node_count()))
date_list.append(time.ctime(session._state_time))
date_list.append(time.ctime(session.state_time))
thumb = session.thumbnail
if not thumb:
@ -1819,7 +1821,7 @@ class CoreHandler(socketserver.BaseRequestHandler):
"""
# find all nodes and links
links_data = []
with self.session._nodes_lock:
with self.session.nodes_lock:
for node_id in self.session.nodes:
node = self.session.nodes[node_id]
self.session.broadcast_node(node, MessageFlags.ADD)
@ -1897,8 +1899,8 @@ class CoreHandler(socketserver.BaseRequestHandler):
# TODO: send location info
# send hook scripts
for state in sorted(self.session._hooks.keys()):
for file_name, config_data in self.session._hooks[state]:
for state in sorted(self.session.hooks.keys()):
for file_name, config_data in self.session.hooks[state]:
file_data = FileData(
message_type=MessageFlags.ADD,
name=str(file_name),

View file

@ -279,7 +279,7 @@ class EmaneManager(ModelManager):
logging.debug("emane setup")
# TODO: drive this from the session object
with self.session._nodes_lock:
with self.session.nodes_lock:
for node_id in self.session.nodes:
node = self.session.nodes[node_id]
if isinstance(node, EmaneNet):

View file

@ -6,7 +6,6 @@ that manages a CORE session.
import logging
import os
import pwd
import random
import shutil
import subprocess
import tempfile
@ -113,15 +112,13 @@ class Session:
# dict of nodes: all nodes and nets
self.nodes: Dict[int, NodeBase] = {}
self._nodes_lock = threading.Lock()
self.nodes_lock = threading.Lock()
# states and hooks handlers
self.state: EventTypes = EventTypes.DEFINITION_STATE
self._state_time: float = time.monotonic()
self._state_file: str = os.path.join(self.session_dir, "state")
# hooks handlers
self._hooks: Dict[EventTypes, Tuple[str, str]] = {}
self._state_hooks: Dict[EventTypes, Callable[[int], None]] = {}
self.state_time: float = time.monotonic()
self.hooks: Dict[EventTypes, Tuple[str, str]] = {}
self.state_hooks: Dict[EventTypes, List[Callable[[EventTypes], None]]] = {}
self.add_state_hook(
state=EventTypes.RUNTIME_STATE, hook=self.runtime_state_hook
)
@ -154,15 +151,6 @@ class Session:
self.emane: EmaneManager = EmaneManager(self)
self.sdt: Sdt = Sdt(self)
# initialize default node services
self.services.default_services = {
"mdr": ("zebra", "OSPFv3MDR", "IPForward"),
"PC": ("DefaultRoute",),
"prouter": (),
"router": ("zebra", "OSPFv2", "OSPFv3", "IPForward"),
"host": ("DefaultRoute", "SSH"),
}
# config services
self.service_manager: Optional[ConfigServiceManager] = None
@ -473,7 +461,7 @@ class Session:
f"cannot update link node1({type(node1)}) node2({type(node2)})"
)
def _next_node_id(self) -> int:
def next_node_id(self) -> int:
"""
Find the next valid node id, starting from 1.
@ -506,7 +494,7 @@ class Session:
# determine node id
if not _id:
_id = self._next_node_id()
_id = self.next_node_id()
# generate name if not provided
if not options:
@ -692,7 +680,7 @@ class Session:
"setting state hook: %s - %s source(%s)", state, file_name, source_name
)
hook = file_name, data
state_hooks = self._hooks.setdefault(state, [])
state_hooks = self.hooks.setdefault(state, [])
state_hooks.append(hook)
# immediately run a hook if it is in the current state
@ -727,7 +715,7 @@ class Session:
self.emane.shutdown()
self.delete_nodes()
self.distributed.shutdown()
self.del_hooks()
self.hooks.clear()
self.emane.reset()
self.emane.config_reset()
self.location.reset()
@ -795,7 +783,6 @@ class Session:
:param event_data: event data to send out
:return: nothing
"""
for handler in self.event_handlers:
handler(event_data)
@ -806,7 +793,6 @@ class Session:
:param exception_data: exception data to send out
:return: nothing
"""
for handler in self.exception_handlers:
handler(exception_data)
@ -837,7 +823,6 @@ class Session:
:param file_data: file data to send out
:return: nothing
"""
for handler in self.file_handlers:
handler(file_data)
@ -848,7 +833,6 @@ class Session:
:param config_data: config data to send out
:return: nothing
"""
for handler in self.config_handlers:
handler(config_data)
@ -859,7 +843,6 @@ class Session:
:param link_data: link data to send out
:return: nothing
"""
for handler in self.link_handlers:
handler(link_data)
@ -871,22 +854,14 @@ class Session:
:param send_event: if true, generate core API event messages
:return: nothing
"""
state_name = state.name
if self.state == state:
logging.info(
"session(%s) is already in state: %s, skipping change",
self.id,
state_name,
)
return
self.state = state
self._state_time = time.monotonic()
logging.info("changing session(%s) to state %s", self.id, state_name)
self.state_time = time.monotonic()
logging.info("changing session(%s) to state %s", self.id, state.name)
self.write_state(state)
self.run_hooks(state)
self.run_state_hooks(state)
if send_event:
event_data = EventData(event_type=state, time=str(time.monotonic()))
self.broadcast_event(event_data)
@ -898,10 +873,10 @@ class Session:
:param state: state to write to file
:return: nothing
"""
state_file = os.path.join(self.session_dir, "state")
try:
state_file = open(self._state_file, "w")
state_file.write(f"{state.value} {state.name}\n")
state_file.close()
with open(state_file, "w") as f:
f.write(f"{state.value} {state.name}\n")
except IOError:
logging.exception("error writing state file: %s", state.name)
@ -913,61 +888,10 @@ class Session:
:param state: state to run hooks for
:return: nothing
"""
# check that state change hooks exist
if state not in self._hooks:
return
# retrieve all state hooks
hooks = self._hooks.get(state, [])
# execute all state hooks
if hooks:
for hook in hooks:
self.run_hook(hook)
else:
logging.info("no state hooks for %s", state)
def set_hook(
self, hook_type: str, file_name: str, source_name: str, data: str
) -> None:
"""
Store a hook from a received file message.
:param hook_type: hook type
:param file_name: file name for hook
:param source_name: source name
:param data: hook data
:return: nothing
"""
logging.info(
"setting state hook: %s - %s from %s", hook_type, file_name, source_name
)
_hook_id, state = hook_type.split(":")[:2]
if not state.isdigit():
logging.error("error setting hook having state '%s'", state)
return
state = int(state)
hook = file_name, data
# append hook to current state hooks
state_hooks = self._hooks.setdefault(state, [])
state_hooks.append(hook)
# immediately run a hook if it is in the current state
# (this allows hooks in the definition and configuration states)
if self.state == state:
logging.info("immediately running new state hook")
hooks = self.hooks.get(state, [])
for hook in hooks:
self.run_hook(hook)
def del_hooks(self) -> None:
"""
Clear the hook scripts dict.
"""
self._hooks.clear()
def run_hook(self, hook: Tuple[str, str]) -> None:
"""
Run a hook.
@ -977,37 +901,23 @@ class Session:
"""
file_name, data = hook
logging.info("running hook %s", file_name)
# write data to hook file
file_path = os.path.join(self.session_dir, file_name)
log_path = os.path.join(self.session_dir, f"{file_name}.log")
try:
hook_file = open(os.path.join(self.session_dir, file_name), "w")
hook_file.write(data)
hook_file.close()
except IOError:
logging.exception("error writing hook '%s'", file_name)
# setup hook stdout and stderr
try:
stdout = open(os.path.join(self.session_dir, file_name + ".log"), "w")
stderr = subprocess.STDOUT
except IOError:
logging.exception("error setting up hook stderr and stdout")
stdout = None
stderr = None
# execute hook file
try:
args = ["/bin/sh", file_name]
subprocess.check_call(
args,
stdout=stdout,
stderr=stderr,
close_fds=True,
cwd=self.session_dir,
env=self.get_environment(),
)
except (OSError, subprocess.CalledProcessError):
logging.exception("error running hook: %s", file_name)
with open(file_path, "w") as f:
f.write(data)
with open(log_path, "w") as f:
args = ["/bin/sh", file_name]
subprocess.check_call(
args,
stdout=f,
stderr=subprocess.STDOUT,
close_fds=True,
cwd=self.session_dir,
env=self.get_environment(),
)
except (IOError, subprocess.CalledProcessError):
logging.exception("error running hook: %s", file_path)
def run_state_hooks(self, state: EventTypes) -> None:
"""
@ -1016,17 +926,16 @@ class Session:
:param state: state to run hooks for
:return: nothing
"""
for hook in self._state_hooks.get(state, []):
try:
hook(state)
except Exception:
message = (
f"exception occured when running {state.name} state hook: {hook}"
)
logging.exception(message)
self.exception(
ExceptionLevels.ERROR, "Session.run_state_hooks", message
)
for hook in self.state_hooks.get(state, []):
self.run_state_hook(state, hook)
def run_state_hook(self, state: EventTypes, hook: Callable[[EventTypes], None]):
try:
hook(state)
except Exception:
message = f"exception occurred when running {state.name} state hook: {hook}"
logging.exception(message)
self.exception(ExceptionLevels.ERROR, "Session.run_state_hooks", message)
def add_state_hook(
self, state: EventTypes, hook: Callable[[EventTypes], None]
@ -1038,15 +947,16 @@ class Session:
:param hook: hook callback for the state
:return: nothing
"""
hooks = self._state_hooks.setdefault(state, [])
hooks = self.state_hooks.setdefault(state, [])
if hook in hooks:
raise CoreError("attempting to add duplicate state hook")
hooks.append(hook)
if self.state == state:
hook(state)
self.run_state_hook(state, hook)
def del_state_hook(self, state: int, hook: Callable[[int], None]) -> None:
def del_state_hook(
self, state: EventTypes, hook: Callable[[EventTypes], None]
) -> None:
"""
Delete a state hook.
@ -1054,24 +964,23 @@ class Session:
:param hook: hook to delete
:return: nothing
"""
hooks = self._state_hooks.setdefault(state, [])
hooks.remove(hook)
hooks = self.state_hooks.get(state, [])
if hook in hooks:
hooks.remove(hook)
def runtime_state_hook(self, state: EventTypes) -> None:
def runtime_state_hook(self, _state: EventTypes) -> None:
"""
Runtime state hook check.
:param state: state to check
:param _state: state to check
:return: nothing
"""
if state == EventTypes.RUNTIME_STATE:
self.emane.poststartup()
# create session deployed xml
xml_file_name = os.path.join(self.session_dir, "session-deployed.xml")
xml_writer = corexml.CoreXmlWriter(self)
corexmldeployment.CoreXmlDeployment(self, xml_writer.scenario)
xml_writer.write(xml_file_name)
self.emane.poststartup()
# create session deployed xml
xml_file_name = os.path.join(self.session_dir, "session-deployed.xml")
xml_writer = corexml.CoreXmlWriter(self)
corexmldeployment.CoreXmlDeployment(self, xml_writer.scenario)
xml_writer.write(xml_file_name)
def get_environment(self, state: bool = True) -> Dict[str, str]:
"""
@ -1090,10 +999,8 @@ class Session:
env["SESSION_FILENAME"] = str(self.file_name)
env["SESSION_USER"] = str(self.user)
env["SESSION_NODE_COUNT"] = str(self.get_node_count())
if state:
env["SESSION_STATE"] = str(self.state)
# attempt to read and add environment config file
environment_config_file = os.path.join(constants.CORE_CONF_DIR, "environment")
try:
@ -1104,7 +1011,6 @@ class Session:
"environment configuration file does not exist: %s",
environment_config_file,
)
# attempt to read and add user environment file
if self.user:
environment_user_file = os.path.join(
@ -1117,7 +1023,6 @@ class Session:
"user core environment settings file not present: %s",
environment_user_file,
)
return env
def set_thumbnail(self, thumb_file: str) -> None:
@ -1131,7 +1036,6 @@ class Session:
logging.error("thumbnail file to set does not exist: %s", thumb_file)
self.thumbnail = None
return
destination_file = os.path.join(self.session_dir, os.path.basename(thumb_file))
shutil.copy(thumb_file, destination_file)
self.thumbnail = destination_file
@ -1151,20 +1055,8 @@ class Session:
os.chown(self.session_dir, uid, gid)
except IOError:
logging.exception("failed to set permission on %s", self.session_dir)
self.user = user
def get_node_id(self) -> int:
"""
Return a unique, new node id.
"""
with self._nodes_lock:
while True:
node_id = random.randint(1, 0xFFFF)
if node_id not in self.nodes:
break
return node_id
def create_node(self, _class: Type[NT], *args: Any, **kwargs: Any) -> NT:
"""
Create an emulation node.
@ -1176,7 +1068,7 @@ class Session:
:raises core.CoreError: when id of the node to create already exists
"""
node = _class(self, *args, **kwargs)
with self._nodes_lock:
with self.nodes_lock:
if node.id in self.nodes:
node.shutdown()
raise CoreError(f"duplicate node id {node.id} for {node.name}")
@ -1192,9 +1084,9 @@ class Session:
:return: node for the given id
:raises core.CoreError: when node does not exist
"""
if _id not in self.nodes:
node = self.nodes.get(_id)
if node is None:
raise CoreError(f"unknown node id {_id}")
node = self.nodes[_id]
if not isinstance(node, _class):
actual = node.__class__.__name__
expected = _class.__name__
@ -1210,7 +1102,7 @@ class Session:
"""
# delete node and check for session shutdown if a node was removed
node = None
with self._nodes_lock:
with self.nodes_lock:
if _id in self.nodes:
node = self.nodes.pop(_id)
logging.info("deleted node(%s)", node.name)
@ -1224,7 +1116,7 @@ class Session:
"""
Clear the nodes dictionary, and call shutdown for each node.
"""
with self._nodes_lock:
with self.nodes_lock:
funcs = []
while self.nodes:
_, node = self.nodes.popitem()
@ -1237,29 +1129,15 @@ class Session:
Write nodes to a 'nodes' file in the session dir.
The 'nodes' file lists: number, name, api-type, class-type
"""
file_path = os.path.join(self.session_dir, "nodes")
try:
with self._nodes_lock:
file_path = os.path.join(self.session_dir, "nodes")
with self.nodes_lock:
with open(file_path, "w") as f:
for _id in self.nodes.keys():
node = self.nodes[_id]
for _id, node in self.nodes.items():
f.write(f"{_id} {node.name} {node.apitype} {type(node)}\n")
except IOError:
logging.exception("error writing nodes file")
def dump_session(self) -> None:
"""
Log information about the session in its current state.
"""
logging.info("session id=%s name=%s state=%s", self.id, self.name, self.state)
logging.info(
"file=%s thumbnail=%s node_count=%s/%s",
self.file_name,
self.thumbnail,
self.get_node_count(),
len(self.nodes),
)
def exception(
self, level: ExceptionLevels, source: str, text: str, node_id: int = None
) -> None:
@ -1327,17 +1205,15 @@ class Session:
:return: created node count
"""
with self._nodes_lock:
with self.nodes_lock:
count = 0
for node_id in self.nodes:
node = self.nodes[node_id]
for node in self.nodes.values():
is_p2p_ctrlnet = isinstance(node, (PtpNet, CtrlNet))
is_tap = isinstance(node, GreTapBridge) and not isinstance(
node, TunnelNode
)
if is_p2p_ctrlnet or is_tap:
continue
count += 1
return count
@ -1359,7 +1235,6 @@ class Session:
if self.state == EventTypes.RUNTIME_STATE:
logging.info("valid runtime state found, returning")
return
# start event loop and set to runtime
self.event_loop.run()
self.set_state(EventTypes.RUNTIME_STATE, send_event=True)
@ -1375,7 +1250,7 @@ class Session:
self.event_loop.stop()
# stop node services
with self._nodes_lock:
with self.nodes_lock:
funcs = []
for node_id in self.nodes:
node = self.nodes[node_id]
@ -1447,7 +1322,7 @@ class Session:
:return: service boot exceptions
"""
with self._nodes_lock:
with self.nodes_lock:
funcs = []
start = time.monotonic()
for _id in self.nodes:
@ -1545,7 +1420,6 @@ class Session:
else:
prefix_spec = CtrlNet.DEFAULT_PREFIX_LIST[net_index]
logging.debug("prefix spec: %s", prefix_spec)
server_interface = self.get_control_net_server_interfaces()[net_index]
# return any existing controlnet bridge
@ -1685,7 +1559,7 @@ class Session:
if not in runtime.
"""
if self.state == EventTypes.RUNTIME_STATE:
return time.monotonic() - self._state_time
return time.monotonic() - self.state_time
else:
return 0.0
@ -1708,7 +1582,6 @@ class Session:
"""
event_time = float(event_time)
current_time = self.runtime()
if current_time > 0:
if event_time <= current_time:
logging.warning(
@ -1718,11 +1591,9 @@ class Session:
)
return
event_time = event_time - current_time
self.event_loop.add_event(
event_time, self.run_event, node=node, name=name, data=data
)
if not name:
name = ""
logging.info(
@ -1732,8 +1603,6 @@ class Session:
data,
)
# TODO: if data is None, this blows up, but this ties into how event functions
# are ran, need to clean that up
def run_event(
self, node_id: int = None, name: str = None, data: str = None
) -> None:
@ -1745,10 +1614,12 @@ class Session:
:param data: event data
:return: nothing
"""
if data is None:
logging.warning("no data for event node(%s) name(%s)", node_id, name)
return
now = self.runtime()
if not name:
name = ""
logging.info("running event %s at time %s cmd=%s", name, now, data)
if not node_id:
utils.mute_detach(data)

View file

@ -63,7 +63,7 @@ class NodeBase:
self.session: "Session" = session
if _id is None:
_id = session.get_node_id()
_id = session.next_node_id()
self.id: int = _id
if name is None:
name = f"o{self.id}"

View file

@ -215,7 +215,7 @@ class Sdt:
for layer in CORE_LAYERS:
self.cmd(f"layer {layer}")
with self.session._nodes_lock:
with self.session.nodes_lock:
for node_id in self.session.nodes:
node = self.session.nodes[node_id]
if isinstance(node, CoreNetworkBase):

View file

@ -325,7 +325,13 @@ class CoreServices:
"""
self.session = session
# dict of default services tuples, key is node type
self.default_services = {}
self.default_services = {
"mdr": ("zebra", "OSPFv3MDR", "IPForward"),
"PC": ("DefaultRoute",),
"prouter": (),
"router": ("zebra", "OSPFv2", "OSPFv3", "IPForward"),
"host": ("DefaultRoute", "SSH"),
}
# dict of node ids to dict of custom services by name
self.custom_services = {}

View file

@ -320,8 +320,8 @@ class CoreXmlWriter:
def write_session_hooks(self) -> None:
# hook scripts
hooks = etree.Element("session_hooks")
for state in sorted(self.session._hooks, key=lambda x: x.value):
for file_name, data in self.session._hooks[state]:
for state in sorted(self.session.hooks, key=lambda x: x.value):
for file_name, data in self.session.hooks[state]:
hook = etree.SubElement(hooks, "hook")
add_attribute(hook, "name", file_name)
add_attribute(hook, "state", state.value)

View file

@ -133,7 +133,7 @@ class TestGrpc:
assert wlan_node.id in session.nodes
assert session.nodes[node1.id].netif(0) is not None
assert session.nodes[node2.id].netif(0) is not None
hook_file, hook_data = session._hooks[EventTypes.RUNTIME_STATE][0]
hook_file, hook_data = session.hooks[EventTypes.RUNTIME_STATE][0]
assert hook_file == hook.file
assert hook_data == hook.data
assert session.location.refxyz == (location_x, location_y, location_z)

View file

@ -382,7 +382,7 @@ class TestGui:
def test_file_hook_add(self, coretlv: CoreHandler):
state = EventTypes.DATACOLLECT_STATE
assert coretlv.session._hooks.get(state) is None
assert coretlv.session.hooks.get(state) is None
file_name = "test.sh"
file_data = "echo hello"
message = coreapi.CoreFileMessage.create(
@ -396,7 +396,7 @@ class TestGui:
coretlv.handle_message(message)
hooks = coretlv.session._hooks.get(state)
hooks = coretlv.session.hooks.get(state)
assert len(hooks) == 1
name, data = hooks[0]
assert file_name == name

View file

@ -48,7 +48,7 @@ class TestXml:
session.open_xml(file_path, start=True)
# verify nodes have been recreated
runtime_hooks = session._hooks.get(state)
runtime_hooks = session.hooks.get(state)
assert runtime_hooks
runtime_hook = runtime_hooks[0]
assert file_name == runtime_hook[0]