from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable
from itertools import chain
from importlib import import_module

import torch
import torch.nn as nn
import torch.utils.data as data

import constants
import utils.metrics as metrics
from utils.misc import R
from data.augmentation import *


class _Desc:
    def __init__(self, key):
        self.key = key
    def __get__(self, instance, owner):
        return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
    def __set__(self, instance, values):
        if not (isinstance(values, Iterable) and len(values)==len(instance)):
            raise TypeError("incorrect type or number of values")
        for i, v in zip(range(len(instance)), values):
            setattr(instance[i], self.key, v)


def _func_deco(func_name):
    def _wrapper(self, *args):
        return tuple(getattr(ins, func_name)(*args) for ins in self)
    return _wrapper


def _generator_deco(func_name):
    def _wrapper(self, *args, **kwargs):
        for ins in self:
            yield from getattr(ins, func_name)(*args, **kwargs)
    return _wrapper


# Duck typing
class Duck(tuple):
    __ducktype__ = object
    def __new__(cls, *args):
        if any(not isinstance(a, cls.__ducktype__) for a in args):
            raise TypeError("please check the input type")
        return tuple.__new__(cls, args)

    def __add__(self, tup):
        raise NotImplementedError

    def __mul__(self, tup):
        raise NotImplementedError


class DuckMeta(type):
    def __new__(cls, name, bases, attrs):
        assert len(bases) == 1
        for k, v in getmembers(bases[0]):
            if k.startswith('__'):
                continue
            if isgeneratorfunction(v):
                attrs.setdefault(k, _generator_deco(k))
            elif isfunction(v):
                attrs.setdefault(k, _func_deco(k))
            else:
                attrs.setdefault(k, _Desc(k))
        attrs['__ducktype__'] = bases[0]
        return super().__new__(cls, name, (Duck,), attrs)


class DuckModel(nn.Module):
    def __init__(self, *models):
        super().__init__()
        ## XXX: The state_dict will be a little larger in size
        # Since some extra bytes are stored in every key
        self._m = nn.ModuleList(models)

    def __len__(self):
        return len(self._m)

    def __getitem__(self, idx):
        return self._m[idx]

    def __repr__(self):
        return repr(self._m)


class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
    # Cuz this is an instance method
    @property
    def param_groups(self):
        return list(chain.from_iterable(ins.param_groups for ins in self))

    # This is special in dispatching
    def load_state_dict(self, state_dicts):
        for optim, state_dict in zip(self, state_dicts):
            optim.load_state_dict(state_dict)


class DuckCriterion(nn.Module, metaclass=DuckMeta):
    pass


class DuckDataset(data.Dataset, metaclass=DuckMeta):
    pass


def _import_module(pkg: str, mod: str, rel=False):
    if not rel:
        # Use absolute import
        return import_module('.'.join([pkg, mod]), package=None)
    else:
        return import_module('.'+mod, package=pkg)


def single_model_factory(model_name, C):
    name = model_name.strip().upper()
    if name == 'SIAMUNET_CONC':
        from models.siamunet_conc import SiamUnet_conc
        return SiamUnet_conc(C.num_feats_in, 2)
    elif name == 'SIAMUNET_DIFF':
        from models.siamunet_diff import SiamUnet_diff
        return SiamUnet_diff(C.num_feats_in, 2)
    elif name == 'EF':
        from models.unet import Unet
        return Unet(C.num_feats_in, 2)
    else:
        raise NotImplementedError("{} is not a supported architecture".format(model_name))


def single_optim_factory(optim_name, params, C):
    optim_name = optim_name.strip()
    name = optim_name.upper()
    if name == 'ADAM':
        return torch.optim.Adam(
            params, 
            betas=(0.9, 0.999),
            lr=C.lr,
            weight_decay=C.weight_decay
        )
    elif name == 'SGD':
        return torch.optim.SGD(
            params, 
            lr=C.lr,
            momentum=0.9,
            weight_decay=C.weight_decay
        )
    else:
        raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))


