Skip to content
Snippets Groups Projects
factories.py 8.72 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    from functools import wraps
    from inspect import isfunction, isgeneratorfunction, getmembers
    from collections.abc import Iterable
    from itertools import chain
    from importlib import import_module
    
    import torch
    import torch.nn as nn
    import torch.utils.data as data
    
    import constants
    import utils.metrics as metrics
    from utils.misc import R
    
    Bobholamovic's avatar
    Bobholamovic committed
    from data.augmentation import *
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    class _Desc:
        def __init__(self, key):
            self.key = key
        def __get__(self, instance, owner):
            return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
        def __set__(self, instance, values):
            if not (isinstance(values, Iterable) and len(values)==len(instance)):
                raise TypeError("incorrect type or number of values")
            for i, v in zip(range(len(instance)), values):
                setattr(instance[i], self.key, v)
    
    
    def _func_deco(func_name):
        def _wrapper(self, *args):
    
    Bobholamovic's avatar
    Bobholamovic committed
            return tuple(getattr(ins, func_name)(*args) for ins in self)
    
    Bobholamovic's avatar
    Bobholamovic committed
        return _wrapper
    
    
    def _generator_deco(func_name):
        def _wrapper(self, *args, **kwargs):
            for ins in self:
                yield from getattr(ins, func_name)(*args, **kwargs)
        return _wrapper
    
    
    # Duck typing
    class Duck(tuple):
        __ducktype__ = object
        def __new__(cls, *args):
            if any(not isinstance(a, cls.__ducktype__) for a in args):
                raise TypeError("please check the input type")
            return tuple.__new__(cls, args)
    
    
    Bobholamovic's avatar
    Bobholamovic committed
        def __add__(self, tup):
            raise NotImplementedError
    
        def __mul__(self, tup):
            raise NotImplementedError
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    class DuckMeta(type):
        def __new__(cls, name, bases, attrs):
            assert len(bases) == 1
            for k, v in getmembers(bases[0]):
                if k.startswith('__'):
                    continue
                if isgeneratorfunction(v):
    
    Bobholamovic's avatar
    Bobholamovic committed
                    attrs.setdefault(k, _generator_deco(k))
    
    Bobholamovic's avatar
    Bobholamovic committed
                elif isfunction(v):
    
    Bobholamovic's avatar
    Bobholamovic committed
                    attrs.setdefault(k, _func_deco(k))
    
    Bobholamovic's avatar
    Bobholamovic committed
                else:
    
    Bobholamovic's avatar
    Bobholamovic committed
                    attrs.setdefault(k, _Desc(k))
    
    Bobholamovic's avatar
    Bobholamovic committed
            attrs['__ducktype__'] = bases[0]
            return super().__new__(cls, name, (Duck,), attrs)
    
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    class DuckModel(nn.Module):
        def __init__(self, *models):
            super().__init__()
            ## XXX: The state_dict will be a little larger in size
            # Since some extra bytes are stored in every key
            self._m = nn.ModuleList(models)
    
        def __len__(self):
            return len(self._m)
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
        def __getitem__(self, idx):
            return self._m[idx]
    
        def __repr__(self):
            return repr(self._m)
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
    
    Bobholamovic's avatar
    Bobholamovic committed
        # Cuz this is an instance method
    
    Bobholamovic's avatar
    Bobholamovic committed
        @property
        def param_groups(self):
            return list(chain.from_iterable(ins.param_groups for ins in self))
    
    
    Bobholamovic's avatar
    Bobholamovic committed
        # This is special in dispatching
        def load_state_dict(self, state_dicts):
            for optim, state_dict in zip(self, state_dicts):
                optim.load_state_dict(state_dict)
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    class DuckCriterion(nn.Module, metaclass=DuckMeta):
        pass
    
    
    class DuckDataset(data.Dataset, metaclass=DuckMeta):
        pass
    
    
    def _import_module(pkg: str, mod: str, rel=False):
        if not rel:
            # Use absolute import
            return import_module('.'.join([pkg, mod]), package=None)
        else:
            return import_module('.'+mod, package=pkg)
    
    
    def single_model_factory(model_name, C):
        name = model_name.strip().upper()
        if name == 'SIAMUNET_CONC':
            from models.siamunet_conc import SiamUnet_conc
            return SiamUnet_conc(C.num_feats_in, 2)
        elif name == 'SIAMUNET_DIFF':
            from models.siamunet_diff import SiamUnet_diff
            return SiamUnet_diff(C.num_feats_in, 2)
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif name == 'EF':
            from models.unet import Unet
            return Unet(C.num_feats_in, 2)
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
            raise NotImplementedError("{} is not a supported architecture".format(model_name))
    
    
    def single_optim_factory(optim_name, params, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        optim_name = optim_name.strip()
        name = optim_name.upper()
    
    Bobholamovic's avatar
    Bobholamovic committed
        if name == 'ADAM':
            return torch.optim.Adam(
                params, 
                betas=(0.9, 0.999),
                lr=C.lr,
                weight_decay=C.weight_decay
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif name == 'SGD':
            return torch.optim.SGD(
                params, 
                lr=C.lr,
                momentum=0.9,
                weight_decay=C.weight_decay
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
            raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
    
    
    def single_critn_factory(critn_name, C):
        import losses
    
    Bobholamovic's avatar
    Bobholamovic committed
        critn_name = critn_name.strip()
    
    Bobholamovic's avatar
    Bobholamovic committed
        try:
            criterion, params = {
                'L1': (nn.L1Loss, ()),
                'MSE': (nn.MSELoss, ()),
                'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
                'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
            }[critn_name.upper()]
            return criterion(*params)
        except KeyError:
            raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
    
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    def _get_basic_configs(ds_name, C):
        if ds_name == 'OSCD':
            return dict(
                root = constants.IMDB_OSCD
            )
        elif ds_name.startswith('AC'):
            return dict(
    
    Bobholamovic's avatar
    Bobholamovic committed
                root = constants.IMDB_AIRCHANGE
    
    Bobholamovic's avatar
    Bobholamovic committed
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        elif ds_name == 'Lebedev':
    
    Bobholamovic's avatar
    Bobholamovic committed
            return dict(
                root = constants.IMDB_LEBEDEV
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
            return dict()
            
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    def single_train_ds_factory(ds_name, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        ds_name = ds_name.strip()
        module = _import_module('data', ds_name)
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset = getattr(module, ds_name+'Dataset')
        configs = dict(
            phase='train', 
            transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
            repeats=C.repeats
        )
    
    Bobholamovic's avatar
    Bobholamovic committed
        
        # Update some common configurations
        configs.update(_get_basic_configs(ds_name, C))
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
        # Set phase-specific ones
    
    Bobholamovic's avatar
    Bobholamovic committed
        if ds_name == 'Lebedev':
    
    Bobholamovic's avatar
    Bobholamovic committed
            configs.update(
                dict(
                    subsets = ('real',)
    
    Bobholamovic's avatar
    Bobholamovic committed
                )
            )
        else:
            pass
    
    Bobholamovic's avatar
    Bobholamovic committed
        
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset_obj = dataset(**configs)
    
    Bobholamovic's avatar
    Bobholamovic committed
        return data.DataLoader(
            dataset_obj,
            batch_size=C.batch_size,
            shuffle=True,
            num_workers=C.num_workers,
            pin_memory=not (C.device == 'cpu'), drop_last=True
        )
    
    
    def single_val_ds_factory(ds_name, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        ds_name = ds_name.strip()
        module = _import_module('data', ds_name)
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset = getattr(module, ds_name+'Dataset')
        configs = dict(
            phase='val', 
    
    Bobholamovic's avatar
    Bobholamovic committed
            transforms=(None, None, None),
            repeats=1
    
    Bobholamovic's avatar
    Bobholamovic committed
        )
    
    
    Bobholamovic's avatar
    Bobholamovic committed
        # Update some common configurations
        configs.update(_get_basic_configs(ds_name, C))
    
        # Set phase-specific ones
    
    Bobholamovic's avatar
    Bobholamovic committed
        if ds_name == 'Lebedev':
    
    Bobholamovic's avatar
    Bobholamovic committed
            configs.update(
                dict(
                    subsets = ('real',)
    
    Bobholamovic's avatar
    Bobholamovic committed
                )
            )
        else:
            pass
    
    Bobholamovic's avatar
    Bobholamovic committed
        
    
    Bobholamovic's avatar
    Bobholamovic committed
        dataset_obj = dataset(**configs)  
    
        # Create eval set
        return data.DataLoader(
            dataset_obj,
            batch_size=1,
            shuffle=False,
    
    Bobholamovic's avatar
    Bobholamovic committed
            num_workers=1,
    
    Bobholamovic's avatar
    Bobholamovic committed
            pin_memory=False, drop_last=False
        )
    
    
    def _parse_input_names(name_str):
        return name_str.split('+')
    
    
    def model_factory(model_names, C):
        name_list = _parse_input_names(model_names)
        if len(name_list) > 1:
            return DuckModel(*(single_model_factory(name, C) for name in name_list))
        else:
            return single_model_factory(model_names, C)
    
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    def optim_factory(optim_names, models, C):
    
    Bobholamovic's avatar
    Bobholamovic committed
        name_list = _parse_input_names(optim_names)
    
    Bobholamovic's avatar
    Bobholamovic committed
        num_models = len(models) if isinstance(models, DuckModel) else 1
        if len(name_list) != num_models:
            raise ValueError("the number of optimizers does not match the number of models")
        
        if num_models > 1:
            optims = []
            for name, model in zip(name_list, models):
                param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
                optims.append(single_optim_factory(name, param_groups, C))
            return DuckOptimizer(*optims)
    
    Bobholamovic's avatar
    Bobholamovic committed
        else:
    
    Bobholamovic's avatar
    Bobholamovic committed
            return single_optim_factory(
                optim_names, 
                [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], 
                C
            )
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    
    def critn_factory(critn_names, C):
        name_list = _parse_input_names(critn_names)
        if len(name_list) > 1:
            return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
        else:
            return single_critn_factory(critn_names, C)
    
    
    def data_factory(dataset_names, phase, C):
        name_list = _parse_input_names(dataset_names)
        if phase not in ('train', 'val'):
            raise ValueError("phase should be either 'train' or 'val'")
        fact = globals()['single_'+phase+'_ds_factory']
        if len(name_list) > 1:
            return DuckDataset(*(fact(name, C) for name in name_list))
        else:
            return fact(dataset_names, C)
    
    
    def metric_factory(metric_names, C):
        from utils import metrics
        name_list = _parse_input_names(metric_names)
    
    Bobholamovic's avatar
    Bobholamovic committed
        return [getattr(metrics, name.strip())() for name in name_list]