convenience method created for dealing with udp server cases

This commit is contained in:
Blake J. Harnden 2018-03-16 12:39:23 -07:00
parent 7260f823cb
commit ee5bbdd949

View file

@ -14,7 +14,7 @@ import time
from core import coreobj from core import coreobj
from core import logger from core import logger
from core.api import coreapi from core.api import coreapi
from core.coreserver import CoreServer from core.coreserver import CoreServer, CoreUdpServer
from core.data import ConfigData from core.data import ConfigData
from core.data import EventData from core.data import EventData
from core.data import NodeData from core.data import NodeData
@ -87,6 +87,17 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
utils.close_onexec(request.fileno()) utils.close_onexec(request.fileno())
SocketServer.BaseRequestHandler.__init__(self, request, client_address, server) SocketServer.BaseRequestHandler.__init__(self, request, client_address, server)
def _get_server(self):
"""
Retrieve server to interface with, in cases where the server is a UDP instance.
:return: core.coreserver.CoreServer
"""
server = self.server
if isinstance(server, CoreUdpServer):
server = self.server.mainserver
return server
def setup(self): def setup(self):
""" """
Client has connected, set up a new connection. Client has connected, set up a new connection.
@ -94,7 +105,6 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
:return: nothing :return: nothing
""" """
logger.info("new TCP connection: %s", self.client_address) logger.info("new TCP connection: %s", self.client_address)
# self.register()
def finish(self): def finish(self):
""" """
@ -1118,16 +1128,13 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
""" """
replies = [] replies = []
server = self._get_server()
# execute a Python script or XML file # execute a Python script or XML file
execute_server = message.get_tlv(RegisterTlvs.EXECUTE_SERVER.value) execute_server = message.get_tlv(RegisterTlvs.EXECUTE_SERVER.value)
if execute_server: if execute_server:
try: try:
logger.info("executing: %s", execute_server) logger.info("executing: %s", execute_server)
# assumed to be udp server
if not isinstance(self.server, CoreServer):
server = self.server.mainserver
else:
server = self.server
if message.flags & MessageFlags.STRING.value: if message.flags & MessageFlags.STRING.value:
old_session_ids = set(server.get_session_ids()) old_session_ids = set(server.get_session_ids())
sys.argv = shlex.split(execute_server) sys.argv = shlex.split(execute_server)
@ -1159,7 +1166,7 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
logger.info("executed %s with unknown session ID", execute_server) logger.info("executed %s with unknown session ID", execute_server)
return replies return replies
logger.info("checking session %d for RUNTIME state" % sid) logger.info("checking session %d for RUNTIME state" % sid)
session = self.server.get_session(session_id=sid) session = server.get_session(session_id=sid)
retries = 10 retries = 10
# wait for session to enter RUNTIME state, to prevent GUI from # wait for session to enter RUNTIME state, to prevent GUI from
# connecting while nodes are still being instantiated # connecting while nodes are still being instantiated
@ -1193,14 +1200,14 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
# TODO: need to replicate functionality? # TODO: need to replicate functionality?
# self.server.set_session_master(self) # self.server.set_session_master(self)
# find the session containing this client and set the session to master # find the session containing this client and set the session to master
for session in self.server.sessions.itervalues(): for session in server.sessions.itervalues():
if self in session.broker.session_clients: if self in session.broker.session_clients:
logger.info("setting session to master: %s", session.session_id) logger.info("setting session to master: %s", session.session_id)
session.master = True session.master = True
break break
replies.append(self.register()) replies.append(self.register())
replies.append(self.server.to_session_message()) replies.append(server.to_session_message())
return replies return replies
@ -1439,6 +1446,8 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
node_counts = coreapi.str_to_list(node_count_str) node_counts = coreapi.str_to_list(node_count_str)
logger.info("SESSION message flags=0x%x sessions=%s" % (message.flags, session_id_str)) logger.info("SESSION message flags=0x%x sessions=%s" % (message.flags, session_id_str))
server = self._get_server()
if message.flags == 0: if message.flags == 0:
# modify a session # modify a session
i = 0 i = 0
@ -1447,7 +1456,7 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
if session_id == 0: if session_id == 0:
session = self.session session = self.session
else: else:
session = self.server.get_session(session_id=session_id) session = server.get_session(session_id=session_id)
if session is None: if session is None:
logger.info("session %s not found", session_id) logger.info("session %s not found", session_id)
@ -1469,18 +1478,13 @@ class CoreRequestHandler(SocketServer.BaseRequestHandler):
i += 1 i += 1
else: else:
if message.flags & MessageFlags.STRING.value and not message.flags & MessageFlags.ADD.value: if message.flags & MessageFlags.STRING.value and not message.flags & MessageFlags.ADD.value:
# assumed to be udp server
if not isinstance(self.server, CoreServer):
server = self.server.mainserver
else:
server = self.server
# status request flag: send list of sessions # status request flag: send list of sessions
return server.to_session_message(), return server.to_session_message(),
# handle ADD or DEL flags # handle ADD or DEL flags
for session_id in session_ids: for session_id in session_ids:
session_id = int(session_id) session_id = int(session_id)
session = self.server.get_session(session_id=session_id) session = server.get_session(session_id=session_id)
if session is None: if session is None:
logger.info("session %s not found (flags=0x%x)", session_id, message.flags) logger.info("session %s not found (flags=0x%x)", session_id, message.flags)