diff --git a/src/core/factories.py b/src/core/factories.py index 2db268a521dd978a84c691ad393819719d4e043c..c6ac334a3212934dfe99f768764c9c3e6d0a16e1 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -12,6 +12,7 @@ import constants import utils.metrics as metrics from utils.misc import R + class _Desc: def __init__(self, key): self.key = key @@ -26,15 +27,7 @@ class _Desc: def _func_deco(func_name): def _wrapper(self, *args): - # TODO: Add key argument support - try: - # Dispatch type 1 - ret = tuple(getattr(ins, func_name)(*args) for ins in self) - except Exception: - # Dispatch type 2 - if len(args) > 1 or (len(args[0]) != len(self)): raise - ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0])) - return ret + return tuple(getattr(ins, func_name)(*args) for ins in self) return _wrapper @@ -45,6 +38,16 @@ def _generator_deco(func_name): return _wrapper +def _mark(func): + func.__marked__ = True + return func + + +def _unmark(func): + func.__marked__ = False + return func + + # Duck typing class Duck(tuple): __ducktype__ = object @@ -60,6 +63,9 @@ class DuckMeta(type): for k, v in getmembers(bases[0]): if k.startswith('__'): continue + if k in attrs and hasattr(attrs[k], '__marked__'): + if attrs[k].__marked__: + continue if isgeneratorfunction(v): attrs[k] = _generator_deco(k) elif isfunction(v): @@ -71,14 +77,48 @@ class DuckMeta(type): class DuckModel(nn.Module, metaclass=DuckMeta): - pass + DELIM = ':' + @_mark + def load_state_dict(self, state_dict): + dicts = [dict() for _ in range(len(self))] + for k, v in state_dict.items(): + i, *k = k.split(self.DELIM) + k = self.DELIM.join(k) + i = int(i) + dicts[i][k] = v + for i in range(len(self)): self[i].load_state_dict(dicts[i]) + + @_mark + def state_dict(self): + dict_ = dict() + for i, ins in enumerate(self): + dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()}) + return dict_ class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta): + DELIM = ':' @property def param_groups(self): return list(chain.from_iterable(ins.param_groups for ins in self)) + @_mark + def state_dict(self): + dict_ = dict() + for i, ins in enumerate(self): + dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()}) + return dict_ + + @_mark + def load_state_dict(self, state_dict): + dicts = [dict() for _ in range(len(self))] + for k, v in state_dict.items(): + i, *k = k.split(self.DELIM) + k = self.DELIM.join(k) + i = int(i) + dicts[i][k] = v + for i in range(len(self)): self[i].load_state_dict(dicts[i]) + class DuckCriterion(nn.Module, metaclass=DuckMeta): pass @@ -112,7 +152,8 @@ def single_model_factory(model_name, C): def single_optim_factory(optim_name, params, C): - name = optim_name.strip().upper() + optim_name = optim_name.strip() + name = optim_name.upper() if name == 'ADAM': return torch.optim.Adam( params, @@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C): def single_critn_factory(critn_name, C): import losses + critn_name = critn_name.strip() try: criterion, params = { 'L1': (nn.L1Loss, ()), @@ -145,6 +187,19 @@ def single_critn_factory(critn_name, C): raise NotImplementedError("{} is not a supported criterion type".format(critn_name)) +def _get_basic_configs(ds_name, C): + if ds_name == 'OSCD': + return dict( + root = constants.IMDB_OSCD + ) + elif ds_name.startswith('AC'): + return dict( + root = constants.IMDB_AirChange + ) + else: + return dict() + + def single_train_ds_factory(ds_name, C): from data.augmentation import Compose, Crop, Flip ds_name = ds_name.strip() @@ -155,21 +210,13 @@ def single_train_ds_factory(ds_name, C): transforms=(Compose(Crop(C.crop_size), Flip()), None, None), repeats=C.repeats ) - if ds_name == 'OSCD': - configs.update( - dict( - root = constants.IMDB_OSCD - ) - ) - elif ds_name.startswith('AC'): - configs.update( - dict( - root = constants.IMDB_AirChange - ) - ) - else: - pass + + # Update some common configurations + configs.update(_get_basic_configs(ds_name, C)) + # Set phase-specific ones + pass + dataset_obj = dataset(**configs) return data.DataLoader( @@ -190,21 +237,13 @@ def single_val_ds_factory(ds_name, C): transforms=(None, None, None), repeats=1 ) - if ds_name == 'OSCD': - configs.update( - dict( - root = constants.IMDB_OSCD - ) - ) - elif ds_name.startswith('AC'): - configs.update( - dict( - root = constants.IMDB_AirChange - ) - ) - else: - pass + # Update some common configurations + configs.update(_get_basic_configs(ds_name, C)) + + # Set phase-specific ones + pass + dataset_obj = dataset(**configs) # Create eval set @@ -229,12 +268,24 @@ def model_factory(model_names, C): return single_model_factory(model_names, C) -def optim_factory(optim_names, params, C): +def optim_factory(optim_names, models, C): name_list = _parse_input_names(optim_names) - if len(name_list) > 1: - return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list)) + num_models = len(models) if isinstance(models, DuckModel) else 1 + if len(name_list) != num_models: + raise ValueError("the number of optimizers does not match the number of models") + + if num_models > 1: + optims = [] + for name, model in zip(name_list, models): + param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()] + optims.append(single_optim_factory(name, param_groups, C)) + return DuckOptimizer(*optims) else: - return single_optim_factory(optim_names, params, C) + return single_optim_factory( + optim_names, + [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], + C + ) def critn_factory(critn_names, C): diff --git a/src/core/trainers.py b/src/core/trainers.py index 3e1381ace84c06e55652ef89cc3b5e70d7e63673..47f290d9c016c0156c61bcaaef1ef78801df5417 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -33,8 +33,8 @@ class Trainer: self.lr = float(context.lr) self.save = context.save_on or context.out_dir self.out_dir = context.out_dir - self.trace_freq = context.trace_freq - self.device = context.device + self.trace_freq = int(context.trace_freq) + self.device = torch.device(context.device) self.suffix_off = context.suffix_off for k, v in sorted(self.ctx.items()): @@ -44,7 +44,7 @@ class Trainer: self.model.to(self.device) self.criterion = critn_factory(criterion, context) self.criterion.to(self.device) - self.optimizer = optim_factory(optimizer, self.model.parameters(), context) + self.optimizer = optim_factory(optimizer, self.model, context) self.metrics = metric_factory(context.metrics, context) self.train_loader = data_factory(dataset, 'train', context) @@ -74,10 +74,14 @@ class Trainer: # Train for one epoch self.train_epoch() + # Clear the history of metric objects + for m in self.metrics: + m.reset() + # Evaluate the model on validation set self.logger.show_nl("Validate") acc = self.validate_epoch(epoch=epoch, store=self.save) - + is_best = acc > max_acc if is_best: max_acc = acc @@ -250,7 +254,7 @@ class CDTrainer(Trainer): losses.update(loss.item(), n=self.batch_size) # Convert to numpy arrays - CM = to_array(torch.argmax(prob, 1)).astype('uint8') + CM = to_array(torch.argmax(prob[0], 0)).astype('uint8') label = to_array(label[0]).astype('uint8') for m in self.metrics: m.update(CM, label) @@ -267,6 +271,6 @@ class CDTrainer(Trainer): self.logger.dump(desc) if store: - self.save_image(name[0], (CM*255).squeeze(-1), epoch) + self.save_image(name[0], CM*255, epoch) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) \ No newline at end of file diff --git a/src/utils/metrics.py b/src/utils/metrics.py index f4be69fcf79fbd5e6feb433bbed9f3fb3b1ed47a..3e2b24faf3bf8774c7b8d3391055b615642b7709 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -1,24 +1,26 @@ +from functools import partial + +import numpy as np from sklearn import metrics class AverageMeter: def __init__(self, callback=None): super().__init__() - self.callback = callback + if callback is not None: + self.compute = callback self.reset() def compute(self, *args): - if self.callback is not None: - return self.callback(*args) - elif len(args) == 1: + if len(args) == 1: return args[0] else: raise NotImplementedError def reset(self): - self.val = 0.0 - self.avg = 0.0 - self.sum = 0.0 + self.val = 0 + self.avg = 0 + self.sum = 0 self.count = 0 def update(self, *args, n=1): @@ -27,36 +29,75 @@ class AverageMeter: self.count += n self.avg = self.sum / self.count + def __repr__(self): + return 'val: {} avg: {} cnt: {}'.format(self.val, self.avg, self.count) + +# These metrics only for numpy arrays class Metric(AverageMeter): __name__ = 'Metric' - def __init__(self, callback, **configs): - super().__init__(callback) - self.configs = configs + def __init__(self, n_classes=2, mode='accum', reduction='binary'): + super().__init__(None) + self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes))) + assert mode in ('accum', 'separ') + self.mode = mode + assert reduction in ('mean', 'none', 'binary') + if reduction == 'binary' and n_classes != 2: + raise ValueError("binary reduction only works in 2-class cases") + self.reduction = reduction - def compute(self, pred, true): - return self.callback(true.ravel(), pred.ravel(), **self.configs) + def _compute(self, cm): + raise NotImplementedError + + def compute(self, cm): + if self.reduction == 'none': + # Do not reduce size + return self._compute(cm) + elif self.reduction == 'mean': + # Micro averaging + return self._compute(cm).mean() + else: + # The pos_class be 1 + return self._compute(cm)[1] + + def update(self, pred, true, n=1): + # Note that this is no thread-safe + self._cm.update(true.ravel(), pred.ravel()) + if self.mode == 'accum': + cm = self._cm.sum + elif self.mode == 'separ': + cm = self._cm.val + else: + raise NotImplementedError + super().update(cm, n=n) + + def __repr__(self): + return self.__name__+' '+super().__repr__() class Precision(Metric): __name__ = 'Prec.' - def __init__(self, **configs): - super().__init__(metrics.precision_score, **configs) + def _compute(self, cm): + return np.nan_to_num(np.diag(cm)/cm.sum(axis=0)) class Recall(Metric): __name__ = 'Recall' - def __init__(self, **configs): - super().__init__(metrics.recall_score, **configs) + def _compute(self, cm): + return np.nan_to_num(np.diag(cm)/cm.sum(axis=1)) class Accuracy(Metric): __name__ = 'OA' - def __init__(self, **configs): - super().__init__(metrics.accuracy_score, **configs) + def __init__(self, n_classes=2, mode='accum'): + super().__init__(n_classes=n_classes, mode=mode, reduction='none') + def _compute(self, cm): + return np.nan_to_num(np.diag(cm).sum()/cm.sum()) class F1Score(Metric): __name__ = 'F1' - def __init__(self, **configs): - super().__init__(metrics.f1_score, **configs) \ No newline at end of file + def _compute(self, cm): + prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0)) + recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1)) + return np.nan_to_num(2*(prec*recall) / (prec+recall)) \ No newline at end of file