diff --git a/src/core/factories.py b/src/core/factories.py
index 2db268a521dd978a84c691ad393819719d4e043c..c6ac334a3212934dfe99f768764c9c3e6d0a16e1 100644
--- a/src/core/factories.py
+++ b/src/core/factories.py
@@ -12,6 +12,7 @@ import constants
 import utils.metrics as metrics
 from utils.misc import R
 
+
 class _Desc:
     def __init__(self, key):
         self.key = key
@@ -26,15 +27,7 @@ class _Desc:
 
 def _func_deco(func_name):
     def _wrapper(self, *args):
-        # TODO: Add key argument support
-        try:
-            # Dispatch type 1
-            ret = tuple(getattr(ins, func_name)(*args) for ins in self)
-        except Exception:
-            # Dispatch type 2
-            if len(args) > 1 or (len(args[0]) != len(self)): raise
-            ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
-        return ret
+        return tuple(getattr(ins, func_name)(*args) for ins in self)
     return _wrapper
 
 
@@ -45,6 +38,16 @@ def _generator_deco(func_name):
     return _wrapper
 
 
+def _mark(func):
+    func.__marked__ = True
+    return func
+
+
+def _unmark(func):
+    func.__marked__ = False
+    return func
+
+
 # Duck typing
 class Duck(tuple):
     __ducktype__ = object
@@ -60,6 +63,9 @@ class DuckMeta(type):
         for k, v in getmembers(bases[0]):
             if k.startswith('__'):
                 continue
+            if k in attrs and hasattr(attrs[k], '__marked__'):
+                if attrs[k].__marked__:
+                    continue
             if isgeneratorfunction(v):
                 attrs[k] = _generator_deco(k)
             elif isfunction(v):
@@ -71,14 +77,48 @@ class DuckMeta(type):
 
 
 class DuckModel(nn.Module, metaclass=DuckMeta):
-    pass
+    DELIM = ':'
+    @_mark
+    def load_state_dict(self, state_dict):
+        dicts = [dict() for _ in range(len(self))]
+        for k, v in state_dict.items():
+            i, *k = k.split(self.DELIM)
+            k = self.DELIM.join(k)
+            i = int(i)
+            dicts[i][k] = v
+        for i in range(len(self)):  self[i].load_state_dict(dicts[i])
+
+    @_mark
+    def state_dict(self):
+        dict_ = dict()
+        for i, ins in enumerate(self):
+            dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
+        return dict_
 
 
 class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
+    DELIM = ':'
     @property
     def param_groups(self):
         return list(chain.from_iterable(ins.param_groups for ins in self))
 
+    @_mark
+    def state_dict(self):
+        dict_ = dict()
+        for i, ins in enumerate(self):
+            dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
+        return dict_
+
+    @_mark
+    def load_state_dict(self, state_dict):
+        dicts = [dict() for _ in range(len(self))]
+        for k, v in state_dict.items():
+            i, *k = k.split(self.DELIM)
+            k = self.DELIM.join(k)
+            i = int(i)
+            dicts[i][k] = v
+        for i in range(len(self)):  self[i].load_state_dict(dicts[i])
+
 
 class DuckCriterion(nn.Module, metaclass=DuckMeta):
     pass
@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
 
 
 def single_optim_factory(optim_name, params, C):
