diff --git a/daemon/core/services/coreservices.py b/daemon/core/services/coreservices.py index 3d04f9ae..646a433d 100644 --- a/daemon/core/services/coreservices.py +++ b/daemon/core/services/coreservices.py @@ -53,18 +53,34 @@ class ServiceDependencies: def __init__(self, services: List["CoreServiceType"]) -> None: self.visited: Set[str] = set() - self.boot: List["CoreServiceType"] = [] self.services: Dict[str, "CoreServiceType"] = {} + self.paths: Dict[str, List["CoreServiceType"]] = {} + self.boot_paths: List[List["CoreServiceType"]] = [] + roots = set([x.name for x in services]) for service in services: self.services[service.name] = service + roots -= set(service.dependencies) + self.roots: List["CoreServiceType"] = [x for x in services if x.name in roots] + if services and not self.roots: + raise ValueError("circular dependency is present") - def _search(self, service: "CoreServiceType", visiting: Set[str] = None) -> None: + def _search( + self, + service: "CoreServiceType", + visiting: Set[str] = None, + path: List[str] = None, + ) -> List["CoreServiceType"]: if service.name in self.visited: - return + return self.paths[service.name] self.visited.add(service.name) if visiting is None: visiting = set() visiting.add(service.name) + if path is None: + for dependency in service.dependencies: + path = self.paths.get(dependency) + if path is not None: + break for dependency in service.dependencies: service_dependency = self.services.get(dependency) if not service_dependency: @@ -72,14 +88,19 @@ class ServiceDependencies: if dependency in visiting: raise ValueError(f"circular dependency, already visited: {dependency}") else: - self._search(service_dependency, visiting) + path = self._search(service_dependency, visiting, path) visiting.remove(service.name) - self.boot.append(service) + if path is None: + path = [] + self.boot_paths.append(path) + path.append(service) + self.paths[service.name] = path + return path - def boot_order(self) -> List["CoreServiceType"]: - for service in self.services.values(): + def boot_order(self) -> List[List["CoreServiceType"]]: + for service in self.roots: self._search(service) - return self.boot + return self.boot_paths class ServiceShim: @@ -422,13 +443,22 @@ class CoreServices: :param node: node to start services on :return: nothing """ - services = ServiceDependencies(node.services).boot_order() + boot_paths = ServiceDependencies(node.services).boot_order() + funcs = [] + for boot_path in boot_paths: + args = (node, boot_path) + funcs.append((self._boot_service_path, args, {})) + result, exceptions = utils.threadpool(funcs) + if exceptions: + raise ServiceBootError(*exceptions) + + def _boot_service_path(self, node: CoreNode, boot_path: List["CoreServiceType"]): logging.info( "booting node(%s) services: %s", node.name, - " -> ".join([x.name for x in services]), + " -> ".join([x.name for x in boot_path]), ) - for service in services: + for service in boot_path: service = self.get_service(node.id, service.name, default_service=True) try: self.boot_service(node, service) diff --git a/daemon/tests/test_services.py b/daemon/tests/test_services.py index 82135fb6..44776ea2 100644 --- a/daemon/tests/test_services.py +++ b/daemon/tests/test_services.py @@ -220,7 +220,7 @@ class TestServices: assert default_service == my_service assert custom_service and custom_service != my_service - def test_services_dependencies(self): + def test_services_dependency(self): # given service_a = CoreService() service_a.name = "a" @@ -238,20 +238,34 @@ class TestServices: service_d.dependencies = () service_e.dependencies = () services = [service_a, service_b, service_c, service_d, service_e] + expected1 = {service_a.name, service_b.name, service_c.name, service_d.name} + expected2 = [service_e] # when - results = [] permutations = itertools.permutations(services) for permutation in permutations: permutation = list(permutation) - result = ServiceDependencies(permutation).boot_order() - results.append(result) + results = ServiceDependencies(permutation).boot_order() + # then + for result in results: + result_set = {x.name for x in result} + if len(result) == 4: + a_index = result.index(service_a) + b_index = result.index(service_b) + c_index = result.index(service_c) + d_index = result.index(service_d) + assert b_index < a_index + assert b_index < c_index + assert d_index < c_index + assert result_set == expected1 + elif len(result) == 1: + assert expected2 == result + else: + raise ValueError( + f"unexpected result: {results}, perm({permutation})" + ) - # then - for result in results: - assert len(result) == len(services) - - def test_services_missing_dependency(self): + def test_services_dependency_missing(self): # given service_a = CoreService() service_a.name = "a" @@ -271,7 +285,7 @@ class TestServices: with pytest.raises(ValueError): ServiceDependencies(permutation).boot_order() - def test_services_dependencies_cycle(self): + def test_services_dependency_cycle(self): # given service_a = CoreService() service_a.name = "a" @@ -291,7 +305,7 @@ class TestServices: with pytest.raises(ValueError): ServiceDependencies(permutation).boot_order() - def test_services_common_dependency(self): + def test_services_dependency_common(self): # given service_a = CoreService() service_a.name = "a" @@ -299,18 +313,64 @@ class TestServices: service_b.name = "b" service_c = CoreService() service_c.name = "c" - service_b.dependencies = (service_a.name,) - service_c.dependencies = (service_a.name, service_b.name) - services = [service_a, service_b, service_c] + service_d = CoreService() + service_d.name = "d" + service_a.dependencies = (service_b.name,) + service_c.dependencies = (service_d.name, service_b.name) + services = [service_a, service_b, service_c, service_d] + expected = {service_a.name, service_b.name, service_c.name, service_d.name} # when - results = [] permutations = itertools.permutations(services) for permutation in permutations: permutation = list(permutation) - result = ServiceDependencies(permutation).boot_order() - results.append(result) + results = ServiceDependencies(permutation).boot_order() - # then - for result in results: - assert result == [service_a, service_b, service_c] + # then + for result in results: + assert len(result) == 4 + result_set = {x.name for x in result} + a_index = result.index(service_a) + b_index = result.index(service_b) + c_index = result.index(service_c) + d_index = result.index(service_d) + assert b_index < a_index + assert d_index < c_index + assert b_index < c_index + assert expected == result_set + + def test_services_dependency_common2(self): + # given + service_a = CoreService() + service_a.name = "a" + service_b = CoreService() + service_b.name = "b" + service_c = CoreService() + service_c.name = "c" + service_d = CoreService() + service_d.name = "d" + service_a.dependencies = (service_b.name,) + service_b.dependencies = (service_c.name, service_d.name) + service_c.dependencies = (service_d.name,) + services = [service_a, service_b, service_c, service_d] + expected = {service_a.name, service_b.name, service_c.name, service_d.name} + + # when + permutations = itertools.permutations(services) + for permutation in permutations: + permutation = list(permutation) + results = ServiceDependencies(permutation).boot_order() + + # then + for result in results: + assert len(result) == 4 + result_set = {x.name for x in result} + a_index = result.index(service_a) + b_index = result.index(service_b) + c_index = result.index(service_c) + d_index = result.index(service_d) + assert b_index < a_index + assert c_index < b_index + assert d_index < b_index + assert d_index < c_index + assert expected == result_set