diff --git a/daemon/core/gui/coreclient.py b/daemon/core/gui/coreclient.py index ef50ace2..8c05a30f 100644 --- a/daemon/core/gui/coreclient.py +++ b/daemon/core/gui/coreclient.py @@ -283,6 +283,9 @@ class CoreClient: response = self.client.get_emane_config(self.session_id) self.emane_config = response.config + # update interface manager + self.interfaces_manager.joined(session.links) + # draw session self.app.canvas.reset_and_redraw(session) diff --git a/daemon/core/gui/interface.py b/daemon/core/gui/interface.py index 3310da90..8e2f4aa6 100644 --- a/daemon/core/gui/interface.py +++ b/daemon/core/gui/interface.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, List, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple import netaddr from netaddr import EUI, IPNetwork @@ -13,13 +13,22 @@ if TYPE_CHECKING: from core.gui.graph.node import CanvasNode +def get_index(interface: "core_pb2.Interface") -> int: + net = netaddr.IPNetwork(f"{interface.ip4}/{interface.ip4mask}") + ip_value = net.value + cidr_value = net.cidr.value + return ip_value - cidr_value + + class Subnets: def __init__(self, ip4: IPNetwork, ip6: IPNetwork) -> None: self.ip4 = ip4 self.ip6 = ip6 self.used_indexes = set() - def __eq__(self, other: "Subnets") -> bool: + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Subnets): + return False return self.key() == other.key() def __hash__(self) -> int: @@ -73,34 +82,55 @@ class InterfaceManager: self.used_subnets[subnets.key()] = subnets return subnets - def reset(self): + def reset(self) -> None: self.current_subnets = None self.used_subnets.clear() - def removed(self, links: List["core_pb2.Link"]): + def removed(self, links: List["core_pb2.Link"]) -> None: # get remaining subnets remaining_subnets = set() - - for link in links: + for edge in self.app.core.links.values(): + link = edge.link if link.HasField("interface_one"): subnets = self.get_subnets(link.interface_one) - if subnets not in remaining_subnets: - self.used_subnets.pop(subnets.key(), None) + remaining_subnets.add(subnets) if link.HasField("interface_two"): subnets = self.get_subnets(link.interface_two) - if subnets not in remaining_subnets: - self.used_subnets.pop(subnets.key(), None) + remaining_subnets.add(subnets) - def initialize_links(self, links: List["core_pb2.Link"]): + # remove all subnets from used subnets when no longer present + # or remove used indexes from subnet + interfaces = [] for link in links: if link.HasField("interface_one"): - subnets = self.get_subnets(link.interface_one) - if subnets.key() not in self.used_subnets: - self.used_subnets[subnets.key()] = subnets + interfaces.append(link.interface_one) if link.HasField("interface_two"): - subnets = self.get_subnets(link.interface_two) - if subnets.key() not in self.used_subnets: - self.used_subnets[subnets.key()] = subnets + interfaces.append(link.interface_two) + for interface in interfaces: + subnets = self.get_subnets(interface) + if subnets not in remaining_subnets: + if self.current_subnets == subnets: + self.current_subnets = None + self.used_subnets.pop(subnets.key(), None) + else: + index = get_index(interface) + subnets.used_indexes.discard(index) + + def joined(self, links: List["core_pb2.Link"]) -> None: + interfaces = [] + for link in links: + if link.HasField("interface_one"): + interfaces.append(link.interface_one) + if link.HasField("interface_two"): + interfaces.append(link.interface_two) + + # add to used subnets and mark used indexes + for interface in interfaces: + subnets = self.get_subnets(interface) + index = get_index(interface) + subnets.used_indexes.add(index) + if subnets.key() not in self.used_subnets: + self.used_subnets[subnets.key()] = subnets def next_index(self, node: "core_pb2.Node") -> int: if NodeUtils.is_router_node(node): @@ -128,7 +158,7 @@ class InterfaceManager: def determine_subnets( self, canvas_src_node: "CanvasNode", canvas_dst_node: "CanvasNode" - ): + ) -> None: src_node = canvas_src_node.core_node dst_node = canvas_dst_node.core_node is_src_container = NodeUtils.is_container_node(src_node.type) @@ -152,7 +182,7 @@ class InterfaceManager: def find_subnets( self, canvas_node: "CanvasNode", visited: Set[int] = None - ) -> Union[IPNetwork, None]: + ) -> Optional[IPNetwork]: logging.info("finding subnet for node: %s", canvas_node.core_node.name) canvas = self.app.canvas subnets = None