-    name = optim_name.strip().upper()
+    optim_name = optim_name.strip()
+    name = optim_name.upper()
     if name == 'ADAM':
         return torch.optim.Adam(
             params, 
@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
 
 def single_critn_factory(critn_name, C):
     import losses
+    critn_name = critn_name.strip()
     try:
         criterion, params = {
             'L1': (nn.L1Loss, ()),
@@ -145,6 +187,19 @@ def single_critn_factory(critn_name, C):
         raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
 
 
+def _get_basic_configs(ds_name, C):
+    if ds_name == 'OSCD':
+        return dict(
+            root = constants.IMDB_OSCD
+        )
+    elif ds_name.startswith('AC'):
+        return dict(
+            root = constants.IMDB_AirChange
+        )
+    else:
+        return dict()
+        
+
 def single_train_ds_factory(ds_name, C):
     from data.augmentation import Compose, Crop, Flip
     ds_name = ds_name.strip()
@@ -155,21 +210,13 @@ def single_train_ds_factory(ds_name, C):
         transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
         repeats=C.repeats
     )
-    if ds_name == 'OSCD':
-        configs.update(
-            dict(
-                root = constants.IMDB_OSCD
-            )
-        )
-    elif ds_name.startswith('AC'):
-        configs.update(
-            dict(
-                root = constants.IMDB_AirChange
-            )
-        )
-    else:
-        pass
+    
+    # Update some common configurations
+    configs.update(_get_basic_configs(ds_name, C))
 
+    # Set phase-specific ones
+    pass
+    
     dataset_obj = dataset(**configs)
     
     return data.DataLoader(
@@ -190,21 +237,13 @@ def single_val_ds_factory(ds_name, C):
         transforms=(None, None, None),
         repeats=1
     )
-    if ds_name == 'OSCD':
-        configs.update(
-            dict(
-                root = constants.IMDB_OSCD
-            )
-        )
-    elif ds_name.startswith('AC'):
-        configs.update(
-            dict(
-                root = constants.IMDB_AirChange
-            )
-        )
-    else:
-        pass
 
+    # Update some common configurations
+    configs.update(_get_basic_configs(ds_name, C))
+
+    # Set phase-specific ones
+    pass
+    
     dataset_obj = dataset(**configs)  
 
     # Create eval set
@@ -229,12 +268,24 @@ def model_factory(model_names, C):
         return single_model_factory(model_names, C)
 
 
-def optim_factory(optim_names, params, C):
+def optim_factory(optim_names, models, C):
     name_list = _parse_input_names(optim_names)
-    if len(name_list) > 1:
-        return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list))
+    num_models = len(models) if isinstance(models, DuckModel) else 1
+    if len(name_list) != num_models:
+        raise ValueError("the number of optimizers does not match the number of models")
+    
+    if num_models > 1:
+        optims = []
+        for name, model in zip(name_list, models):
+            param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
+            optims.append(single_optim_factory(name, param_groups, C))
+        return DuckOptimizer(*optims)
     else:
-        return single_optim_factory(optim_names, params, C)
+        return single_optim_factory(
+            optim_names, 
+            [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], 
+            C
+        )
 
 
 def critn_factory(critn_names, C):
diff --git a/src/core/trainers.py b/src/core/trainers.py
index 3e1381ace84c06e55652ef89cc3b5e70d7e63673..47f290d9c016c0156c61bcaaef1ef78801df5417 100644
--- a/src/core/trainers.py
+++ b/src/core/trainers.py
@@ -33,8 +33,8 @@ class Trainer:
         self.lr = float(context.lr)
         self.save = context.save_on or context.out_dir
         self.out_dir = context.out_dir
-        self.trace_freq = context.trace_freq
-        self.device = context.device
+        self.trace_freq = int(context.trace_freq)
+        self.device = torch.device(context.device)
         self.suffix_off = context.suffix_off
 
         for k, v in sorted(self.ctx.items()):
@@ -44,7 +44,7 @@ class Trainer:
         self.model.to(self.device)
         self.criterion = critn_factory(criterion, context)
         self.criterion.to(self.device)
-        self.optimizer = optim_factory(optimizer, self.model.parameters(), context)
+        self.optimizer = optim_factory(optimizer, self.model, context)
         self.metrics = metric_factory(context.metrics, context)
 
         self.train_loader = data_factory(dataset, 'train', context)
@@ -74,10 +74,14 @@ class Trainer:
             # Train for one epoch
             self.train_epoch()
 
+            # Clear the history of metric objects
+            for m in self.metrics:
+                m.reset()
+                
             # Evaluate the model on validation set
             self.logger.show_nl("Validate")
             acc = self.validate_epoch(epoch=epoch, store=self.save)
-            
+                
             is_best = acc > max_acc
             if is_best:
                 max_acc = acc
@@ -250,7 +254,7 @@ class CDTrainer(Trainer):
                 losses.update(loss.item(), n=self.batch_size)
 
                 # Convert to numpy arrays
-                CM = to_array(torch.argmax(prob, 1)).astype('uint8')
+                CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
                 label = to_array(label[0]).astype('uint8')
                 for m in self.metrics:
                     m.update(CM, label)