def single_critn_factory(critn_name, C):
    import losses
    critn_name = critn_name.strip()
    try:
        criterion, params = {
            'L1': (nn.L1Loss, ()),
            'MSE': (nn.MSELoss, ()),
            'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
            'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
        }[critn_name.upper()]
        return criterion(*params)
    except KeyError:
        raise NotImplementedError("{} is not a supported criterion type".format(critn_name))


def _get_basic_configs(ds_name, C):
    if ds_name == 'OSCD':
        return dict(
            root = constants.IMDB_OSCD
        )
    elif ds_name.startswith('AC'):
        return dict(
            root = constants.IMDB_AIRCHANGE
        )
    elif ds_name == 'Lebedev':
        return dict(
            root = constants.IMDB_LEBEDEV
        )
    else:
        return dict()
        

def single_train_ds_factory(ds_name, C):
    ds_name = ds_name.strip()
    module = _import_module('data', ds_name)
    dataset = getattr(module, ds_name+'Dataset')
    configs = dict(
        phase='train', 
        transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
        repeats=C.repeats
    )
    
    # Update some common configurations
    configs.update(_get_basic_configs(ds_name, C))

    # Set phase-specific ones
    if ds_name == 'Lebedev':
        configs.update(
            dict(
                subsets = ('real',)
            )
        )
    else:
        pass
    
    dataset_obj = dataset(**configs)
    
    return data.DataLoader(
        dataset_obj,
        batch_size=C.batch_size,
        shuffle=True,
        num_workers=C.num_workers,
        pin_memory=not (C.device == 'cpu'), drop_last=True
    )


def single_val_ds_factory(ds_name, C):
    ds_name = ds_name.strip()
    module = _import_module('data', ds_name)
    dataset = getattr(module, ds_name+'Dataset')
    configs = dict(
        phase='val', 
        transforms=(None, None, None),
        repeats=1
    )

    # Update some common configurations
    configs.update(_get_basic_configs(ds_name, C))

    # Set phase-specific ones
    if ds_name == 'Lebedev':
        configs.update(
            dict(
                subsets = ('real',)
            )
        )
    else:
        pass
    
    dataset_obj = dataset(**configs)  

    # Create eval set
    return data.DataLoader(
        dataset_obj,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=False, drop_last=False
    )


def _parse_input_names(name_str):
    return name_str.split('+')


def model_factory(model_names, C):
    name_list = _parse_input_names(model_names)
    if len(name_list) > 1:
        return DuckModel(*(single_model_factory(name, C) for name in name_list))
    else:
        return single_model_factory(model_names, C)


def optim_factory(optim_names, models, C):
    name_list = _parse_input_names(optim_names)
    num_models = len(models) if isinstance(models, DuckModel) else 1
    if len(name_list) != num_models:
        raise ValueError("the number of optimizers does not match the number of models")
    
    if num_models > 1:
        optims = []
        for name, model in zip(name_list, models):
            param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
            optims.append(single_optim_factory(name, param_groups, C))
        return DuckOptimizer(*optims)
    else:
        return single_optim_factory(
            optim_names, 
            [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], 
            C
        )


def critn_factory(critn_names, C):
    name_list = _parse_input_names(critn_names)
    if len(name_list) > 1:
        return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
    else:
        return single_critn_factory(critn_names, C)


def data_factory(dataset_names, phase, C):
    name_list = _parse_input_names(dataset_names)
    if phase not in ('train', 'val'):
        raise ValueError("phase should be either 'train' or 'val'")
    fact = globals()['single_'+phase+'_ds_factory']
    if len(name_list) > 1:
        return DuckDataset(*(fact(name, C) for name in name_list))
    else:
        return fact(dataset_names, C)


def metric_factory(metric_names, C):
    from utils import metrics
    name_list = _parse_input_names(metric_names)
    return [getattr(metrics, name.strip())() for name in name_list]