# imports

# standard
import math

# local
# import numpy as np
# import networkx as nx
import pyomo.environ as pyo

# import src.topupopt.problems.esipp.utils as utils
from src.topupopt.data.misc.utils import generate_pseudo_unique_key
from src.topupopt.problems.esipp.problem import InfrastructurePlanningProblem
from src.topupopt.problems.esipp.network import Arcs, Network
from src.topupopt.problems.esipp.resource import ResourcePrice
# from src.topupopt.problems.esipp.utils import compute_cost_volume_metrics
from src.topupopt.problems.esipp.utils import statistics
from src.topupopt.problems.esipp.time import EconomicTimeFrame
# from src.topupopt.problems.esipp.converter import Converter

# *****************************************************************************
# *****************************************************************************

class TestESIPPProblem:
    
    solver = 'glpk'
    # solver = 'scip'
    # solver = 'cbc'
    
    def build_solve_ipp(
        self,
        solver: str = None,
        solver_options: dict = None,
        use_sos_arcs: bool = False,
        arc_sos_weight_key: str = (InfrastructurePlanningProblem.SOS1_ARC_WEIGHTS_NONE),
        arc_use_real_variables_if_possible: bool = False,
        use_sos_sense: bool = False,
        sense_sos_weight_key: int = (
            InfrastructurePlanningProblem.SOS1_SENSE_WEIGHT_NOMINAL_HIGHER
        ),
        sense_use_real_variables_if_possible: bool = False,
        sense_use_arc_interfaces: bool = False,
        perform_analysis: bool = False,
        plot_results: bool = False,
        print_solver_output: bool = False,
        time_frame: EconomicTimeFrame = None,
        networks: dict = None,
        converters: dict = None,
        static_losses_mode=None,
        mandatory_arcs: list = None,
        max_number_parallel_arcs: dict = None,
        arc_groups_dict: dict = None,
        init_aux_sets: bool = False,
        # discount_rates: dict = None,
        assessment_weights: dict = None,
        simplify_problem: bool = False,
    ):
        if type(solver) == type(None):
            solver = self.solver
        
        if type(assessment_weights) != dict:
            assessment_weights = {}  # default

        if type(converters) != dict:
            converters = {}
            
        # time weights

        # relative weight of time period

        # one interval twice as long as the average is worth twice
        # one interval half as long as the average is worth half

        # time_weights = [
        #     [time_period_duration/average_time_interval_duration
        #       for time_period_duration in intraperiod_time_interval_duration]
        #     for p in range(number_periods)]

        time_weights = None  # nothing yet

        normalised_time_interval_duration = None  # nothing yet

        # create problem object

        ipp = InfrastructurePlanningProblem(
            # discount_rates=discount_rates,
            time_frame=time_frame,
            # reporting_periods=time_frame.reporting_periods,
            # time_intervals=time_frame.time_interval_durations,
            time_weights=time_weights,
            normalised_time_interval_duration=normalised_time_interval_duration,
            assessment_weights=assessment_weights,
        )

        # add networks and systems

        for netkey, net in networks.items():
            ipp.add_network(network_key=netkey, network=net)

        # add converters

        for cvtkey, cvt in converters.items():
            ipp.add_converter(converter_key=cvtkey, converter=cvt)

        # define arcs as mandatory

        if type(mandatory_arcs) == list:
            for full_arc_key in mandatory_arcs:
                ipp.make_arc_mandatory(full_arc_key[0], full_arc_key[1:])

        # if make_all_arcs_mandatory:

        #     for network_key in ipp.networks:

        #         for arc_key in ipp.networks[network_key].edges(keys=True):

        #             # preexisting arcs are no good

        #             if ipp.networks[network_key].edges[arc_key][
        #                     Network.KEY_ARC_TECH].has_been_selected():

        #                 continue

        #             ipp.make_arc_mandatory(network_key, arc_key)

        # set up the use of sos for arc selection

        if use_sos_arcs:
            for network_key in ipp.networks:
                for arc_key in ipp.networks[network_key].edges(keys=True):
                    if (
                        ipp.networks[network_key]
                        .edges[arc_key][Network.KEY_ARC_TECH]
                        .has_been_selected()
                    ):
                        continue

                    ipp.use_sos1_for_arc_selection(
                        network_key,
                        arc_key,
                        use_real_variables_if_possible=(
                            arc_use_real_variables_if_possible
                        ),
                        sos1_weight_method=arc_sos_weight_key,
                    )

        # set up the use of sos for flow sense determination

        if use_sos_sense:
            for network_key in ipp.networks:
                for arc_key in ipp.networks[network_key].edges(keys=True):
                    if not ipp.networks[network_key].edges[arc_key][
                        Network.KEY_ARC_UND
                    ]:
                        continue

                    ipp.use_sos1_for_flow_senses(
                        network_key,
                        arc_key,
                        use_real_variables_if_possible=(
                            sense_use_real_variables_if_possible
                        ),
                        use_interface_variables=sense_use_arc_interfaces,
                        sos1_weight_method=sense_sos_weight_key,
                    )

        elif sense_use_arc_interfaces:  # set up the use of arc interfaces w/o sos1
            for network_key in ipp.networks:
                for arc_key in ipp.networks[network_key].edges(keys=True):
                    if (
                        ipp.networks[network_key]
                        .edges[arc_key][Network.KEY_ARC_TECH]
                        .has_been_selected()
                    ):
                        continue

                    ipp.use_interface_variables_for_arc_selection(network_key, arc_key)

        # static losses

        if static_losses_mode == ipp.STATIC_LOSS_MODE_ARR:
            ipp.place_static_losses_arrival_node()

        elif static_losses_mode == ipp.STATIC_LOSS_MODE_DEP:
            ipp.place_static_losses_departure_node()

        elif static_losses_mode == ipp.STATIC_LOSS_MODE_US:
            ipp.place_static_losses_upstream()

        elif static_losses_mode == ipp.STATIC_LOSS_MODE_DS:
            ipp.place_static_losses_downstream()

        else:
            raise ValueError("Unknown static loss modelling mode.")

        # *********************************************************************

        # groups

        if type(arc_groups_dict) != type(None):
            for key in arc_groups_dict:
                ipp.create_arc_group(arc_groups_dict[key])

        # *********************************************************************

        # maximum number of parallel arcs

        for key in max_number_parallel_arcs:
            ipp.set_maximum_number_parallel_arcs(
                network_key=key[0],
                node_a=key[1],
                node_b=key[2],
                limit=max_number_parallel_arcs[key],
            )

        # *********************************************************************

        if simplify_problem:
            ipp.simplify_peak_total_assessments()

        # *********************************************************************
        
        # instantiate (disable the default case v-a-v fixed losses)

        # ipp.instantiate(place_fixed_losses_upstream_if_possible=False)

        ipp.instantiate(initialise_ancillary_sets=init_aux_sets)
        # ipp.instance.pprint()
        # optimise
        ipp.optimise(
            solver_name=solver,
            solver_options=solver_options,
            output_options={},
            print_solver_output=print_solver_output,
        )
        # ipp.instance.pprint()
        # return the problem object
        return ipp

        # *********************************************************************
        # *********************************************************************

    # *************************************************************************
    # *************************************************************************

    def test_problem_increasing_imp_prices(self):
        
        # assessment
        q = 0

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one import, one regular
        mynet = Network()

        # import node
        node_IMP = 'I'
        mynet.add_import_node(
            node_key=node_IMP,
            prices={
                qpk: ResourcePrice(prices=[1.0, 2.0], volumes=[0.5, None])
                for qpk in tf.qpk()
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): 1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_IMP, node_key_b=node_A, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()

        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False,
            time_frame=tf,
            networks={"mynet": mynet},
            static_losses_mode=True,  # just to reach a line,
            mandatory_arcs=[],
            max_number_parallel_arcs={},
            simplify_problem=False
        )

        assert not ipp.has_peak_total_assessments()
        assert ipp.results["Problem"][0]["Number of constraints"] == 10
        assert ipp.results["Problem"][0]["Number of variables"] == 11
        assert ipp.results["Problem"][0]["Number of nonzeros"] == 20

        # *********************************************************************
        # *********************************************************************

        # validation

        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_IMP, node_A, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # the flows should be 1.0, 0.0 and 2.0
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_IMP, node_A, 0, q, 0)]),
            2.0,
            abs_tol=1e-6,
        )

        # arc amplitude should be two
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_IMP, node_A, 0)]),
            2.0,
            abs_tol=0.01,
        )

        # capex should be four
        assert math.isclose(pyo.value(ipp.instance.var_capex), 4.0, abs_tol=1e-3)

        # sdncf should be -3.5
        assert math.isclose(pyo.value(ipp.instance.var_sdncf_q[q]), -3.5, abs_tol=1e-3)

        # the objective function should be -7.5
        assert math.isclose(pyo.value(ipp.instance.obj_f), -7.5, abs_tol=1e-3)
        
    # *************************************************************************
    # *************************************************************************

    def test_problem_decreasing_imp_prices(self):
        
        # assessment
        q = 0

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one import, one regular
        mynet = Network()

        # import node
        node_IMP = 'I'
        mynet.add_import_node(
            node_key=node_IMP,
            prices={
                qpk: ResourcePrice(prices=[2.0, 1.0], volumes=[0.5, 3.0])
                for qpk in tf.qpk()
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): 1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_IMP, node_key_b=node_A, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()

        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False, 
            time_frame=tf,
            networks={"mynet": mynet},
            static_losses_mode=True,  # just to reach a line,
            mandatory_arcs=[],
            max_number_parallel_arcs={},
            simplify_problem=False
        )

        assert not ipp.has_peak_total_assessments()
        assert ipp.results["Problem"][0]["Number of constraints"] == 14 # 10 prior to nonconvex block
        assert ipp.results["Problem"][0]["Number of variables"] == 13 # 11 prior to nonconvex block
        assert ipp.results["Problem"][0]["Number of nonzeros"] == 28 # 20 prior to nonconvex block

        # *********************************************************************
        # *********************************************************************

        # validation

        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_IMP, node_A, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # the flows should be 1.0, 0.0 and 2.0
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_IMP, node_A, 0, q, 0)]),
            2.0,
            abs_tol=1e-6,
        )

        # arc amplitude should be two
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_IMP, node_A, 0)]),
            2.0,
            abs_tol=0.01,
        )

        # capex should be four
        assert math.isclose(pyo.value(ipp.instance.var_capex), 4.0, abs_tol=1e-3)

        # sdncf should be -2.5
        assert math.isclose(pyo.value(ipp.instance.var_sdncf_q[q]), -2.5, abs_tol=1e-3)

        # the objective function should be -7.5
        assert math.isclose(pyo.value(ipp.instance.obj_f), -6.5, abs_tol=1e-3)
                
    # *************************************************************************
    # *************************************************************************

    def test_problem_decreasing_imp_prices_infinite_capacity(self):
        
        # assessment
        q = 0

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one import, one regular
        mynet = Network()

        # import node
        node_IMP = 'I'
        mynet.add_import_node(
            node_key=node_IMP,
            prices={
                qpk: ResourcePrice(prices=[2.0, 1.0], volumes=[0.5, None])
                for qpk in tf.qpk()
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): 1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_IMP, node_key_b=node_A, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()
        
        # trigger the error
        error_raised = False
        try:
            # no sos, regular time intervals
            self.build_solve_ipp(
                solver_options={},
                perform_analysis=False,
                plot_results=False,  # True,
                print_solver_output=False,
                time_frame=tf,
                networks={"mynet": mynet},
                static_losses_mode=True,  # just to reach a line,
                mandatory_arcs=[],
                max_number_parallel_arcs={},
                simplify_problem=False,
            )
        except Exception:
            error_raised = True
        assert error_raised

    # *************************************************************************
    # *************************************************************************

    def test_problem_decreasing_exp_prices(self):
        # assessment
        q = 0
        # time
        number_intervals = 1
        # periods
        number_periods = 1

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one export, one regular
        mynet = Network()

        # import node
        node_EXP = generate_pseudo_unique_key(mynet.nodes())
        mynet.add_export_node(
            node_key=node_EXP,
            prices={
                (q, p, k): ResourcePrice(prices=[2.0, 1.0], volumes=[0.5, None])
                for p in range(number_periods)
                for k in range(number_intervals)
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): -1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_A, node_key_b=node_EXP, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()

        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False,
            time_frame=tf,
            networks={"mynet": mynet},
            static_losses_mode=True,  # just to reach a line,
            mandatory_arcs=[],
            max_number_parallel_arcs={},
            simplify_problem=False,
        )

        assert not ipp.has_peak_total_assessments()
        assert ipp.results["Problem"][0]["Number of constraints"] == 10
        assert ipp.results["Problem"][0]["Number of variables"] == 11
        assert ipp.results["Problem"][0]["Number of nonzeros"] == 20

        # *********************************************************************
        # *********************************************************************

        # validation

        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_A, node_EXP, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # the flows should be 1.0, 0.0 and 2.0
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_A, node_EXP, 0, q, 0)]),
            1.0,
            abs_tol=1e-6,
        )

        # arc amplitude should be two
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_A, node_EXP, 0)]),
            1.0,
            abs_tol=0.01,
        )

        # capex should be four
        assert math.isclose(pyo.value(ipp.instance.var_capex), 3.0, abs_tol=1e-3)

        # sdncf should be 1.0
        assert math.isclose(pyo.value(ipp.instance.var_sdncf_q[q]), 1.0, abs_tol=1e-3)

        # the objective function should be -7.5
        assert math.isclose(pyo.value(ipp.instance.obj_f), -2.0, abs_tol=1e-3)
        
    # *************************************************************************
    # *************************************************************************

    def test_problem_increasing_exp_prices(self):
        # assessment
        q = 0
        # time
        number_intervals = 1
        # periods
        number_periods = 1

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one export, one regular
        mynet = Network()

        # import node
        node_EXP = generate_pseudo_unique_key(mynet.nodes())
        mynet.add_export_node(
            node_key=node_EXP,
            prices={
                (q, p, k): ResourcePrice(prices=[1.0, 2.0], volumes=[0.25, 3.0])
                for p in range(number_periods)
                for k in range(number_intervals)
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): -1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_A, node_key_b=node_EXP, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()

        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False,
            time_frame=tf,
            networks={"mynet": mynet},
            static_losses_mode=True,  # just to reach a line,
            mandatory_arcs=[],
            max_number_parallel_arcs={},
            simplify_problem=False,
        )

        assert not ipp.has_peak_total_assessments()
        assert ipp.results["Problem"][0]["Number of constraints"] == 14 # 10 before nonconvex block
        assert ipp.results["Problem"][0]["Number of variables"] == 13 # 11 before nonconvex block
        assert ipp.results["Problem"][0]["Number of nonzeros"] == 28 # 20 before nonconvex block

        # *********************************************************************
        # *********************************************************************

        # validation

        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_A, node_EXP, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # the flows should be 1.0, 0.0 and 2.0
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_A, node_EXP, 0, q, 0)]),
            1.0,
            abs_tol=1e-6,
        )

        # arc amplitude should be two
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_A, node_EXP, 0)]),
            1.0,
            abs_tol=0.01,
        )

        # capex should be four
        assert math.isclose(pyo.value(ipp.instance.var_capex), 3.0, abs_tol=1e-3)

        # sdncf should be 0.75
        assert math.isclose(pyo.value(ipp.instance.var_sdncf_q[q]), 0.75, abs_tol=1e-3)

        # the objective function should be -2.25
        assert math.isclose(pyo.value(ipp.instance.obj_f), -2.25, abs_tol=1e-3)
                
    # *************************************************************************
    # *************************************************************************

    def test_problem_increasing_exp_prices_infinite_capacity(self):
        # assessment
        q = 0
        # time
        number_intervals = 1
        # periods
        number_periods = 1

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,)},
            time_interval_durations={q: (1,)},
        )

        # 2 nodes: one export, one regular
        mynet = Network()

        # import node
        node_EXP = generate_pseudo_unique_key(mynet.nodes())
        mynet.add_export_node(
            node_key=node_EXP,
            prices={
                (q, p, k): ResourcePrice(prices=[1.0, 2.0], volumes=[0.25, None])
                for p in range(number_periods)
                for k in range(number_intervals)
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(node_key=node_A, base_flow={(q, 0): -1.0})

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_A, node_key_b=node_EXP, arcs=arc_tech_IA)

        # identify node types
        mynet.identify_node_types()
        
        # trigger the error
        error_raised = False
        try:
            # no sos, regular time intervals
            self.build_solve_ipp(
                solver_options={},
                perform_analysis=False,
                plot_results=False,  # True,
                print_solver_output=False,
                time_frame=tf,
                networks={"mynet": mynet},
                static_losses_mode=True,  # just to reach a line,
                mandatory_arcs=[],
                max_number_parallel_arcs={},
                simplify_problem=False,
            )
        except Exception:
            error_raised = True
        assert error_raised

    # *************************************************************************
    # *************************************************************************

    def test_problem_increasing_imp_decreasing_exp_prices(self):
        # scenario
        q = 0
        # time
        number_intervals = 2
        # periods
        number_periods = 1

        tf = EconomicTimeFrame(
            discount_rate=0.0,
            reporting_periods={q: (0,)},
            reporting_period_durations={q: (365 * 24 * 3600,)},
            time_intervals={q: (0,1)},
            time_interval_durations={q: (1,1)},
        )

        # 3 nodes: one import, one export, one regular
        mynet = Network()

        # import node
        node_IMP = 'I'
        mynet.add_import_node(
            node_key=node_IMP,
            prices={
                (q, p, k): ResourcePrice(prices=[1.0, 2.0], volumes=[0.5, None])
                for p in range(number_periods)
                for k in range(number_intervals)
            },
        )

        # export node
        node_EXP = generate_pseudo_unique_key(mynet.nodes())
        mynet.add_export_node(
            node_key=node_EXP,
            prices={
                (q, p, k): ResourcePrice(prices=[2.0, 1.0], volumes=[0.5, None])
                for p in range(number_periods)
                for k in range(number_intervals)
            },
        )

        # other nodes
        node_A = 'A'
        mynet.add_source_sink_node(
            node_key=node_A, base_flow={(q, 0): 1.0, (q, 1): -1.0}
        )

        # arc IA
        arc_tech_IA = Arcs(
            name="any",
            efficiency={(q, 0): 0.5, (q, 1): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_IMP, node_key_b=node_A, arcs=arc_tech_IA)

        # arc AE
        arc_tech_AE = Arcs(
            name="any",
            efficiency={(q, 0): 0.5, (q, 1): 0.5},
            efficiency_reverse=None,
            static_loss=None,
            capacity=[3],
            minimum_cost=[2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
            validate=False,
        )
        mynet.add_directed_arc(node_key_a=node_A, node_key_b=node_EXP, arcs=arc_tech_AE)

        # identify node types
        mynet.identify_node_types()

        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False,
            time_frame=tf,
            networks={"mynet": mynet},
            static_losses_mode=True,  # just to reach a line,
            mandatory_arcs=[],
            max_number_parallel_arcs={},
            simplify_problem=False,
            # discount_rates={0: (0.0,)},
        )

        assert not ipp.has_peak_total_assessments()
        assert ipp.results["Problem"][0]["Number of constraints"] == 23
        assert ipp.results["Problem"][0]["Number of variables"] == 26
        assert ipp.results["Problem"][0]["Number of nonzeros"] == 57

        # *********************************************************************
        # *********************************************************************

        # validation

        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_IMP, node_A, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )
        # the arc should be installed since it is required for feasibility
        assert (
            True
            in ipp.networks["mynet"]
            .edges[(node_A, node_EXP, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # interval 0: import only
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_IMP, node_A, 0, q, 0)]),
            2.0,
            abs_tol=1e-6,
        )
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_A, node_EXP, 0, q, 0)]),
            0.0,
            abs_tol=1e-6,
        )
        # interval 1: export only
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_IMP, node_A, 0, q, 1)]),
            0.0,
            abs_tol=1e-6,
        )
        assert math.isclose(
            pyo.value(ipp.instance.var_v_glljqk[("mynet", node_A, node_EXP, 0, q, 1)]),
            1.0,
            abs_tol=1e-6,
        )

        # IA amplitude
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_IMP, node_A, 0)]),
            2.0,
            abs_tol=0.01,
        )
        # AE amplitude
        assert math.isclose(
            pyo.value(ipp.instance.var_v_amp_gllj[("mynet", node_A, node_EXP, 0)]),
            1.0,
            abs_tol=0.01,
        )

        # capex should be 7.0: 4+3
        assert math.isclose(pyo.value(ipp.instance.var_capex), 7.0, abs_tol=1e-3)

        # sdncf should be -2.5: -3.5+1.0
        assert math.isclose(pyo.value(ipp.instance.var_sdncf_q[q]), -2.5, abs_tol=1e-3)

        # the objective function should be -9.5: -7.5-2.5
        assert math.isclose(pyo.value(ipp.instance.obj_f), -9.5, abs_tol=1e-3)

            
    # *************************************************************************
    # *************************************************************************
        
    def test_direct_imp_exp_network_higher_exp_prices(self):
        
        # time frame
        q = 0
        tf = EconomicTimeFrame(
            discount_rate=3.5/100,
            reporting_periods={q: (0,1)},
            reporting_period_durations={q: (365 * 24 * 3600,365 * 24 * 3600)},
            time_intervals={q: (0,1)},
            time_interval_durations={q: (1,1)},
        )    
        
        # 4 nodes: one import, one export, two supply/demand nodes
        mynet = Network()
    
        # import node
        imp_node_key = 'thatimpnode'
        imp_prices = {
            qpk: ResourcePrice(
                prices=0.5,
                volumes=None,
            )
            for qpk in tf.qpk()
            }
        mynet.add_import_node(
            node_key=imp_node_key,
            prices=imp_prices
        )
    
        # export node
        exp_node_key = 'thatexpnode'
        exp_prices = {
            qpk: ResourcePrice(
                prices=1.5,
                volumes=None,
            )
            for qpk in tf.qpk()
            }
        mynet.add_export_node(
            node_key=exp_node_key,
            prices=exp_prices,
        )
        
        # add arc without fixed losses from import node to export
        arc_tech_IE = Arcs(
            name="IE",
            # efficiency=[1, 1, 1, 1],
            efficiency={(0, 0): 1, (0, 1): 1, (0, 2): 1, (0, 3): 1},
            efficiency_reverse=None,
            static_loss=None,
            validate=False,
            capacity=[0.5, 1.0, 2.0],
            minimum_cost=[5, 5.1, 5.2],
            specific_capacity_cost=1,
            capacity_is_instantaneous=False,
        )
        mynet.add_directed_arc(
            node_key_a=imp_node_key, node_key_b=exp_node_key, arcs=arc_tech_IE
        )
    
        # identify node types
        mynet.identify_node_types()
    
        # no sos, regular time intervals
        ipp = self.build_solve_ipp(
            solver_options={},
            perform_analysis=False,
            plot_results=False,  # True,
            print_solver_output=False,
            networks={"mynet": mynet},
            time_frame=tf,
            static_losses_mode=InfrastructurePlanningProblem.STATIC_LOSS_MODE_DEP,
            mandatory_arcs=[],
            max_number_parallel_arcs={}
        )
    
        # export prices are higher: it makes sense to install the arc since the
        # revenue (@ max. cap.) exceeds the cost of installing the arc

        assert (
            True
            in ipp.networks["mynet"]
            .edges[(imp_node_key, exp_node_key, 0)][Network.KEY_ARC_TECH]
            .options_selected
        )

        # overview
        (imports_qpk, 
         exports_qpk, 
         balance_qpk, 
         import_costs_qpk, 
         export_revenue_qpk, 
         ncf_qpk, 
         aggregate_static_demand_qpk,
         aggregate_static_supply_qpk,
         aggregate_static_balance_qpk) = statistics(ipp)

        # there should be no imports

        abs_tol = 1e-6
        
        abs_tol = 1e-3
        imports_qp = sum(imports_qpk[qpk] for qpk in tf.qpk() if qpk[1] == 0)
        assert imports_qp > 0.0 - abs_tol

        abs_tol = 1e-3
        import_costs_qp = sum(import_costs_qpk[qpk] for qpk in tf.qpk() if qpk[1] == 0)
        assert import_costs_qp > 0.0 - abs_tol

        # there should be no exports

        abs_tol = 1e-2

        exports_qp = sum(exports_qpk[(q, 0, k)] for k in tf.time_intervals[q])
        export_revenue_qp = sum(export_revenue_qpk[(q, 0, k)] for k in tf.time_intervals[q])
        assert exports_qp > 0.0 - abs_tol
        assert export_revenue_qp > 0.0 - abs_tol

        # the revenue should exceed the costs

        abs_tol = 1e-2

        assert (
            export_revenue_qp > import_costs_qp - abs_tol
        )

        # the capex should be positive

        abs_tol = 1e-6

        assert pyo.value(ipp.instance.var_capex) > 0 - abs_tol
        
    # *************************************************************************
    # *************************************************************************

# *****************************************************************************
# *****************************************************************************