@@ -267,6 +271,6 @@ class CDTrainer(Trainer):
                 self.logger.dump(desc)
                     
                 if store:
-                    self.save_image(name[0], (CM*255).squeeze(-1), epoch)
+                    self.save_image(name[0], CM*255, epoch)
 
         return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
index f4be69fcf79fbd5e6feb433bbed9f3fb3b1ed47a..3e2b24faf3bf8774c7b8d3391055b615642b7709 100644
--- a/src/utils/metrics.py
+++ b/src/utils/metrics.py
@@ -1,24 +1,26 @@
+from functools import partial
+
+import numpy as np
 from sklearn import metrics
 
 
 class AverageMeter:
     def __init__(self, callback=None):
         super().__init__()
-        self.callback = callback
+        if callback is not None:
+            self.compute = callback
         self.reset()
 
     def compute(self, *args):
-        if self.callback is not None:
-            return self.callback(*args) 
-        elif len(args) == 1:
+        if len(args) == 1:
             return args[0]
         else:
             raise NotImplementedError
 
     def reset(self):
-        self.val = 0.0
-        self.avg = 0.0
-        self.sum = 0.0
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
         self.count = 0
 
     def update(self, *args, n=1):
@@ -27,36 +29,75 @@ class AverageMeter:
         self.count += n
         self.avg = self.sum / self.count
 
+    def __repr__(self):
+        return 'val: {} avg: {} cnt: {}'.format(self.val, self.avg, self.count)
+
 
+# These metrics only for numpy arrays
 class Metric(AverageMeter):
     __name__ = 'Metric'
-    def __init__(self, callback, **configs):
-        super().__init__(callback)
-        self.configs = configs
+    def __init__(self, n_classes=2, mode='accum', reduction='binary'):
+        super().__init__(None)
+        self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)))
+        assert mode in ('accum', 'separ')
+        self.mode = mode
+        assert reduction in ('mean', 'none', 'binary')
+        if reduction == 'binary' and n_classes != 2:
+            raise ValueError("binary reduction only works in 2-class cases")
+        self.reduction = reduction
     
-    def compute(self, pred, true):
-        return self.callback(true.ravel(), pred.ravel(), **self.configs)
+    def _compute(self, cm):
+        raise NotImplementedError
+
+    def compute(self, cm):
+        if self.reduction == 'none':
+            # Do not reduce size
+            return self._compute(cm)
+        elif self.reduction == 'mean':
+            # Micro averaging
+            return self._compute(cm).mean()
+        else:
+            # The pos_class be 1
+            return self._compute(cm)[1]
+
+    def update(self, pred, true, n=1):
+        # Note that this is no thread-safe
+        self._cm.update(true.ravel(), pred.ravel())
+        if self.mode == 'accum':
+            cm = self._cm.sum
+        elif self.mode == 'separ':
+            cm = self._cm.val
+        else:
+            raise NotImplementedError
+        super().update(cm, n=n)
+
+    def __repr__(self):
+        return self.__name__+' '+super().__repr__()
 
 
 class Precision(Metric):
     __name__ = 'Prec.'
-    def __init__(self, **configs):
-        super().__init__(metrics.precision_score, **configs)
+    def _compute(self, cm):
+        return np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
 
 
 class Recall(Metric):
     __name__ = 'Recall'
-    def __init__(self, **configs):
-        super().__init__(metrics.recall_score, **configs)
+    def _compute(self, cm):
+        return np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
 
 
 class Accuracy(Metric):
     __name__ = 'OA'
-    def __init__(self, **configs):
-        super().__init__(metrics.accuracy_score, **configs)
+    def __init__(self, n_classes=2, mode='accum'):
+        super().__init__(n_classes=n_classes, mode=mode, reduction='none')
+    def _compute(self, cm):
+        return np.nan_to_num(np.diag(cm).sum()/cm.sum())
 
 
 class F1Score(Metric):
     __name__ = 'F1'
-    def __init__(self, **configs):
-        super().__init__(metrics.f1_score, **configs)
\ No newline at end of file
+    def _compute(self, cm):
+        prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
+        recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
+        return np.nan_to_num(2*(prec*recall) / (prec+recall))
\ No newline at end of file