Skip to content
Snippets Groups Projects
Commit 36e94e06 authored by Bobholamovic's avatar Bobholamovic
Browse files

Happy New Year

parent 4d7e2b88
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ import constants
import utils.metrics as metrics
from utils.misc import R
class _Desc:
def __init__(self, key):
self.key = key
......@@ -26,15 +27,7 @@ class _Desc:
def _func_deco(func_name):
def _wrapper(self, *args):
# TODO: Add key argument support
try:
# Dispatch type 1
ret = tuple(getattr(ins, func_name)(*args) for ins in self)
except Exception:
# Dispatch type 2
if len(args) > 1 or (len(args[0]) != len(self)): raise
ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
return ret
return tuple(getattr(ins, func_name)(*args) for ins in self)
return _wrapper
......@@ -45,6 +38,16 @@ def _generator_deco(func_name):
return _wrapper
def _mark(func):
func.__marked__ = True
return func
def _unmark(func):
func.__marked__ = False
return func
# Duck typing
class Duck(tuple):
__ducktype__ = object
......@@ -60,6 +63,9 @@ class DuckMeta(type):
for k, v in getmembers(bases[0]):
if k.startswith('__'):
continue
if k in attrs and hasattr(attrs[k], '__marked__'):
if attrs[k].__marked__:
continue
if isgeneratorfunction(v):
attrs[k] = _generator_deco(k)
elif isfunction(v):
......@@ -71,14 +77,48 @@ class DuckMeta(type):
class DuckModel(nn.Module, metaclass=DuckMeta):
pass
DELIM = ':'
@_mark
def load_state_dict(self, state_dict):
dicts = [dict() for _ in range(len(self))]
for k, v in state_dict.items():
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k)
i = int(i)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
@_mark
def state_dict(self):
dict_ = dict()
for i, ins in enumerate(self):
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
return dict_
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
DELIM = ':'
@property
def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self))
@_mark
def state_dict(self):
dict_ = dict()
for i, ins in enumerate(self):
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
return dict_
@_mark
def load_state_dict(self, state_dict):
dicts = [dict() for _ in range(len(self))]
for k, v in state_dict.items():
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k)
i = int(i)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
class DuckCriterion(nn.Module, metaclass=DuckMeta):
pass
......@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
def single_optim_factory(optim_name, params, C):
name = optim_name.strip().upper()
optim_name = optim_name.strip()
name = optim_name.upper()
if name == 'ADAM':
return torch.optim.Adam(
params,
......@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
def single_critn_factory(critn_name, C):
import losses
critn_name = critn_name.strip()
try:
criterion, params = {
'L1': (nn.L1Loss, ()),
......@@ -145,6 +187,19 @@ def single_critn_factory(critn_name, C):
raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def _get_basic_configs(ds_name, C):
if ds_name == 'OSCD':
return dict(
root = constants.IMDB_OSCD
)
elif ds_name.startswith('AC'):
return dict(
root = constants.IMDB_AirChange
)
else:
return dict()
def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip
ds_name = ds_name.strip()
......@@ -155,21 +210,13 @@ def single_train_ds_factory(ds_name, C):
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AirChange
)
)
else:
pass
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
......@@ -190,21 +237,13 @@ def single_val_ds_factory(ds_name, C):
transforms=(None, None, None),
repeats=1
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AirChange
)
)
else:
pass
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
pass
dataset_obj = dataset(**configs)
# Create eval set
......@@ -229,12 +268,24 @@ def model_factory(model_names, C):
return single_model_factory(model_names, C)
def optim_factory(optim_names, params, C):
def optim_factory(optim_names, models, C):
name_list = _parse_input_names(optim_names)
if len(name_list) > 1:
return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list))
num_models = len(models) if isinstance(models, DuckModel) else 1
if len(name_list) != num_models:
raise ValueError("the number of optimizers does not match the number of models")
if num_models > 1:
optims = []
for name, model in zip(name_list, models):
param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
optims.append(single_optim_factory(name, param_groups, C))
return DuckOptimizer(*optims)
else:
return single_optim_factory(optim_names, params, C)
return single_optim_factory(
optim_names,
[{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()],
C
)
def critn_factory(critn_names, C):
......
......@@ -33,8 +33,8 @@ class Trainer:
self.lr = float(context.lr)
self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir
self.trace_freq = context.trace_freq
self.device = context.device
self.trace_freq = int(context.trace_freq)
self.device = torch.device(context.device)
self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()):
......@@ -44,7 +44,7 @@ class Trainer:
self.model.to(self.device)
self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device)
self.optimizer = optim_factory(optimizer, self.model.parameters(), context)
self.optimizer = optim_factory(optimizer, self.model, context)
self.metrics = metric_factory(context.metrics, context)
self.train_loader = data_factory(dataset, 'train', context)
......@@ -74,10 +74,14 @@ class Trainer:
# Train for one epoch
self.train_epoch()
# Clear the history of metric objects
for m in self.metrics:
m.reset()
# Evaluate the model on validation set
self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save)
is_best = acc > max_acc
if is_best:
max_acc = acc
......@@ -250,7 +254,7 @@ class CDTrainer(Trainer):
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
CM = to_array(torch.argmax(prob, 1)).astype('uint8')
CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
label = to_array(label[0]).astype('uint8')
for m in self.metrics:
m.update(CM, label)
......@@ -267,6 +271,6 @@ class CDTrainer(Trainer):
self.logger.dump(desc)
if store:
self.save_image(name[0], (CM*255).squeeze(-1), epoch)
self.save_image(name[0], CM*255, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
from functools import partial
import numpy as np
from sklearn import metrics
class AverageMeter:
def __init__(self, callback=None):
super().__init__()
self.callback = callback
if callback is not None:
self.compute = callback
self.reset()
def compute(self, *args):
if self.callback is not None:
return self.callback(*args)
elif len(args) == 1:
if len(args) == 1:
return args[0]
else:
raise NotImplementedError
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, *args, n=1):
......@@ -27,36 +29,75 @@ class AverageMeter:
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return 'val: {} avg: {} cnt: {}'.format(self.val, self.avg, self.count)
# These metrics only for numpy arrays
class Metric(AverageMeter):
__name__ = 'Metric'
def __init__(self, callback, **configs):
super().__init__(callback)
self.configs = configs
def __init__(self, n_classes=2, mode='accum', reduction='binary'):
super().__init__(None)
self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)))
assert mode in ('accum', 'separ')
self.mode = mode
assert reduction in ('mean', 'none', 'binary')
if reduction == 'binary' and n_classes != 2:
raise ValueError("binary reduction only works in 2-class cases")
self.reduction = reduction
def compute(self, pred, true):
return self.callback(true.ravel(), pred.ravel(), **self.configs)
def _compute(self, cm):
raise NotImplementedError
def compute(self, cm):
if self.reduction == 'none':
# Do not reduce size
return self._compute(cm)
elif self.reduction == 'mean':
# Micro averaging
return self._compute(cm).mean()
else:
# The pos_class be 1
return self._compute(cm)[1]
def update(self, pred, true, n=1):
# Note that this is no thread-safe
self._cm.update(true.ravel(), pred.ravel())
if self.mode == 'accum':
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
else:
raise NotImplementedError
super().update(cm, n=n)
def __repr__(self):
return self.__name__+' '+super().__repr__()
class Precision(Metric):
__name__ = 'Prec.'
def __init__(self, **configs):
super().__init__(metrics.precision_score, **configs)
def _compute(self, cm):
return np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
class Recall(Metric):
__name__ = 'Recall'
def __init__(self, **configs):
super().__init__(metrics.recall_score, **configs)
def _compute(self, cm):
return np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
class Accuracy(Metric):
__name__ = 'OA'
def __init__(self, **configs):
super().__init__(metrics.accuracy_score, **configs)
def __init__(self, n_classes=2, mode='accum'):
super().__init__(n_classes=n_classes, mode=mode, reduction='none')
def _compute(self, cm):
return np.nan_to_num(np.diag(cm).sum()/cm.sum())
class F1Score(Metric):
__name__ = 'F1'
def __init__(self, **configs):
super().__init__(metrics.f1_score, **configs)
\ No newline at end of file
def _compute(self, cm):
prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
return np.nan_to_num(2*(prec*recall) / (prec+recall))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment