# imports

# standard
from math import inf
import random

# external
import osmnx as ox

# internal
from src.topupopt.problems.esipp.network import Network, ArcsWithoutLosses

# import src.topupopt.data.dhn.network as tuo_dhn
import src.topupopt.data.dhn.utils as utils
from src.topupopt.data.dhn.network import PipeTrenchOptions
from topupheat.pipes.single import StandardisedPipe, StandardisedPipeDatabase
import topupheat.pipes.trenches as trenches
from topupheat.common.fluids import FluidDatabase  # , Fluid

# *****************************************************************************
# *****************************************************************************

class TestDistrictHeatingNetworkUtils:
    # *************************************************************************
    # *************************************************************************

    def test_cost_pipes_single_arc(self):
        # fluid data
        waterdata_file = "tests/data/incropera2006_saturated_water.csv"
        phase = FluidDatabase.fluid_LIQUID
        fluid_db = FluidDatabase(fluid="fluid", phase=phase, source=waterdata_file)

        singlepipedata_files = ["tests/data/isoplus_singlepipes_s1.csv"]
        pipedb = StandardisedPipeDatabase(source=singlepipedata_files)
        pipe = StandardisedPipe(
            pipe_tuple=pipedb.pipe_tuples[0],
            # e_eff=pipe_e_eff,
            # sp=pipe_specific_price,
            db=pipedb,
        )

        # network details
        supply_temperature = 85 + 273.15
        return_temperature = 45 + 273.15
        pressure = 1e5
        # trench
        pipe_distance = 0.52  # m
        pipe_depth = 0.66  # m
        # environmental
        outdoor_temperature = 6 + 273.15  # K
        h_gs = inf  # 14.6 # W/m2K
        soil_k = 1.5  # W/mK
        # more information
        max_specific_pressure_loss = 100  # Pa/m

        mytrench = trenches.SupplyReturnPipeTrench(
            pipe_center_depth=pipe_depth,
            pipe_center_distance=pipe_distance,
            fluid_db=fluid_db,
            phase=phase,
            pressure=pressure,
            supply_temperature=supply_temperature,
            return_temperature=return_temperature,
            max_specific_pressure_loss=max_specific_pressure_loss,
            supply_pipe=pipe,
        )

        # create arcs object with multiple static loss values as a first case

        trench_length = 50

        # PipeTrenchOptions
        myarcs = PipeTrenchOptions(
            trench=mytrench, name="hellotrench", length=trench_length
        )

        number_steps = 3
        myarcs.set_static_losses(
            scenario_key="scenario2",
            ground_thermal_conductivity=[soil_k for i in range(number_steps)],
            ground_air_heat_transfer_coefficient=[h_gs for i in range(number_steps)],
            time_interval_duration=[3600 for i in range(number_steps)],
            temperature_surroundings=[outdoor_temperature for i in range(number_steps)],
        )

        mypipecosts = utils.cost_pipes(mytrench, trench_length)
        assert mypipecosts == (50.0,)

        # unrecognised input: using a string for the trench length
        error_raised = False
        try:
            mypipecosts = utils.cost_pipes(mytrench, "50")
        except ValueError:
            error_raised = True
        assert error_raised

    # *************************************************************************
    # *************************************************************************

    def test_cost_pipes_multiple_arcs(self):
        # fluid data
        waterdata_file = "tests/data/incropera2006_saturated_water.csv"
        phase = FluidDatabase.fluid_LIQUID
        fluid_db = FluidDatabase(fluid="fluid", phase=phase, source=waterdata_file)

        singlepipedata_files = ["tests/data/isoplus_singlepipes_s1.csv"]
        pipedb = StandardisedPipeDatabase(source=singlepipedata_files)
        pipe = StandardisedPipe(pipe_tuple=pipedb.pipe_tuples[0], db=pipedb)

        # network details
        supply_temperature = 85 + 273.15
        return_temperature = 45 + 273.15
        pressure = 1e5
        # trench
        pipe_distance = 0.52  # m
        pipe_depth = 0.66  # m
        # environmental
        outdoor_temperature = 6 + 273.15  # K
        h_gs = inf  # 14.6 # W/m2K
        soil_k = 1.5  # W/mK
        # more information
        max_specific_pressure_loss = 100  # Pa/m
        number_options = 2

        mytrench = trenches.SupplyReturnPipeTrench(
            pipe_center_depth=[pipe_depth for i in range(number_options)],
            pipe_center_distance=[pipe_distance for i in range(number_options)],
            fluid_db=fluid_db,
            phase=phase,
            pressure=[pressure for i in range(number_options)],
            supply_temperature=[supply_temperature for i in range(number_options)],
            return_temperature=[return_temperature for i in range(number_options)],
            max_specific_pressure_loss=[
                max_specific_pressure_loss for i in range(number_options)
            ],
            supply_pipe=[pipe for i in range(number_options)],
        )

        # PipeTrenchOptions
        myarcs = PipeTrenchOptions(trench=mytrench, name="hellotrench", length=50)

        # add static loss scenario
        myarcs.set_static_losses(
            scenario_key="scenario0",
            ground_thermal_conductivity=soil_k,
            ground_air_heat_transfer_coefficient=h_gs,
            time_interval_duration=3600,
            temperature_surroundings=outdoor_temperature,
        )
        # add another static loss scenario
        myarcs.set_static_losses(
            scenario_key="scenario1",
            ground_thermal_conductivity=soil_k + 1,
            ground_air_heat_transfer_coefficient=h_gs + 1,
            time_interval_duration=3600 + 100,
            temperature_surroundings=outdoor_temperature + 1,
        )
        # add static loss scenario
        number_steps = 3
        myarcs.set_static_losses(
            scenario_key="scenario2",
            ground_thermal_conductivity=[soil_k for i in range(number_steps)],
            ground_air_heat_transfer_coefficient=[h_gs for i in range(number_steps)],
            time_interval_duration=[3600 for i in range(number_steps)],
            temperature_surroundings=[outdoor_temperature for i in range(number_steps)],
        )
        trench_length = 50

        # PipeTrenchOptions
        myarcs = PipeTrenchOptions(
            trench=mytrench, name="hellotrench", length=trench_length
        )

        number_steps = 3
        myarcs.set_static_losses(
            scenario_key="scenario2",
            ground_thermal_conductivity=[soil_k for i in range(number_steps)],
            ground_air_heat_transfer_coefficient=[h_gs for i in range(number_steps)],
            time_interval_duration=[3600 for i in range(number_steps)],
            temperature_surroundings=[outdoor_temperature for i in range(number_steps)],
        )

        mypipecosts = utils.cost_pipes(mytrench, trench_length)
        assert mypipecosts == (100.0, 100.0)
        mypipecosts = utils.cost_pipes(mytrench, (trench_length, trench_length))
        assert mypipecosts == (100.0, 100.0)

        # unrecognised input: using a list for the trench lengths
        error_raised = False
        try:
            mypipecosts = utils.cost_pipes(mytrench, [trench_length])
        except ValueError:
            error_raised = True
        assert error_raised

    # *************************************************************************
    # *************************************************************************

    def test_plotting_heating_demand(self):
        # g = 'dh'
        # q = 0
        # p = 0 # any

        # months = [
        #     'Jan',
        #     'Fev',
        #     'Mar',
        #     'Apr',
        #     'May',
        #     'Jun',
        #     'Jul',
        #     'Aug',
        #     'Sep',
        #     'Oct',
        #     'Nov',
        #     'Dec'
        #     ]

        # monthly_end_use_demand = [
        #     network_flows_dict['gross_demand_gqk'][(g,q,k)]
        #     for k in ipp.instance.set_K_q[q]
        #     ]

        # monthly_total_demand = [
        #     flow_in_k[(g,q,p,k)]
        #     for k in ipp.instance.set_K_q[q]
        #     ]

        monthly_end_use_demand = [
            1466.7343572178731,
            1558.9721332796835,
            1466.7343572178725,
            1214.7360666398415,
            870.4999999999999,
            526.2639333601583,
            274.2656427821285,
            182.0278667203175,
            274.26564278212857,
            526.2639333601578,
            870.4999999999999,
            1214.7360666398413,
        ]

        monthly_total_demand = [
            1628.7218570926373,
            1721.9500038070653,
            1625.9983377978033,
            1361.9918030563408,
            1004.4476307714823,
            644.0561428602034,
            388.96294399263485,
            296.6013715992422,
            400.8473918244898,
            664.0492504106618,
            1019.8602740534218,
            1375.1761123698027,
        ]

        # monthly_losses = [
        #     161.98749987476413,
        #     162.9778705273818,
        #     159.26398057993083,
        #     147.2557364164993,
        #     133.9476307714824,
        #     117.79220950004515,
        #     114.69730121050634,
        #     114.57350487892472,
        #     126.58174904236125,
        #     137.78531705050398,
        #     149.3602740534219,
        #     160.4400457299614]

        monthly_losses = [
            total_demand - end_use_demand
            for end_use_demand, total_demand in zip(
                monthly_end_use_demand, monthly_total_demand
            )
        ]

        months = [
            "Jan",
            "Fev",
            "Mar",
            "Apr",
            "May",
            "Jun",
            "Jul",
            "Aug",
            "Sep",
            "Oct",
            "Nov",
            "Dec",
        ]

        utils.plot_heating_demand(
            losses=monthly_losses, end_use_demand=monthly_end_use_demand, labels=months
        )

    # *************************************************************************
    # *************************************************************************

    def test_summarising_output(self):
        # fluid data
        waterdata_file = "tests/data/incropera2006_saturated_water.csv"
        phase = FluidDatabase.fluid_LIQUID
        fluid_db = FluidDatabase(fluid="fluid", phase=phase, source=waterdata_file)

        singlepipedata_files = ["tests/data/isoplus_singlepipes_s1.csv"]
        pipedb = StandardisedPipeDatabase(source=singlepipedata_files)

        # network details
        supply_temperature = 85 + 273.15
        return_temperature = 45 + 273.15
        pressure = 1e5
        # trench
        pipe_distance = 0.52  # m
        pipe_depth = 0.66  # m
        # environmental
        outdoor_temperature = 6 + 273.15  # K
        h_gs = inf  # 14.6 # W/m2K
        soil_k = 1.5  # W/mK
        # more information
        max_specific_pressure_loss = 100  # Pa/m
        number_options = 3
        mytrench = trenches.SupplyReturnPipeTrench(
            pipe_center_depth=[pipe_depth for i in range(number_options)],
            pipe_center_distance=[pipe_distance for i in range(number_options)],
            fluid_db=fluid_db,
            phase=phase,
            pressure=[pressure for i in range(number_options)],
            supply_temperature=[supply_temperature for i in range(number_options)],
            return_temperature=[return_temperature for i in range(number_options)],
            max_specific_pressure_loss=[
                max_specific_pressure_loss for i in range(number_options)
            ],
            supply_pipe=[
                StandardisedPipe(pipe_tuple=pipedb.pipe_tuples[i], db=pipedb)
                for i in range(number_options)
            ],
        )

        # *********************************************************************

        # arc 1
        trench_length1 = 50
        myarcs1 = PipeTrenchOptions(
            trench=mytrench, name="hellotrench", length=trench_length1
        )
        # add static loss scenario
        myarcs1.set_static_losses(
            scenario_key="scenario0",
            ground_thermal_conductivity=soil_k,
            ground_air_heat_transfer_coefficient=h_gs,
            time_interval_duration=3600,
            temperature_surroundings=outdoor_temperature,
        )
        # add another static loss scenario
        myarcs1.set_static_losses(
            scenario_key="scenario1",
            ground_thermal_conductivity=soil_k + 1,
            ground_air_heat_transfer_coefficient=h_gs + 1,
            time_interval_duration=3600 + 100,
            temperature_surroundings=outdoor_temperature + 1,
        )
        # set the option
        myarcs1.options_selected[2] = True

        # *********************************************************************

        # arc 2
        trench_length2 = 25
        myarcs2 = PipeTrenchOptions(
            trench=mytrench, name="hellotrench", length=trench_length2
        )
        # add static loss scenario
        myarcs2.set_static_losses(
            scenario_key="scenario0",
            ground_thermal_conductivity=soil_k,
            ground_air_heat_transfer_coefficient=h_gs,
            time_interval_duration=3600,
            temperature_surroundings=outdoor_temperature,
        )
        # add another static loss scenario
        myarcs2.set_static_losses(
            scenario_key="scenario1",
            ground_thermal_conductivity=soil_k + 1,
            ground_air_heat_transfer_coefficient=h_gs + 1,
            time_interval_duration=3600 + 100,
            temperature_surroundings=outdoor_temperature + 1,
        )
        # set the option
        myarcs2.options_selected[0] = True

        # *********************************************************************

        # create network
        mynet = Network()
        mynet.add_directed_arc(node_key_a=0, node_key_b=1, arcs=myarcs1)
        mynet.add_directed_arc(node_key_a=1, node_key_b=2, arcs=myarcs2)
        mynet.add_directed_arc(
            0,
            2,
            arcs=ArcsWithoutLosses(
                name="hello",
                capacity=[1, 2, 3],
                minimum_cost=[4, 10, 16],
                specific_capacity_cost=3,
                capacity_is_instantaneous=False,
            ),
        )

        # *********************************************************************

        out = utils.summarise_network_by_pipe_technology(mynet, False)

        assert "DN20" in out and out["DN20"] == 25
        assert "DN32" in out and out["DN32"] == 50

        utils.summarise_network_by_pipe_technology(mynet, True)

    # *************************************************************************
    # *************************************************************************

    def test_summarising_output_osmnx(self):
        # get the network
        _protonet = ox.graph_from_point(
            (55.71654, 9.11728),
            network_type="drive",
            custom_filter=('["highway"~"residential|tertiary|unclassified|service"]'),
            truncate_by_edge=True,
        )

        # create a network object
        network = Network(incoming_graph_data=_protonet)

        # fluid data
        waterdata_file = "tests/data/incropera2006_saturated_water.csv"
        phase = FluidDatabase.fluid_LIQUID
        fluid_db = FluidDatabase(fluid="fluid", phase=phase, source=waterdata_file)

        singlepipedata_files = ["tests/data/isoplus_singlepipes_s1.csv"]
        pipedb = StandardisedPipeDatabase(source=singlepipedata_files)

        # network details
        supply_temperature = 85 + 273.15
        return_temperature = 45 + 273.15
        pressure = 1e5
        # trench
        pipe_distance = 0.52  # m
        pipe_depth = 0.66  # m
        # environmental
        outdoor_temperature = 6 + 273.15  # K
        h_gs = inf  # 14.6 # W/m2K
        soil_k = 1.5  # W/mK
        # more information
        max_specific_pressure_loss = 100  # Pa/m

        number_options = 4

        mytrench = trenches.SupplyReturnPipeTrench(
            pipe_center_depth=[pipe_depth for i in range(number_options)],
            pipe_center_distance=[pipe_distance for i in range(number_options)],
            fluid_db=fluid_db,
            phase=phase,
            pressure=[pressure for i in range(number_options)],
            supply_temperature=[supply_temperature for i in range(number_options)],
            return_temperature=[return_temperature for i in range(number_options)],
            max_specific_pressure_loss=[
                max_specific_pressure_loss for i in range(number_options)
            ],
            supply_pipe=[
                StandardisedPipe(pipe_tuple=pipedb.pipe_tuples[i], db=pipedb)
                for i in range(number_options)
            ],
        )

        # *********************************************************************

        for edge_key in network.edges(keys=True):
            # set up arc
            myarcs = PipeTrenchOptions(
                trench=mytrench,
                name="hellotrench",
                length=network.edges[edge_key]["length"],
            )
            # add static loss scenario
            myarcs.set_static_losses(
                scenario_key="scenario0",
                ground_thermal_conductivity=soil_k,
                ground_air_heat_transfer_coefficient=h_gs,
                time_interval_duration=3600,
                temperature_surroundings=outdoor_temperature,
            )
            # add another static loss scenario
            myarcs.set_static_losses(
                scenario_key="scenario1",
                ground_thermal_conductivity=soil_k + 1,
                ground_air_heat_transfer_coefficient=h_gs + 1,
                time_interval_duration=3600 + 100,
                temperature_surroundings=outdoor_temperature + 1,
            )
            # set the option
            myarcs.options_selected[
                random.randint(0, myarcs.number_options() - 1)
            ] = True
            # modify the arc
            network.modify_network_arc(*edge_key, {Network.KEY_ARC_TECH: myarcs})
        # deselect one of the trenches
        trench_index = random.randint(0, network.number_of_edges() - 1)
        edge_key = tuple(network.edges(keys=True))[trench_index]
        network.edges[edge_key][Network.KEY_ARC_TECH].options_selected[
            network.edges[edge_key][Network.KEY_ARC_TECH].options_selected.index(True)
        ] = False

        # add non-trench Arcs object
        network.add_directed_arc(
            0,
            2,
            arcs=ArcsWithoutLosses(
                name="hello",
                capacity=[1, 2, 3],
                minimum_cost=[4, 10, 16],
                specific_capacity_cost=3,
                capacity_is_instantaneous=False,
            ),
        )
        # update the nodes
        network.add_node(0, x=55, y=12)
        network.add_node(2, x=55.01, y=12.01)
        
        # *********************************************************************

        utils.summarise_network_by_pipe_technology(network, False)

        utils.plot_network_layout(network=network, include_basemap=False)

        utils.plot_network_layout(network=network, include_basemap=True)

    # *************************************************************************
    # *************************************************************************


# ******************************************************************************
# ******************************************************************************