From 6327bb3d51b976fbdeadb3ee8133dcaa63e74193 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pedro=20L=2E=20Magalh=C3=A3es?= <pmlpm@posteo.de>
Date: Tue, 23 Jul 2024 23:23:01 +0200
Subject: [PATCH] Made add_node and add_nodes_from methods identical to those
 of networkx.

---
 src/topupopt/problems/esipp/network.py | 107 ++++++++++++++++++-------
 tests/test_esipp_network.py            |  22 ++---
 2 files changed, 83 insertions(+), 46 deletions(-)

diff --git a/src/topupopt/problems/esipp/network.py b/src/topupopt/problems/esipp/network.py
index 60fa3a8..d225256 100644
--- a/src/topupopt/problems/esipp/network.py
+++ b/src/topupopt/problems/esipp/network.py
@@ -17,7 +17,7 @@ from math import inf
 import networkx as nx
 from ...data.gis.identify import get_edges_involving_node
 from ...data.gis.identify import find_unconnected_nodes
-from .resource import are_prices_time_invariant, ResourcePrice
+# from .resource import are_prices_time_invariant, ResourcePrice
 
 # *****************************************************************************
 # *****************************************************************************
@@ -624,10 +624,7 @@ class Network(nx.MultiDiGraph):
         )
 
     def __init__(self, network_type = NET_TYPE_HYBRID, **kwargs):
-        
-        # run base class init routine
-        nx.MultiDiGraph.__init__(self, **kwargs)
-        
+    
         # initialise the node type
         self.import_nodes = set()
         self.export_nodes = set()
@@ -643,6 +640,9 @@ class Network(nx.MultiDiGraph):
         # nodes without outgoing directed arcs limitations
         self.nodes_w_out_dir_arc_limitations = dict()
         
+        # run base class init routine
+        nx.MultiDiGraph.__init__(self, **kwargs)
+        
         # process the input data
         for node_key in self.nodes():
             self._process_node_data(node_key, data=self.nodes[node_key])
@@ -718,30 +718,6 @@ class Network(nx.MultiDiGraph):
         self._set_up_node(node_key, **kwargs)
     
     # TODO: automatically identify import and export nodes (without defining them explicitly)
-
-    # *************************************************************************
-    # *************************************************************************
-    
-    # TODO: use a decorator function to prevent the original method(s) from being used inappropriately
-    
-    def add_node(self, node_key, **kwargs):
-        
-        self._handle_node(node_key, **kwargs)
-
-    # *************************************************************************
-    # *************************************************************************
-    
-    # TODO: automatically check if node already exists and implications when "adding" one
-    
-    def add_nodes(self, node_key_data: list):
-        
-        # process the input data
-        for entry in node_key_data:
-            if type(entry) != tuple :
-                raise ValueError('The input must be a list of tuples.')
-            self._process_node_data(entry[0], entry[1])
-        # add the nodes
-        nx.MultiDiGraph.add_nodes_from(self, node_key_data)
         
     # *************************************************************************
     # *************************************************************************
@@ -782,11 +758,82 @@ class Network(nx.MultiDiGraph):
 
     # *************************************************************************
     # *************************************************************************
-        
+    
+    # TODO: use a decorator function to prevent the original method(s) from being used inappropriately
+    
+    def add_node(self, node_key, **kwargs):
+        # check if the node can be added and add it
+        self._handle_node(node_key, **kwargs)
+
+    # *************************************************************************
+    # *************************************************************************
+            
     def modify_node(self, node_key, **kwargs):
         if not self.has_node(node_key):
             raise ValueError('The node indicated does not exist.')
         self._handle_node(node_key, **kwargs)
+
+    # *************************************************************************
+    # *************************************************************************
+    
+    # TODO: automatically check if node already exists and implications when "adding" one
+    
+    # def add_nodes(self, node_key_data: list):
+        
+    #     # process the input data
+    #     for entry in node_key_data:
+    #         if type(entry) != tuple :
+    #             raise ValueError('The input must be a list of tuples.')
+    #         # self._handle_node(entry[0], **entry[1])
+    #         self._process_node_data(entry[0], entry[1])
+    #     # add the nodes
+    #     nx.MultiDiGraph.add_nodes_from(self, node_key_data)
+
+    # *************************************************************************
+    # *************************************************************************
+    
+    def add_nodes_from(self, nodes_for_adding, **kwargs):
+        
+        # input formats:
+        # 1) container of node keys
+        # 2) container of tuples
+        # process the input data
+        for entry in nodes_for_adding:
+            if type(entry) == tuple and len(entry) == 2 and type(entry[1]) == dict:
+                # option 2
+                # update the dict
+                new_dict = kwargs.copy()
+                new_dict.update(entry[1])
+                self._handle_node(entry[0], **new_dict)
+            else:
+                # option 1
+                self._handle_node(entry, **kwargs)
+
+    # *************************************************************************
+    # *************************************************************************
+    
+    def modify_nodes_from(self, nodes_for_adding, **kwargs):
+        
+        # input formats:
+        # 1) container of node keys
+        # 2) container of tuples
+        # process the input data
+        for entry in nodes_for_adding:
+            if type(entry) == tuple and len(entry) == 2 and type(entry[1]) == dict:
+                # option 2
+                new_dict = kwargs.copy()
+                new_dict.update(entry[1])
+                if not self.has_node(entry[0]):
+                    raise ValueError('The node indicated does not exist.')
+                self._handle_node(entry[0], **new_dict)
+            else:
+                # option 1
+                if not self.has_node(entry):
+                    raise ValueError('The node indicated does not exist.')
+                self._handle_node(entry, **kwargs)
+
+    # *************************************************************************
+    # *************************************************************************
         
     def _handle_node(self, node_key, **kwargs):
         
diff --git a/tests/test_esipp_network.py b/tests/test_esipp_network.py
index b965a3e..d843c92 100644
--- a/tests/test_esipp_network.py
+++ b/tests/test_esipp_network.py
@@ -1,23 +1,15 @@
 # imports
 
 # standard
-
 import random
-
 from networkx import binomial_tree, MultiDiGraph
 
 # local
-
 from src.topupopt.problems.esipp.network import Arcs, Network
-
 from src.topupopt.problems.esipp.network import ArcsWithoutLosses
-
 from src.topupopt.problems.esipp.network import ArcsWithoutProportionalLosses
-
 from src.topupopt.problems.esipp.network import ArcsWithoutStaticLosses
-
 from src.topupopt.problems.esipp.resource import ResourcePrice
-
 from src.topupopt.data.misc.utils import generate_pseudo_unique_key
 
 # *****************************************************************************
@@ -2179,11 +2171,8 @@ class TestNetwork:
         # create network
         network = Network()
 
-        # add node A
-        network.add_waypoint_node("A")
-
-        # add node B
-        network.add_waypoint_node("B")
+        # add nodes A and B
+        network.add_nodes_from(['A','B'])
 
         # add arcs
         key_list = [
@@ -2303,11 +2292,12 @@ class TestNetwork:
         
         # add nodes
         node_a = 'A'
-        net.add_waypoint_node(node_a)
+        # net.add_waypoint_node(node_a)
         node_b = 'B'
-        net.add_waypoint_node(node_b)
+        # net.add_waypoint_node(node_b)
         node_c = 'C'
-        net.add_waypoint_node(node_c)
+        # net.add_waypoint_node(node_c)
+        net.add_nodes_from([node_a,node_b,node_c])
         
         # add arcs
         node_pairs = ((node_a, node_b), (node_b, node_a),)
-- 
GitLab