Newer
Older
from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable
from itertools import chain
from importlib import import_module
import torch
import torch.nn as nn
import torch.utils.data as data
import constants
import utils.metrics as metrics
from utils.misc import R
class _Desc:
def __init__(self, key):
self.key = key
def __get__(self, instance, owner):
return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
def __set__(self, instance, values):
if not (isinstance(values, Iterable) and len(values)==len(instance)):
raise TypeError("incorrect type or number of values")
for i, v in zip(range(len(instance)), values):
setattr(instance[i], self.key, v)
def _func_deco(func_name):
def _wrapper(self, *args):
return tuple(getattr(ins, func_name)(*args) for ins in self)
return _wrapper
def _generator_deco(func_name):
def _wrapper(self, *args, **kwargs):
for ins in self:
yield from getattr(ins, func_name)(*args, **kwargs)
return _wrapper
# Duck typing
class Duck(tuple):
__ducktype__ = object
def __new__(cls, *args):
if any(not isinstance(a, cls.__ducktype__) for a in args):
raise TypeError("please check the input type")
return tuple.__new__(cls, args)
def __add__(self, tup):
raise NotImplementedError
def __mul__(self, tup):
raise NotImplementedError
class DuckMeta(type):
def __new__(cls, name, bases, attrs):
assert len(bases) == 1
for k, v in getmembers(bases[0]):
if k.startswith('__'):
continue
if isgeneratorfunction(v):
attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module):
def __init__(self, *models):
super().__init__()
## XXX: The state_dict will be a little larger in size
# Since some extra bytes are stored in every key
self._m = nn.ModuleList(models)
def __len__(self):
return len(self._m)
def __getitem__(self, idx):
return self._m[idx]
def __repr__(self):
return repr(self._m)
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
@property
def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self))
# This is special in dispatching
def load_state_dict(self, state_dicts):
for optim, state_dict in zip(self, state_dicts):
optim.load_state_dict(state_dict)
class DuckCriterion(nn.Module, metaclass=DuckMeta):
pass
class DuckDataset(data.Dataset, metaclass=DuckMeta):
pass
def _import_module(pkg: str, mod: str, rel=False):
if not rel:
# Use absolute import
return import_module('.'.join([pkg, mod]), package=None)
else:
return import_module('.'+mod, package=pkg)
def single_model_factory(model_name, C):
name = model_name.strip().upper()
if name == 'SIAMUNET_CONC':
from models.siamunet_conc import SiamUnet_conc
return SiamUnet_conc(C.num_feats_in, 2)
elif name == 'SIAMUNET_DIFF':
from models.siamunet_diff import SiamUnet_diff
return SiamUnet_diff(C.num_feats_in, 2)
elif name == 'EF':
from models.unet import Unet
return Unet(C.num_feats_in, 2)
else:
raise NotImplementedError("{} is not a supported architecture".format(model_name))
def single_optim_factory(optim_name, params, C):
optim_name = optim_name.strip()
name = optim_name.upper()
if name == 'ADAM':
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C.lr,
weight_decay=C.weight_decay
)
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
def single_critn_factory(critn_name, C):
import losses
try:
criterion, params = {
'L1': (nn.L1Loss, ()),
'MSE': (nn.MSELoss, ()),
'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
}[critn_name.upper()]
return criterion(*params)
except KeyError:
raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def _get_basic_configs(ds_name, C):
if ds_name == 'OSCD':
return dict(
root = constants.IMDB_OSCD
)
elif ds_name.startswith('AC'):
return dict(
return dict(
root = constants.IMDB_LEBEDEV
)
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
shuffle=True,
num_workers=C.num_workers,
pin_memory=not (C.device == 'cpu'), drop_last=True
)
def single_val_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
dataset_obj = dataset(**configs)
# Create eval set
return data.DataLoader(
dataset_obj,
batch_size=1,
shuffle=False,
pin_memory=False, drop_last=False
)
def _parse_input_names(name_str):
return name_str.split('+')
def model_factory(model_names, C):
name_list = _parse_input_names(model_names)
if len(name_list) > 1:
return DuckModel(*(single_model_factory(name, C) for name in name_list))
else:
return single_model_factory(model_names, C)
num_models = len(models) if isinstance(models, DuckModel) else 1
if len(name_list) != num_models:
raise ValueError("the number of optimizers does not match the number of models")
if num_models > 1:
optims = []
for name, model in zip(name_list, models):
param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
optims.append(single_optim_factory(name, param_groups, C))
return DuckOptimizer(*optims)
return single_optim_factory(
optim_names,
[{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()],
C
)
def critn_factory(critn_names, C):
name_list = _parse_input_names(critn_names)
if len(name_list) > 1:
return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
else:
return single_critn_factory(critn_names, C)
def data_factory(dataset_names, phase, C):
name_list = _parse_input_names(dataset_names)
if phase not in ('train', 'val'):
raise ValueError("phase should be either 'train' or 'val'")
fact = globals()['single_'+phase+'_ds_factory']
if len(name_list) > 1:
return DuckDataset(*(fact(name, C) for name in name_list))
else:
return fact(dataset_names, C)
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name.strip())() for name in name_list]