# imports

# local
# import numpy as np
# import networkx as nx
from src.topupopt.problems.esipp.problem import InfrastructurePlanningProblem
from src.topupopt.problems.esipp.network import Network
from src.topupopt.problems.esipp.time import EconomicTimeFrame
from src.topupopt.problems.esipp.blocks.prices import NODE_PRICE_OTHER

# *****************************************************************************
# *****************************************************************************

def check_problem_size(ipp: InfrastructurePlanningProblem, nc, nv, nnz):
    
    assert ipp.results["Problem"][0]["Number of constraints"] == nc # should be 80
    assert ipp.results["Problem"][0]["Number of variables"] == nv # should be 84
    assert ipp.results["Problem"][0]["Number of nonzeros"] == nnz

# *****************************************************************************
# *****************************************************************************

def build_solve_ipp(
    solver: str = 'glpk',
    solver_options: dict = None,
    use_sos_arcs: bool = False,
    use_sos_arc_groups: bool = False,
    arc_sos_weight_key: str = InfrastructurePlanningProblem.SOS1_ARC_WEIGHTS_NONE,
    arc_use_real_variables_if_possible: bool = False,
    use_sos_sense: bool = False,
    sense_sos_weight_key: int = (
        InfrastructurePlanningProblem.SOS1_SENSE_WEIGHT_NOMINAL_HIGHER
    ),
    sense_use_real_variables_if_possible: bool = False,
    sense_use_arc_interfaces: bool = False,
    perform_analysis: bool = False,
    plot_results: bool = False,
    print_solver_output: bool = False,
    time_frame: EconomicTimeFrame = None,
    networks: dict = None,
    converters: dict = None,
    static_losses_mode=None,
    mandatory_arcs: list = None,
    max_number_parallel_arcs: dict = None,
    arc_groups_dict: dict = None,
    init_aux_sets: bool = False,
    # discount_rates: dict = None,
    assessment_weights: dict = None,
    simplify_problem: bool = False,
    use_prices_block: bool = False,
    node_price_model: int = NODE_PRICE_OTHER
):
    
    if type(assessment_weights) != dict:
        assessment_weights = {}  # default

    if type(converters) != dict:
        converters = {}
        
    # time weights

    # relative weight of time period

    # one interval twice as long as the average is worth twice
    # one interval half as long as the average is worth half

    # time_weights = [
    #     [time_period_duration/average_time_interval_duration
    #       for time_period_duration in intraperiod_time_interval_duration]
    #     for p in range(number_periods)]

    time_weights = None  # nothing yet

    normalised_time_interval_duration = None  # nothing yet

    # create problem object

    ipp = InfrastructurePlanningProblem(
        # discount_rates=discount_rates,
        time_frame=time_frame,
        # reporting_periods=time_frame.reporting_periods,
        # time_intervals=time_frame.time_interval_durations,
        time_weights=time_weights,
        normalised_time_interval_duration=normalised_time_interval_duration,
        assessment_weights=assessment_weights,
        use_prices_block=use_prices_block,
        node_price_model=node_price_model
    )

    # add networks and systems

    for netkey, net in networks.items():
        ipp.add_network(network_key=netkey, network=net)

    # add converters

    for cvtkey, cvt in converters.items():
        ipp.add_converter(converter_key=cvtkey, converter=cvt)

    # define arcs as mandatory

    if type(mandatory_arcs) == list:
        for full_arc_key in mandatory_arcs:
            ipp.make_arc_mandatory(full_arc_key[0], full_arc_key[1:])

    # if make_all_arcs_mandatory:

    #     for network_key in ipp.networks:

    #         for arc_key in ipp.networks[network_key].edges(keys=True):

    #             # preexisting arcs are no good

    #             if ipp.networks[network_key].edges[arc_key][
    #                     Network.KEY_ARC_TECH].has_been_selected():

    #                 continue

    #             ipp.make_arc_mandatory(network_key, arc_key)

    # set up the use of sos for arc selection

    if use_sos_arcs:
        for network_key in ipp.networks:
            for arc_key in ipp.networks[network_key].edges(keys=True):
                if (
                    ipp.networks[network_key]
                    .edges[arc_key][Network.KEY_ARC_TECH]
                    .has_been_selected()
                ):
                    # skip arcs that have already been selected (pre-existing)
                    continue

                ipp.use_sos1_for_arc_selection(
                    network_key,
                    arc_key,
                    use_real_variables_if_possible=(
                        arc_use_real_variables_if_possible
                    ),
                    sos1_weight_method=arc_sos_weight_key,
                )            

    # set up the use of sos for flow sense determination

    if use_sos_sense:
        for network_key in ipp.networks:
            for arc_key in ipp.networks[network_key].edges(keys=True):
                if not ipp.networks[network_key].edges[arc_key][
                    Network.KEY_ARC_UND
                ]:
                    continue

                ipp.use_sos1_for_flow_senses(
                    network_key,
                    arc_key,
                    use_real_variables_if_possible=(
                        sense_use_real_variables_if_possible
                    ),
                    use_interface_variables=sense_use_arc_interfaces,
                    sos1_weight_method=sense_sos_weight_key,
                )

    elif sense_use_arc_interfaces:  # set up the use of arc interfaces w/o sos1
        for network_key in ipp.networks:
            for arc_key in ipp.networks[network_key].edges(keys=True):
                if (
                    ipp.networks[network_key]
                    .edges[arc_key][Network.KEY_ARC_TECH]
                    .has_been_selected()
                ):
                    continue

                ipp.use_interface_variables_for_arc_selection(network_key, arc_key)

    # static losses

    if static_losses_mode == ipp.STATIC_LOSS_MODE_ARR:
        ipp.place_static_losses_arrival_node()

    elif static_losses_mode == ipp.STATIC_LOSS_MODE_DEP:
        ipp.place_static_losses_departure_node()

    elif static_losses_mode == ipp.STATIC_LOSS_MODE_US:
        ipp.place_static_losses_upstream()

    elif static_losses_mode == ipp.STATIC_LOSS_MODE_DS:
        ipp.place_static_losses_downstream()

    else:
        raise ValueError("Unknown static loss modelling mode.")

    # *********************************************************************

    # groups

    if type(arc_groups_dict) != type(None):
        for key in arc_groups_dict:
            ipp.create_arc_group(
                arc_groups_dict[key], 
                use_sos1=use_sos_arc_groups,
                sos1_weight_method=arc_sos_weight_key
                )

    # *********************************************************************

    # maximum number of parallel arcs

    for key in max_number_parallel_arcs:
        ipp.set_maximum_number_parallel_arcs(
            network_key=key[0],
            node_a=key[1],
            node_b=key[2],
            limit=max_number_parallel_arcs[key],
        )

    # *********************************************************************

    if simplify_problem:
        ipp.simplify_peak_total_assessments()

    # *********************************************************************
    
    # instantiate (disable the default case v-a-v fixed losses)

    # ipp.instantiate(place_fixed_losses_upstream_if_possible=False)

    ipp.instantiate(initialise_ancillary_sets=init_aux_sets)
    
    # optimise
    ipp.optimise(
        solver_name=solver,
        solver_options=solver_options,
        output_options={},
        print_solver_output=print_solver_output,
    )
    # ipp.instance.pprint()
    # return the problem object
    return ipp

# *****************************************************************************
# *****************************************************************************