# imports

# standard

import random

# local, external

import numpy as np

# local, internal

import src.topupopt.problems.esipp.dynsys as dynsys

import src.topupopt.problems.esipp.converter as cvn

import src.topupopt.problems.esipp.signal as sgn

# ******************************************************************************
# ******************************************************************************


class TestConverter:
    # **************************************************************************
    # **************************************************************************

    # test converters
    # 1) regular and irregular time steps
    # 2) time invariant and time varying models
    # 3) integrate and do not integrate outputs
    # 4) generate coefficients

    # test creating a stateless converter without outputs

    # test creating a stateless converter with 1 output

    # test creating a stateless converter with 2 output

    # test creating a converter based on a single ODE system without outputs

    # test creating a converter based on a single ODE system with 1 output

    # test creating a converter based on a single ODE system with 2 output

    # test creating a converter based on a multiple ODE system without outputs

    # test creating a converter based on a multiple ODE system with 1 output

    # test creating a converter based on a multiple ODE multi-output system

    # **************************************************************************
    # **************************************************************************

    def test_full_converter_regular(self):
        time_step_durations = [1, 1, 1, 1]
        method_full_converter(time_step_durations)

    # **************************************************************************
    # **************************************************************************

    def test_full_converter_irregular(self):
        time_step_durations = [1, 1.5, 0.5, 1]
        method_full_converter(time_step_durations)

    # **************************************************************************
    # **************************************************************************


# ******************************************************************************
# ******************************************************************************


def get_stateless_model_data(relative_amplitude_variation: float = 0.0):
    mrh_deviation = random.random() - 0.5

    Aw = 6.22  # original: 6.22 m2

    min_rel_heat = 0.2 * (1 + relative_amplitude_variation * mrh_deviation)

    return Aw, min_rel_heat


# ******************************************************************************
# ******************************************************************************


def get_single_ode_model_data(relative_amplitude_variation: float = 0.0):
    # define how the coefficients change

    Ria_deviation = random.random() - 0.5

    # define the (A, B, C and D) matrices
    # A: n*n
    # B: n*m
    # C: r*n
    # D: r*m

    Ci = 1.360 * 3600000
    Ria = (1 + relative_amplitude_variation * Ria_deviation) * 5.31 / 3600000
    Aw = 6.22

    min_rel_heat = 0.2

    x0 = np.array([20])

    return Ci, Ria, Aw, min_rel_heat, x0


# *****************************************************************************
# *****************************************************************************


def get_multi_ode_model_data(relative_amplitude_variation: float = 0.0):
    # define how the coefficients change

    Rih_deviation = random.random() - 0.5

    Ria_deviation = random.random() - 0.5

    # define the (A, B, C and D) matrices
    # A: n*n
    # B: n*m
    # C: r*n
    # D: r*m

    # from Bacher and Madsen (2011): model TiTh

    Ci = 1.360 * 3600000  # original: 1.36 kWh/ºC
    Ch = 0.309 * 3600000  # original: 0.309 kWh/ºC
    Ria = (
        (1 + relative_amplitude_variation * Ria_deviation) * 5.31 / 3600000
    )  # original: 5.31 ºC/kWh
    Rih = (
        (1 + relative_amplitude_variation * Rih_deviation) * 0.639 / 3600000
    )  # original: 0.639 ºC/kWh
    Aw = 6.22  # original: 6.22 m2

    Pw = 5000  # 5 kW

    min_rel_heat = 0.2

    x0 = np.array([20, 20])

    return Ci, Ch, Ria, Rih, Aw, min_rel_heat, Pw, x0


# ******************************************************************************
# ******************************************************************************


def stateless_model(Aw, min_rel_heat):
    # inputs: Ta, phi_s, phi_h above the minimum, phi_h status
    # outputs: solar irradiance, heat

    d = np.array([[0, Aw, 0, 0], [0, 0, (1 - min_rel_heat), min_rel_heat]])

    return None, None, None, d


