# imports

# standard

import math
from numbers import Real

import geopandas as gpd

# *****************************************************************************
# *****************************************************************************

# local, internal
from src.topupopt.data.gis.utils import read_gdf_file
from src.topupopt.data.buildings.dk import heat

# *****************************************************************************
# *****************************************************************************

class TestDataBuildingsDK:

    # *************************************************************************
    # *************************************************************************
    
    def test_demand_dict(self):
        
        # heat_demand_dict_by_building_entrance
        
        osm_data_filename = 'tests/data/gdf_osm.gpkg'
        building_data_filename = 'tests/data/gdf_buildings.gpkg'
        bdg_gdf_container_columns = ('ejerskaber','koordinater','bygningspunkt')
        number_time_intervals = 12
        min_to_max_ratio = 0.1
        intraperiod_time_interval_duration = [
            30*24*3600
            for i in range(number_time_intervals)
            ]
        total_demand_true = 1000
        total_area_true = 4563  # 5%: 4563 # 100%: 100882
        assessments = ['q']
        annual_heat_demand = {'q': 1000}
        air_temperature =  {'q': [5+i for i in range(number_time_intervals)]}
        
        gdf_osm = gpd.read_file(osm_data_filename)
        gdf_osm.set_index(['element_type', 'osmid'], drop=True, inplace=True)
    
        gdf_buildings = read_gdf_file(
            filename=building_data_filename,
            packed_columns=bdg_gdf_container_columns,
            index='index'
            )
        
        def verify_result(
                out_dict, 
                out_area, 
                total_demand_true, 
                total_area_true,
                # assessments,
                # number_time_intervals
                ):
            assert type(out_dict) == dict
            assert isinstance(out_area, Real)
            assert len(out_dict) == len(gdf_osm)
            assert math.isclose(out_area, total_area_true, abs_tol=1e-3) # 5%: 4563 # 100%: 100882
            for q in assessments:
                assert math.isclose(
                    sum(sum(v[q]) for k, v in out_dict.items() if len(v[q]) != 0),
                    total_demand_true, 
                    abs_tol=1e-3
                    )
            # output dict must be keyed by entrance id and then by scenario
            for k, v in out_dict.items():
                assert k in gdf_osm.index
                if len(v) == 0:
                    continue
                for q in assessments:
                    assert q in v
                    assert len(v[q]) == number_time_intervals or len(v[q]) == 0
        
        # drop entries to keep things fast
        share_keeper_osm_entries = 0.05 
        number_osm_entries = len(gdf_osm)
        for index in gdf_osm.index:
            if len(gdf_osm) < round(share_keeper_osm_entries*number_osm_entries):
                break
            gdf_osm.drop(index=index, inplace=True)
        
        # create profiles in accordance with a set of states and a positive gain
        
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            air_temperature=air_temperature,
            deviation_gain=1
            )
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)
            
        # create profiles in accordance with a set of states and a negative gain
        
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            air_temperature=air_temperature,
            deviation_gain=-1
            )
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)
         
        # create profiles in accordance with a sinusoidal function (no phase shift)
        
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            min_max_ratio=min_to_max_ratio,
            # air_temperature=air_temperature,
            # state_correlates_with_output=False
            # deviation_gain=1
            )
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)
        
        # create profiles in accordance with a sinusoidal function (with phase shift)
        
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            min_max_ratio=min_to_max_ratio,
            phase_shift_radians=math.pi/2
            # air_temperature=air_temperature,
            # state_correlates_with_output=False
            # deviation_gain=1
            )
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)
            
        # create profiles in accordance with states but without a predefined gain
    
        # create profile (no optimisation)
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            air_temperature=air_temperature,
            min_max_ratio=min_to_max_ratio,
            states_correlate_profile=True,
            )
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)
            
        # create profiles in accordance with states but without a predefined gain (optimisation)
        
        # remove all but one osm entry (to keep things light)
        for index in gdf_osm.index:
            if len(gdf_osm) <= 1:
                break
            gdf_osm.drop(index=index, inplace=True)
        
        # create profile
        heat_demand_dict, total_area = heat.heat_demand_profiles(
            gdf_osm=gdf_osm,
            gdf_buildings=gdf_buildings,
            time_interval_durations=intraperiod_time_interval_duration,
            assessments=assessments,
            annual_heat_demand=annual_heat_demand,
            air_temperature=air_temperature,
            min_max_ratio=min_to_max_ratio,
            states_correlate_profile=True,
            solver='glpk'
            )
        total_area_true = 200
        verify_result(heat_demand_dict, total_area, total_demand_true, total_area_true)

    # *************************************************************************
    # *************************************************************************
    
    # def test_demand_dict3(self):
        
    #     # heat_demand_dict_by_building_entrance
        
    #     osm_data_filename = 'tests/data/gdf_osm.gpkg'
    #     building_data_filename = 'tests/data/gdf_buildings.gpkg'
    #     bdg_gdf_container_columns = ('ejerskaber','koordinater','bygningspunkt')
    #     number_time_intervals = 12
    #     min_to_max_ratio = 0.1
    #     intraperiod_time_interval_duration = [
    #         30*24*3600
    #         for i in range(number_time_intervals)
    #         ]
    #     annual_heat_demand_scenario = 1000
    #     total_area = 1000
    #     states = [10 for i in range(number_time_intervals)]
        
    #     gdf_osm = gpd.read_file(osm_data_filename)
    #     gdf_osm.set_index(['element_type', 'osmid'], drop=True, inplace=True)
    
    #     gdf_buildings = read_gdf_file(
    #         filename=building_data_filename,
    #         packed_columns=bdg_gdf_container_columns,
    #         index='index'
    #         )
        
    #     # sinusoidal
        
    #     heat_demand_dict = heat.heat_demand_dict_by_building_entrance2(
    #         gdf_osm=gdf_osm,
    #         gdf_buildings=gdf_buildings,
    #         number_intervals=number_time_intervals,
    #         time_interval_durations=intraperiod_time_interval_duration,
    #         min_max_ratio=min_to_max_ratio,
    #         specific_demand=annual_heat_demand_scenario/total_area,
    #         )
    #     assert type(heat_demand_dict) == dict
    #     assert len(heat_demand_dict) == len(gdf_osm)
    #     assert math.isclose(
    #         annual_heat_demand_scenario, 
    #         sum(sum(value) for value in heat_demand_dict.values()),
    #         abs_tol=1e-3,
    #         )
        
    #     # sinusoidal with phase shift
    
    #     heat_demand_dict = heat.heat_demand_dict_by_building_entrance2(
    #         gdf_osm=gdf_osm,
    #         gdf_buildings=gdf_buildings,
    #         number_intervals=number_time_intervals,
    #         time_interval_durations=intraperiod_time_interval_duration,
    #         min_max_ratio=min_to_max_ratio,
    #         specific_demand=annual_heat_demand_scenario/total_area ,
    #         phase_shift_radians=math.pi,
    #         )
    #     assert type(heat_demand_dict) == dict
    #     assert len(heat_demand_dict) == len(gdf_osm)
    #     assert math.isclose(
    #         annual_heat_demand_scenario, 
    #         sum(sum(value) for value in heat_demand_dict.values()),
    #         abs_tol=1e-3,
    #         )
        
    #     # predefined deviation gain, positive
    
    #     heat_demand_dict = heat.heat_demand_dict_by_building_entrance2(
    #         gdf_osm=gdf_osm,
    #         gdf_buildings=gdf_buildings,
    #         number_intervals=number_time_intervals,
    #         time_interval_durations=intraperiod_time_interval_duration,
    #         states=states,
    #         specific_demand=annual_heat_demand_scenario/total_area ,
    #         deviation_gain=3,
    #         )
    #     assert type(heat_demand_dict) == dict
    #     assert len(heat_demand_dict) == len(gdf_osm)
    #     assert math.isclose(
    #         annual_heat_demand_scenario, 
    #         sum(sum(value) for value in heat_demand_dict.values()),
    #         abs_tol=1e-3,
    #         )
        
    #     # predefined deviation gain, negative
    
    #     heat_demand_dict = heat.heat_demand_dict_by_building_entrance2(
    #         gdf_osm=gdf_osm,
    #         gdf_buildings=gdf_buildings,
    #         number_intervals=number_time_intervals,
    #         time_interval_durations=intraperiod_time_interval_duration,
    #         states=states,
    #         specific_demand=annual_heat_demand_scenario/total_area ,
    #         deviation_gain=-3,
    #         )
    #     assert type(heat_demand_dict) == dict
    #     assert len(heat_demand_dict) == len(gdf_osm)
    #     assert math.isclose(
    #         annual_heat_demand_scenario, 
    #         sum(sum(value) for value in heat_demand_dict.values()),
    #         abs_tol=1e-3,
    #         )
        
    #     # optimisation
    
    #     heat_demand_dict = heat.heat_demand_dict_by_building_entrance2(
    #         gdf_osm=gdf_osm,
    #         gdf_buildings=gdf_buildings,
    #         number_intervals=number_time_intervals,
    #         time_interval_durations=intraperiod_time_interval_duration,
    #         states=states,
    #         specific_demand=annual_heat_demand_scenario/total_area,
    #         states_correlate_profile=True,
    #         solver='glpk'
    #         )
    #     assert type(heat_demand_dict) == dict
    #     assert len(heat_demand_dict) == len(gdf_osm)
    #     assert math.isclose(
    #         annual_heat_demand_scenario, 
    #         sum(sum(value) for value in heat_demand_dict.values()),
    #         abs_tol=1e-3,
    #         )
    
    # *************************************************************************
    # *************************************************************************
    
    # def test_bbr(self):
        
    #     # test get_bbr_building_data_geodataframe
        
    #     osm_data_filename = 'tests/data/gdf_osm.gpkg'
                
    #     gdf_osm = gpd.read_file(osm_data_filename)
    #     gdf_osm.set_index(['element_type', 'osmid'], drop=True, inplace=True)
        
    #     error_raised = False
    #     try:
    #         gdf_buildings, drop_list = bbr.get_bbr_building_data_geodataframe(
    #             list(gdf_osm[heat.label_osm_entrance_id]),
    #             None,
    #             None,
    #             None)
    #     except UnboundLocalError:
    #         error_raised = True
    #     assert error_raised
        
    #     # drop the rows with no data
    #     gdf_osm = gdf_osm.drop(
    #         index=[
    #             (gdf_osm[
    #                 gdf_osm[
    #                     heat.label_osm_entrance_id
    #                     ]==bdg_entrance_id].index)[0]
    #             for bdg_entrance_id in drop_list
    #             ]
    #         )

# *****************************************************************************
# *****************************************************************************