diff --git a/src/topupopt/problems/esipp/network.py b/src/topupopt/problems/esipp/network.py index 60fa3a8eecab944fe485868de5e52c5996cfaa4e..d2252565f648cd420aef72beece0b835fed49813 100644 --- a/src/topupopt/problems/esipp/network.py +++ b/src/topupopt/problems/esipp/network.py @@ -17,7 +17,7 @@ from math import inf import networkx as nx from ...data.gis.identify import get_edges_involving_node from ...data.gis.identify import find_unconnected_nodes -from .resource import are_prices_time_invariant, ResourcePrice +# from .resource import are_prices_time_invariant, ResourcePrice # ***************************************************************************** # ***************************************************************************** @@ -624,10 +624,7 @@ class Network(nx.MultiDiGraph): ) def __init__(self, network_type = NET_TYPE_HYBRID, **kwargs): - - # run base class init routine - nx.MultiDiGraph.__init__(self, **kwargs) - + # initialise the node type self.import_nodes = set() self.export_nodes = set() @@ -643,6 +640,9 @@ class Network(nx.MultiDiGraph): # nodes without outgoing directed arcs limitations self.nodes_w_out_dir_arc_limitations = dict() + # run base class init routine + nx.MultiDiGraph.__init__(self, **kwargs) + # process the input data for node_key in self.nodes(): self._process_node_data(node_key, data=self.nodes[node_key]) @@ -718,30 +718,6 @@ class Network(nx.MultiDiGraph): self._set_up_node(node_key, **kwargs) # TODO: automatically identify import and export nodes (without defining them explicitly) - - # ************************************************************************* - # ************************************************************************* - - # TODO: use a decorator function to prevent the original method(s) from being used inappropriately - - def add_node(self, node_key, **kwargs): - - self._handle_node(node_key, **kwargs) - - # ************************************************************************* - # ************************************************************************* - - # TODO: automatically check if node already exists and implications when "adding" one - - def add_nodes(self, node_key_data: list): - - # process the input data - for entry in node_key_data: - if type(entry) != tuple : - raise ValueError('The input must be a list of tuples.') - self._process_node_data(entry[0], entry[1]) - # add the nodes - nx.MultiDiGraph.add_nodes_from(self, node_key_data) # ************************************************************************* # ************************************************************************* @@ -782,11 +758,82 @@ class Network(nx.MultiDiGraph): # ************************************************************************* # ************************************************************************* - + + # TODO: use a decorator function to prevent the original method(s) from being used inappropriately + + def add_node(self, node_key, **kwargs): + # check if the node can be added and add it + self._handle_node(node_key, **kwargs) + + # ************************************************************************* + # ************************************************************************* + def modify_node(self, node_key, **kwargs): if not self.has_node(node_key): raise ValueError('The node indicated does not exist.') self._handle_node(node_key, **kwargs) + + # ************************************************************************* + # ************************************************************************* + + # TODO: automatically check if node already exists and implications when "adding" one + + # def add_nodes(self, node_key_data: list): + + # # process the input data + # for entry in node_key_data: + # if type(entry) != tuple : + # raise ValueError('The input must be a list of tuples.') + # # self._handle_node(entry[0], **entry[1]) + # self._process_node_data(entry[0], entry[1]) + # # add the nodes + # nx.MultiDiGraph.add_nodes_from(self, node_key_data) + + # ************************************************************************* + # ************************************************************************* + + def add_nodes_from(self, nodes_for_adding, **kwargs): + + # input formats: + # 1) container of node keys + # 2) container of tuples + # process the input data + for entry in nodes_for_adding: + if type(entry) == tuple and len(entry) == 2 and type(entry[1]) == dict: + # option 2 + # update the dict + new_dict = kwargs.copy() + new_dict.update(entry[1]) + self._handle_node(entry[0], **new_dict) + else: + # option 1 + self._handle_node(entry, **kwargs) + + # ************************************************************************* + # ************************************************************************* + + def modify_nodes_from(self, nodes_for_adding, **kwargs): + + # input formats: + # 1) container of node keys + # 2) container of tuples + # process the input data + for entry in nodes_for_adding: + if type(entry) == tuple and len(entry) == 2 and type(entry[1]) == dict: + # option 2 + new_dict = kwargs.copy() + new_dict.update(entry[1]) + if not self.has_node(entry[0]): + raise ValueError('The node indicated does not exist.') + self._handle_node(entry[0], **new_dict) + else: + # option 1 + if not self.has_node(entry): + raise ValueError('The node indicated does not exist.') + self._handle_node(entry, **kwargs) + + # ************************************************************************* + # ************************************************************************* def _handle_node(self, node_key, **kwargs): diff --git a/tests/test_esipp_network.py b/tests/test_esipp_network.py index b965a3e023df307db740f2ce3d981ba0cd993e09..d843c921dd472ab02ae10981e0519457c2902cc4 100644 --- a/tests/test_esipp_network.py +++ b/tests/test_esipp_network.py @@ -1,23 +1,15 @@ # imports # standard - import random - from networkx import binomial_tree, MultiDiGraph # local - from src.topupopt.problems.esipp.network import Arcs, Network - from src.topupopt.problems.esipp.network import ArcsWithoutLosses - from src.topupopt.problems.esipp.network import ArcsWithoutProportionalLosses - from src.topupopt.problems.esipp.network import ArcsWithoutStaticLosses - from src.topupopt.problems.esipp.resource import ResourcePrice - from src.topupopt.data.misc.utils import generate_pseudo_unique_key # ***************************************************************************** @@ -2179,11 +2171,8 @@ class TestNetwork: # create network network = Network() - # add node A - network.add_waypoint_node("A") - - # add node B - network.add_waypoint_node("B") + # add nodes A and B + network.add_nodes_from(['A','B']) # add arcs key_list = [ @@ -2303,11 +2292,12 @@ class TestNetwork: # add nodes node_a = 'A' - net.add_waypoint_node(node_a) + # net.add_waypoint_node(node_a) node_b = 'B' - net.add_waypoint_node(node_b) + # net.add_waypoint_node(node_b) node_c = 'C' - net.add_waypoint_node(node_c) + # net.add_waypoint_node(node_c) + net.add_nodes_from([node_a,node_b,node_c]) # add arcs node_pairs = ((node_a, node_b), (node_b, node_a),)