# ******************************************************************************
# ******************************************************************************


def single_node_model(Ci, Ria, Aw, min_rel_heat, Pw):
    # states: Ti and Th
    # inputs: Ta, phi_s, phi_h above the minimum, phi_h status
    # outputs: solar irradiance, heat

    a = np.array([[-1 / (Ria * Ci)]])
    b = np.array(
        [
            [
                1 / (Ci * Ria),
                Aw / Ci,
                Pw * (1 - min_rel_heat) / Ci,
                Pw * min_rel_heat / Ci,
            ]
        ]
    )
    c = np.array([[0], [0]])
    d = np.array([[0, Aw, 0, 0], [0, 0, Pw * (1 - min_rel_heat), Pw * min_rel_heat]])

    return a, b, c, d


# ******************************************************************************
# ******************************************************************************


def two_node_model(Ci, Ch, Ria, Rih, Aw, min_rel_heat, Pw):
    # states: Ti and Th
    # inputs: Ta, phi_s, phi_h above the minimum, phi_h status
    # outputs: solar irradiance, heat

    a = np.array(
        [[-(1 / Rih + 1 / Ria) / Ci, 1 / (Ci * Rih)], [1 / (Ch * Rih), -1 / (Ch * Rih)]]
    )
    b = np.array(
        [
            [1 / (Ci * Ria), Aw / Ci, 0, 0],
            [0, 0, Pw * (1 - min_rel_heat) / Ch, Pw * min_rel_heat / Ch],
        ]
    )
    c = np.array([[0, 0], [0, 0]])
    d = np.array([[0, Aw, 0, 0], [0, 0, Pw * (1 - min_rel_heat), Pw * min_rel_heat]])

    return a, b, c, d


# ******************************************************************************
# ******************************************************************************


def get_two_node_model_signals(number_samples):
    # signals

    # inputs:
    # 1) ambient temperature (real, can be fixed later)
    # 2) solar irradiation (real, can be fixed later)
    # 3) relative heat above minimum (nnr)
    # 4) heater status (binary)

    list_inputs = [
        sgn.FreeUnboundedSignal(number_samples),
        sgn.FreeUnboundedSignal(number_samples),
        sgn.NonNegativeRealSignal(number_samples),
        sgn.BinarySignal(number_samples),
    ]

    # states
    # 1) indoor temperature (real)
    # 2) heater temperature (real)

    list_states = [
        sgn.FreeUnboundedSignal(number_samples),
        sgn.FreeUnboundedSignal(number_samples),
    ]

    # outputs:
    # 1) solar gain (nnr)
    # 2) heat input (nnr)

    list_outputs = [
        sgn.NonNegativeRealSignal(number_samples),
        sgn.NonNegativeRealSignal(number_samples),
    ]

    return list_inputs, list_states, list_outputs


# ******************************************************************************
# ******************************************************************************


def method_full_converter(time_step_durations: list):
    # number of samples
    number_time_steps = len(time_step_durations)

    # get the coefficients
    Ci, Ch, Ria, Rih, Aw, min_rel_heat, Pw, x0 = get_multi_ode_model_data()

    # get the model
    a, b, c, d = two_node_model(Ci, Ch, Ria, Rih, Aw, min_rel_heat, Pw)

    # get the signals
    inputs, states, outputs = get_two_node_model_signals(number_time_steps)

    # create a dynamic system
    ds = dynsys.DynamicSystem(
        time_interval_durations=time_step_durations, A=a, B=b, C=c, D=d
    )

    # create a converter
    cvn1 = cvn.Converter(
        sys=ds,
        time_frame=None,
        initial_states=x0,
        turn_key_cost=3,
        inputs=inputs,
        states=states,
        outputs=outputs,
    )

    # get the dictionaries
    (a_innk, b_inmk, c_irnk, d_irmk, e_x_ink, e_y_irk) = cvn1.matrix_dictionaries()

    # TODO: check the dicts


# ******************************************************************************
# ******************************************************************************