Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable
from itertools import chain
from importlib import import_module
import torch
import torch.nn as nn
import torch.utils.data as data
import constants
import utils.metrics as metrics
from utils.misc import R
class _Desc:
def __init__(self, key):
self.key = key
def __get__(self, instance, owner):
return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
def __set__(self, instance, values):
if not (isinstance(values, Iterable) and len(values)==len(instance)):
raise TypeError("incorrect type or number of values")
for i, v in zip(range(len(instance)), values):
setattr(instance[i], self.key, v)
def _func_deco(func_name):
def _wrapper(self, *args):
# TODO: Add key argument support
try:
# Dispatch type 1
ret = tuple(getattr(ins, func_name)(*args) for ins in self)
except Exception:
# Dispatch type 2
if len(args) > 1 or (len(args[0]) != len(self)): raise
ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
return ret
return _wrapper
def _generator_deco(func_name):
def _wrapper(self, *args, **kwargs):
for ins in self:
yield from getattr(ins, func_name)(*args, **kwargs)
return _wrapper
# Duck typing
class Duck(tuple):
__ducktype__ = object
def __new__(cls, *args):
if any(not isinstance(a, cls.__ducktype__) for a in args):
raise TypeError("please check the input type")
return tuple.__new__(cls, args)
class DuckMeta(type):
def __new__(cls, name, bases, attrs):
assert len(bases) == 1
for k, v in getmembers(bases[0]):
if k.startswith('__'):
continue
if isgeneratorfunction(v):
attrs[k] = _generator_deco(k)
elif isfunction(v):
attrs[k] = _func_deco(k)
else:
attrs[k] = _Desc(k)
attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module, metaclass=DuckMeta):
pass
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
@property
def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self))
class DuckCriterion(nn.Module, metaclass=DuckMeta):
pass
class DuckDataset(data.Dataset, metaclass=DuckMeta):
pass
def _import_module(pkg: str, mod: str, rel=False):
if not rel:
# Use absolute import
return import_module('.'.join([pkg, mod]), package=None)
else:
return import_module('.'+mod, package=pkg)
def single_model_factory(model_name, C):
name = model_name.strip().upper()
if name == 'SIAMUNET_CONC':
from models.siamunet_conc import SiamUnet_conc
return SiamUnet_conc(C.num_feats_in, 2)
elif name == 'SIAMUNET_DIFF':
from models.siamunet_diff import SiamUnet_diff
return SiamUnet_diff(C.num_feats_in, 2)
elif name == 'EF':
from models.unet import Unet
return Unet(C.num_feats_in, 2)
else:
raise NotImplementedError("{} is not a supported architecture".format(model_name))
def single_optim_factory(optim_name, params, C):
if name == 'ADAM':
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C.lr,
weight_decay=C.weight_decay
)
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
def single_critn_factory(critn_name, C):
import losses
try:
criterion, params = {
'L1': (nn.L1Loss, ()),
'MSE': (nn.MSELoss, ()),
'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
}[critn_name.upper()]
return criterion(*params)
except KeyError:
raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
)
)
else:
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
shuffle=True,
num_workers=C.num_workers,
pin_memory=not (C.device == 'cpu'), drop_last=True
)
def single_val_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
configs.update(
dict(
root = constants.IMDB_AirChange
)
)
else:
pass
dataset_obj = dataset(**configs)
# Create eval set
return data.DataLoader(
dataset_obj,
batch_size=1,
shuffle=False,
num_workers=C.num_workers,
pin_memory=False, drop_last=False
)
def _parse_input_names(name_str):
return name_str.split('+')
def model_factory(model_names, C):
name_list = _parse_input_names(model_names)
if len(name_list) > 1:
return DuckModel(*(single_model_factory(name, C) for name in name_list))
else:
return single_model_factory(model_names, C)
def optim_factory(optim_names, params, C):
name_list = _parse_input_names(optim_names)
if len(name_list) > 1:
return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list))
else:
return single_optim_factory(optim_names, params, C)
def critn_factory(critn_names, C):
name_list = _parse_input_names(critn_names)
if len(name_list) > 1:
return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
else:
return single_critn_factory(critn_names, C)
def data_factory(dataset_names, phase, C):
name_list = _parse_input_names(dataset_names)
if phase not in ('train', 'val'):
raise ValueError("phase should be either 'train' or 'val'")
fact = globals()['single_'+phase+'_ds_factory']
if len(name_list) > 1:
return DuckDataset(*(fact(name, C) for name in name_list))
else:
return fact(dataset_names, C)
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name.strip())() for name in name_list]