Skip to content
Snippets Groups Projects
factories.py 7.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    from functools import wraps
    from inspect import isfunction, isgeneratorfunction, getmembers
    from collections.abc import Iterable
    from itertools import chain
    from importlib import import_module
    
    import torch
    import torch.nn as nn
    import torch.utils.data as data
    
    import constants
    import utils.metrics as metrics
    from utils.misc import R
    
    class _Desc:
        def __init__(self, key):
            self.key = key
        def __get__(self, instance, owner):
            return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
        def __set__(self, instance, values):
            if not (isinstance(values, Iterable) and len(values)==len(instance)):
                raise TypeError("incorrect type or number of values")
            for i, v in zip(range(len(instance)), values):
                setattr(instance[i], self.key, v)
    
    
    def _func_deco(func_name):
        def _wrapper(self, *args):
            # TODO: Add key argument support
            try:
                # Dispatch type 1
                ret = tuple(getattr(ins, func_name)(*args) for ins in self)
            except Exception:
                # Dispatch type 2
                if len(args) > 1 or (len(args[0]) != len(self)): raise
                ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
            return ret
        return _wrapper
    
    
    def _generator_deco(func_name):
        def _wrapper(self, *args, **kwargs):
            for ins in self:
                yield from getattr(ins, func_name)(*args, **kwargs)
        return _wrapper
    
    
    # Duck typing
    class Duck(tuple):
        __ducktype__ = object
        def __new__(cls, *args):
            if any(not isinstance(a, cls.__ducktype__) for a in args):
                raise TypeError("please check the input type")
            return tuple.__new__(cls, args)
    
    
    class DuckMeta(type):
        def __new__(cls, name, bases, attrs):
            assert len(bases) == 1
            for k, v in getmembers(bases[0]):
                if k.startswith('__'):
                    continue
                if isgeneratorfunction(v):
                    attrs[k] = _generator_deco(k)
                elif isfunction(v):
                    attrs[k] = _func_deco(k)
                else:
                    attrs[k] = _Desc(k)
            attrs['__ducktype__'] = bases[0]
            return super().__new__(cls, name, (Duck,), attrs)
    
    
    class DuckModel(nn.Module, metaclass=DuckMeta):
        pass
    
    
    class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
        @property
        def param_groups(self):
            return list(chain.from_iterable(ins.param_groups for ins in self))
    
    
    class DuckCriterion(nn.Module, metaclass=DuckMeta):
        pass
    
    
    class DuckDataset(data.Dataset, metaclass=DuckMeta):
        pass
    
    
    def _import_module(pkg: str, mod: str, rel=False):
        if not rel:
            # Use absolute import
            return import_module('.'.join([pkg, mod]), package=None)
        else:
            return import_module('.'+mod, package=pkg)
    
    
    def single_model_factory(model_name, C):
        name = model_name.strip().upper()
        if name == 'SIAMUNET_CONC':
            from models.siamunet_conc import SiamUnet_conc
            return SiamUnet_conc(C.num_feats_in, 2)
        elif name == 'SIAMUNET_DIFF':
            from models.siamunet_diff import SiamUnet_diff
            return SiamUnet_diff(C.num_feats_in, 2)
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif name == 'EF':
            from models.unet import Unet
            return Unet(C.num_feats_in, 2)
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
            raise NotImplementedError("{} is not a supported architecture".format(model_name))
    
    
    def single_optim_factory(optim_name, params, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        name = optim_name.strip().upper()
    
    Bobholamovic's avatar
    Bobholamovic committed
        if name == 'ADAM':
            return torch.optim.Adam(
                params, 
                betas=(0.9, 0.999),
                lr=C.lr,
                weight_decay=C.weight_decay
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif name == 'SGD':
            return torch.optim.SGD(
                params, 
                lr=C.lr,
                momentum=0.9,
                weight_decay=C.weight_decay
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
            raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
    
    
    def single_critn_factory(critn_name, C):
        import losses
        try:
            criterion, params = {
                'L1': (nn.L1Loss, ()),
                'MSE': (nn.MSELoss, ()),
                'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
                'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
            }[critn_name.upper()]
            return criterion(*params)
        except KeyError:
            raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
    
    
    def single_train_ds_factory(ds_name, C):
        from data.augmentation import Compose, Crop, Flip
    
    Bobholamovic's avatar
    Bobholamovic committed
        ds_name = ds_name.strip()
        module = _import_module('data', ds_name)
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset = getattr(module, ds_name+'Dataset')
        configs = dict(
            phase='train', 
            transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
            repeats=C.repeats
        )
        if ds_name == 'OSCD':
            configs.update(
                dict(
                    root = constants.IMDB_OSCD
                )
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif ds_name.startswith('AC'):
    
    Bobholamovic's avatar
    Bobholamovic committed
            configs.update(
                dict(
    
    Bobholamovic's avatar
    Bobholamovic committed
                    root = constants.IMDB_AirChange
    
    Bobholamovic's avatar
    Bobholamovic committed
                )
            )
        else:
            pass
    
        dataset_obj = dataset(**configs)
    
    Bobholamovic's avatar
    Bobholamovic committed
        return data.DataLoader(
            dataset_obj,
            batch_size=C.batch_size,
            shuffle=True,
            num_workers=C.num_workers,
            pin_memory=not (C.device == 'cpu'), drop_last=True
        )
    
    
    def single_val_ds_factory(ds_name, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        ds_name = ds_name.strip()
        module = _import_module('data', ds_name)
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset = getattr(module, ds_name+'Dataset')
        configs = dict(
            phase='val', 
    
    Bobholamovic's avatar
    Bobholamovic committed
            transforms=(None, None, None),
            repeats=1
    
    Bobholamovic's avatar
    Bobholamovic committed
        )
        if ds_name == 'OSCD':
            configs.update(
                dict(
                    root = constants.IMDB_OSCD
                )
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif ds_name.startswith('AC'):
    
    Bobholamovic's avatar
    Bobholamovic committed
            configs.update(
                dict(
                    root = constants.IMDB_AirChange
                )
            )
        else:
            pass
    
        dataset_obj = dataset(**configs)  
    
        # Create eval set
        return data.DataLoader(
            dataset_obj,
            batch_size=1,
            shuffle=False,
            num_workers=C.num_workers,
            pin_memory=False, drop_last=False
        )
    
    
    def _parse_input_names(name_str):
        return name_str.split('+')
    
    
    def model_factory(model_names, C):
        name_list = _parse_input_names(model_names)
        if len(name_list) > 1:
            return DuckModel(*(single_model_factory(name, C) for name in name_list))
        else:
            return single_model_factory(model_names, C)
    
    
    def optim_factory(optim_names, params, C):
        name_list = _parse_input_names(optim_names)
        if len(name_list) > 1:
            return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list))
        else:
            return single_optim_factory(optim_names, params, C)
    
    
    def critn_factory(critn_names, C):
        name_list = _parse_input_names(critn_names)
        if len(name_list) > 1:
            return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
        else:
            return single_critn_factory(critn_names, C)
    
    
    def data_factory(dataset_names, phase, C):
        name_list = _parse_input_names(dataset_names)
        if phase not in ('train', 'val'):
            raise ValueError("phase should be either 'train' or 'val'")
        fact = globals()['single_'+phase+'_ds_factory']
        if len(name_list) > 1:
            return DuckDataset(*(fact(name, C) for name in name_list))
        else:
            return fact(dataset_names, C)
    
    
    def metric_factory(metric_names, C):
        from utils import metrics
        name_list = _parse_input_names(metric_names)
    
    Bobholamovic's avatar
    Bobholamovic committed
        return [getattr(metrics, name.strip())() for name in name_list]