From 9fb617877c2942b1aab194a6c6b911251baeec29 Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Sat, 14 Mar 2020 19:57:15 +0800
Subject: [PATCH] Update custom framework

---
 .gitignore                       |   6 +-
 README.md                        |   7 +-
 config_EF_AC_Szada.yaml          |  51 ++++++++++
 config_EF_AC_Tiszadob.yaml       |  51 ++++++++++
 config_EF_OSCD.yaml              |  51 ++++++++++
 config_siamconc_AC_Szada.yaml    |  51 ++++++++++
 config_siamconc_AC_Tiszadob.yaml |  51 ++++++++++
 config_siamconc_OSCD.yaml        |  51 ++++++++++
 config_siamdiff_AC_Szada.yaml    |  51 ++++++++++
 config_siamdiff_AC_Tiszadob.yaml |  51 ++++++++++
 config_siamdiff_OSCD.yaml        |  51 ++++++++++
 src/core/factories.py            |  80 ++++++---------
 src/core/trainers.py             |  91 ++++++++++-------
 src/data/__init__.py             |   9 +-
 src/data/augmentation.py         | 161 ++++++++++++++++++++++---------
 src/train.py                     |  11 +--
 src/utils/misc.py                |  18 +++-
 train9.sh                        |   4 +-
 18 files changed, 697 insertions(+), 149 deletions(-)
 create mode 100644 config_EF_AC_Szada.yaml
 create mode 100644 config_EF_AC_Tiszadob.yaml
 create mode 100644 config_EF_OSCD.yaml
 create mode 100644 config_siamconc_AC_Szada.yaml
 create mode 100644 config_siamconc_AC_Tiszadob.yaml
 create mode 100644 config_siamconc_OSCD.yaml
 create mode 100644 config_siamdiff_AC_Szada.yaml
 create mode 100644 config_siamdiff_AC_Tiszadob.yaml
 create mode 100644 config_siamdiff_OSCD.yaml

diff --git a/.gitignore b/.gitignore
index 9b19707..d488f64 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,9 +130,9 @@ dmypy.json
 # Pyre type checker
 .pyre/
 
-# Config files
-config*.yaml
-!/config_base.yaml
+# # Config files
+# config*.yaml
+# !/config_base.yaml
 
 # Experiment folder
 /exp/
\ No newline at end of file
diff --git a/README.md b/README.md
index b691223..ed0b271 100644
--- a/README.md
+++ b/README.md
@@ -44,4 +44,9 @@ For evaluation, try
 python train.py val --exp-config ../config_base.yaml --resume path_to_checkpoint
 ```
 
-You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`.
\ No newline at end of file
+You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`.
+
+---
+# Changed
+
+2020.3.14 Add the configuration files of my experiments. 
\ No newline at end of file
diff --git a/config_EF_AC_Szada.yaml b/config_EF_AC_Szada.yaml
new file mode 100644
index 0000000..7e969ae
--- /dev/null
+++ b/config_EF_AC_Szada.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Szada
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: EF
+num_feats_in: 6
\ No newline at end of file
diff --git a/config_EF_AC_Tiszadob.yaml b/config_EF_AC_Tiszadob.yaml
new file mode 100644
index 0000000..3c93b3a
--- /dev/null
+++ b/config_EF_AC_Tiszadob.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Tiszadob
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: EF
+num_feats_in: 6
\ No newline at end of file
diff --git a/config_EF_OSCD.yaml b/config_EF_OSCD.yaml
new file mode 100644
index 0000000..a91842c
--- /dev/null
+++ b/config_EF_OSCD.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: OSCD
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: EF
+num_feats_in: 26
\ No newline at end of file
diff --git a/config_siamconc_AC_Szada.yaml b/config_siamconc_AC_Szada.yaml
new file mode 100644
index 0000000..4142e68
--- /dev/null
+++ b/config_siamconc_AC_Szada.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Szada
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_conc
+num_feats_in: 3
\ No newline at end of file
diff --git a/config_siamconc_AC_Tiszadob.yaml b/config_siamconc_AC_Tiszadob.yaml
new file mode 100644
index 0000000..8eee72e
--- /dev/null
+++ b/config_siamconc_AC_Tiszadob.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Tiszadob
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_conc
+num_feats_in: 3
\ No newline at end of file
diff --git a/config_siamconc_OSCD.yaml b/config_siamconc_OSCD.yaml
new file mode 100644
index 0000000..6a25726
--- /dev/null
+++ b/config_siamconc_OSCD.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: OSCD
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_conc
+num_feats_in: 13
\ No newline at end of file
diff --git a/config_siamdiff_AC_Szada.yaml b/config_siamdiff_AC_Szada.yaml
new file mode 100644
index 0000000..3c0ad37
--- /dev/null
+++ b/config_siamdiff_AC_Szada.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Szada
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_diff
+num_feats_in: 3
\ No newline at end of file
diff --git a/config_siamdiff_AC_Tiszadob.yaml b/config_siamdiff_AC_Tiszadob.yaml
new file mode 100644
index 0000000..02f67ba
--- /dev/null
+++ b/config_siamdiff_AC_Tiszadob.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: AC_Tiszadob
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_diff
+num_feats_in: 3
\ No newline at end of file
diff --git a/config_siamdiff_OSCD.yaml b/config_siamdiff_OSCD.yaml
new file mode 100644
index 0000000..90a1671
--- /dev/null
+++ b/config_siamdiff_OSCD.yaml
@@ -0,0 +1,51 @@
+# Basic configurations
+
+
+# Data
+# Common
+dataset: OSCD
+crop_size: 112
+num_workers: 1
+repeats: 3200
+
+
+# Optimizer
+optimizer: SGD
+lr: 0.001
+lr_mode: const
+weight_decay: 0.0005
+step: 2
+
+
+# Training related
+batch_size: 32
+num_epochs: 10
+resume: ''
+load_optim: True
+anew: False
+trace_freq: 1
+device: cuda
+metrics: 'F1Score+Accuracy+Recall+Precision'
+
+
+# Experiment
+exp_dir: ../exp/
+out_dir: ''
+# tag: ''
+# suffix: ''
+# DO NOT specify exp-config term
+save_on: False
+log_off: False
+suffix_off: False
+
+
+# Criterion
+criterion: NLL
+weights: 
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
+
+
+# Model
+model: siamunet_diff
+num_feats_in: 13
\ No newline at end of file
diff --git a/src/core/factories.py b/src/core/factories.py
index e700371..5a35789 100644
--- a/src/core/factories.py
+++ b/src/core/factories.py
@@ -11,6 +11,7 @@ import torch.utils.data as data
 import constants
 import utils.metrics as metrics
 from utils.misc import R
+from data.augmentation import *
 
 
 class _Desc:
@@ -38,16 +39,6 @@ def _generator_deco(func_name):
     return _wrapper
 
 
-def _mark(func):
-    func.__marked__ = True
-    return func
-
-
-def _unmark(func):
-    func.__marked__ = False
-    return func
-
-
 # Duck typing
 class Duck(tuple):
     __ducktype__ = object
@@ -56,6 +47,12 @@ class Duck(tuple):
             raise TypeError("please check the input type")
         return tuple.__new__(cls, args)
 
+    def __add__(self, tup):
+        raise NotImplementedError
+
+    def __mul__(self, tup):
+        raise NotImplementedError
+
 
 class DuckMeta(type):
     def __new__(cls, name, bases, attrs):
@@ -63,61 +60,43 @@ class DuckMeta(type):
         for k, v in getmembers(bases[0]):
             if k.startswith('__'):
                 continue
-            if k in attrs and hasattr(attrs[k], '__marked__'):
-                if attrs[k].__marked__:
-                    continue
             if isgeneratorfunction(v):
-                attrs[k] = _generator_deco(k)
+                attrs.setdefault(k, _generator_deco(k))
             elif isfunction(v):
-                attrs[k] = _func_deco(k)
+                attrs.setdefault(k, _func_deco(k))
             else:
-                attrs[k] = _Desc(k)
+                attrs.setdefault(k, _Desc(k))
         attrs['__ducktype__'] = bases[0]
         return super().__new__(cls, name, (Duck,), attrs)
 
 
-class DuckModel(nn.Module, metaclass=DuckMeta):
-    DELIM = ':'
-    @_mark
-    def load_state_dict(self, state_dict):
-        dicts = [dict() for _ in range(len(self))]
-        for k, v in state_dict.items():
-            i, *k = k.split(self.DELIM)
-            k = self.DELIM.join(k)
-            i = int(i)
-            dicts[i][k] = v
-        for i in range(len(self)):  self[i].load_state_dict(dicts[i])
+class DuckModel(nn.Module):
+    def __init__(self, *models):
+        super().__init__()
+        ## XXX: The state_dict will be a little larger in size
+        # Since some extra bytes are stored in every key
+        self._m = nn.ModuleList(models)
+
+    def __len__(self):
+        return len(self._m)
 
-    @_mark
-    def state_dict(self):
-        dict_ = dict()
-        for i, ins in enumerate(self):
-            dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
-        return dict_
+    def __getitem__(self, idx):
+        return self._m[idx]
+
+    def __repr__(self):
+        return repr(self._m)
 
 
 class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
-    DELIM = ':'
+    # Cuz this is an instance method
     @property
     def param_groups(self):
         return list(chain.from_iterable(ins.param_groups for ins in self))
 
-    @_mark
-    def state_dict(self):
-        dict_ = dict()
-        for i, ins in enumerate(self):
-            dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
-        return dict_
-
-    @_mark
-    def load_state_dict(self, state_dict):
-        dicts = [dict() for _ in range(len(self))]
-        for k, v in state_dict.items():
-            i, *k = k.split(self.DELIM)
-            k = self.DELIM.join(k)
-            i = int(i)
-            dicts[i][k] = v
-        for i in range(len(self)):  self[i].load_state_dict(dicts[i])
+    # This is special in dispatching
+    def load_state_dict(self, state_dicts):
+        for optim, state_dict in zip(self, state_dicts):
+            optim.load_state_dict(state_dict)
 
 
 class DuckCriterion(nn.Module, metaclass=DuckMeta):
@@ -205,7 +184,6 @@ def _get_basic_configs(ds_name, C):
         
 
 def single_train_ds_factory(ds_name, C):
-    from data.augmentation import Compose, Crop, Flip
     ds_name = ds_name.strip()
     module = _import_module('data', ds_name)
     dataset = getattr(module, ds_name+'Dataset')
diff --git a/src/core/trainers.py b/src/core/trainers.py
index 33dcabd..445b96a 100644
--- a/src/core/trainers.py
+++ b/src/core/trainers.py
@@ -20,7 +20,7 @@ class Trainer:
         super().__init__()
         context = deepcopy(settings)
         self.ctx = MappingProxyType(vars(context))
-        self.phase = context.cmd
+        self.mode = ('train', 'val').index(context.cmd)
 
         self.logger = R['LOGGER']
         self.gpc = R['GPC']     # Global Path Controller
@@ -44,27 +44,43 @@ class Trainer:
         self.model.to(self.device)
         self.criterion = critn_factory(criterion, context)
         self.criterion.to(self.device)
-        self.optimizer = optim_factory(optimizer, self.model, context)
         self.metrics = metric_factory(context.metrics, context)
 
-        self.train_loader = data_factory(dataset, 'train', context)
-        self.val_loader = data_factory(dataset, 'val', context)
+        if self.is_training:
+            self.train_loader = data_factory(dataset, 'train', context)
+            self.val_loader = data_factory(dataset, 'val', context)
+            self.optimizer = optim_factory(optimizer, self.model, context)
+        else:
+            self.val_loader = data_factory(dataset, 'val', context)
         
         self.start_epoch = 0
-        self._init_max_acc = 0.0
+        self._init_max_acc_and_epoch = (0.0, 0)
+
+    @property
+    def is_training(self):
+        return self.mode == 0
 
-    def train_epoch(self):
+    def train_epoch(self, epoch):
         raise NotImplementedError
 
     def validate_epoch(self, epoch=0, store=False):
         raise NotImplementedError
 
+    def _write_prompt(self):
+        self.logger.dump(input("\nWrite some notes: "))
+
+    def run(self):
+        if self.is_training:
+            self._write_prompt()
+            self.train()
+        else:
+            self.evaluate()
+
     def train(self):
         if self.load_checkpoint:
             self._resume_from_checkpoint()
 
-        max_acc = self._init_max_acc
-        best_epoch = self.get_ckp_epoch()
+        max_acc, best_epoch = self._init_max_acc_and_epoch
 
         for epoch in range(self.start_epoch, self.num_epochs):
             lr = self._adjust_learning_rate(epoch)
@@ -72,8 +88,8 @@ class Trainer:
             self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
 
             # Train for one epoch
-            self.train_epoch()
-
+            self.train_epoch(epoch)
+            
             # Clear the history of metric objects
             for m in self.metrics:
                 m.reset()
@@ -81,7 +97,7 @@ class Trainer:
             # Evaluate the model on validation set
             self.logger.show_nl("Validate")
             acc = self.validate_epoch(epoch=epoch, store=self.save)
-                
+            
             is_best = acc > max_acc
             if is_best:
                 max_acc = acc
@@ -90,14 +106,14 @@ class Trainer:
                                 acc, epoch, max_acc, best_epoch))
 
             # The checkpoint saves next epoch
-            self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), max_acc, epoch+1, is_best)
+            self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
         
-    def validate(self):
+    def evaluate(self):
         if self.checkpoint: 
             if self._resume_from_checkpoint():
-                self.validate_epoch(self.get_ckp_epoch(), self.save)
+                self.validate_epoch(self.ckp_epoch, self.save)
         else:
-            self.logger.warning("no checkpoint assigned!")
+            self.logger.warning("Warning: no checkpoint assigned!")
 
     def _adjust_learning_rate(self, epoch):
         if self.ctx['lr_mode'] == 'step':
@@ -114,13 +130,14 @@ class Trainer:
         return lr
 
     def _resume_from_checkpoint(self):
+        ## XXX: This could be slow!
         if not os.path.isfile(self.checkpoint):
-            self.logger.error("=> no checkpoint found at '{}'".format(self.checkpoint))
+            self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
             return False
 
-        self.logger.show("=> loading checkpoint '{}'".format(
+        self.logger.show("=> Loading checkpoint '{}'".format(
                         self.checkpoint))
-        checkpoint = torch.load(self.checkpoint)
+        checkpoint = torch.load(self.checkpoint, map_location=self.device)
 
         state_dict = self.model.state_dict()
         ckp_dict = checkpoint.get('state_dict', checkpoint)
@@ -129,32 +146,35 @@ class Trainer:
         
         num_to_update = len(update_dict)
         if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
-            if self.phase == 'val' and (num_to_update < len(state_dict)):
-                self.logger.error("=> mismatched checkpoint for validation")
+            if not self.is_training and (num_to_update < len(state_dict)):
+                self.logger.error("=> Mismatched checkpoint for evaluation")
                 return False
-            self.logger.warning("warning: trying to load an mismatched checkpoint")
+            self.logger.warning("Warning: trying to load an mismatched checkpoint.")
             if num_to_update == 0:
-                self.logger.error("=> no parameter is to be loaded")
+                self.logger.error("=> No parameter is to be loaded.")
                 return False
             else:
-                self.logger.warning("=> {} params are to be loaded".format(num_to_update))
-        elif (not self.ctx['anew']) or (self.phase != 'train'):
-            # Note in the non-anew mode, it is not guaranteed that the contained field 
-            # max_acc be the corresponding one of the loaded checkpoint.
-            self.start_epoch = checkpoint.get('epoch', self.start_epoch)
-            self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc)
-            if self.ctx['load_optim']:
+                self.logger.warning("=> {} params are to be loaded.".format(num_to_update))
+        elif (not self.ctx['anew']) or not self.is_training:
+            self.start_epoch = checkpoint.get('epoch', 0)
+            max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch))
+            # For backward compatibility
+            if isinstance(max_acc_and_epoch, (float, int)):
+                self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
+            else:
+                self._init_max_acc_and_epoch = max_acc_and_epoch
+            if self.ctx['load_optim'] and self.is_training:
                 try:
                     # Note that weight decay might be modified here
                     self.optimizer.load_state_dict(checkpoint['optimizer'])
                 except KeyError:
-                    self.logger.warning("warning: failed to load optimizer parameters")
+                    self.logger.warning("Warning: failed to load optimizer parameters.")
 
         state_dict.update(update_dict)
         self.model.load_state_dict(state_dict)
 
-        self.logger.show("=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format(
-            self.checkpoint, self.get_ckp_epoch(), self._init_max_acc
+        self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format(
+            self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch
             ))
         return True
         
@@ -183,7 +203,8 @@ class Trainer:
                 )
             )
     
-    def get_ckp_epoch(self):
+    @property
+    def ckp_epoch(self):
         # Get current epoch of the checkpoint
         # For dismatched ckp or no ckp, set to 0
         return max(self.start_epoch-1, 0)
@@ -207,7 +228,7 @@ class CDTrainer(Trainer):
     def __init__(self, arch, dataset, optimizer, settings):
         super().__init__(arch, dataset, 'NLL', optimizer, settings)
 
-    def train_epoch(self):
+    def train_epoch(self, epoch):
         losses = AverageMeter()
         len_train = len(self.train_loader)
         pb = tqdm(self.train_loader)
@@ -246,7 +267,7 @@ class CDTrainer(Trainer):
 
         with torch.no_grad():
             for i, (name, t1, t2, label) in enumerate(pb):
-                if self.phase == 'train' and i >= 16: 
+                if self.is_training and i >= 16: 
                     # Do not validate all images on training phase
                     pb.close()
                     self.logger.warning("validation ends early")
diff --git a/src/data/__init__.py b/src/data/__init__.py
index f4bf36e..14a3da5 100644
--- a/src/data/__init__.py
+++ b/src/data/__init__.py
@@ -1,4 +1,4 @@
-from os.path import join, expanduser, basename
+from os.path import join, expanduser, basename, exists
 
 import torch
 import torch.utils.data as data
@@ -16,9 +16,12 @@ class CDDataset(data.Dataset):
     ):
         super().__init__()
         self.root = expanduser(root)
+        if not exists(self.root):
+            raise FileNotFoundError
         self.phase = phase
-        self.transforms = transforms
-        self.repeats = repeats
+        self.transforms = list(transforms)
+        self.transforms += [None]*(3-len(self.transforms))
+        self.repeats = int(repeats)
 
         self.t1_list, self.t2_list, self.label_list = self._read_file_paths()
         self.len = len(self.label_list)
diff --git a/src/data/augmentation.py b/src/data/augmentation.py
index 231d34e..0c5c02b 100644
--- a/src/data/augmentation.py
+++ b/src/data/augmentation.py
@@ -1,9 +1,24 @@
 import random
+import math
 from functools import partial, wraps
 
 import numpy as np
 import cv2
 
+
+__all__ = [
+    'Compose', 'Choose', 
+    'Scale', 'DiscreteScale', 
+    'Flip', 'HorizontalFlip', 'VerticalFlip', 'Rotate', 
+    'Crop', 'MSCrop',
+    'Shift', 'XShift', 'YShift',
+    'HueShift', 'SaturationShift', 'RGBShift', 'RShift', 'GShift', 'BShift',
+    'PCAJitter', 
+    'ContraBrightScale', 'ContrastScale', 'BrightnessScale',
+    'AddGaussNoise'
+]
+
+
 rand = random.random
 randi = random.randint
 choice = random.choice
@@ -11,11 +26,10 @@ uniform = random.uniform
 # gauss = random.gauss
 gauss = random.normalvariate    # This one is thread-safe
 
-# The transformations treat numpy ndarrays only
+# The transformations treat 2-D or 3-D numpy ndarrays only, with the optional 3rd dim as the channel dim
 
 def _istuple(x): return isinstance(x, (tuple, list))
 
-
 class Transform:
     def __init__(self, random_state=False):
         self.random_state = random_state
@@ -28,6 +42,7 @@ class Transform:
     def _set_rand_param(self):
         raise NotImplementedError
 
+
 class Compose:
     def __init__(self, *tf):
         assert len(tf) > 0
@@ -39,17 +54,27 @@ class Compose:
             x = x[0]
             for tf in self.tfs: x = tf(x)
         return x
-        
+
+
+class Choose:
+    def __init__(self, *tf):
+        assert len(tf) > 1
+        self.tfs = tf
+    def __call__(self, *x):
+        idx = randi(0, len(self.tfs)-1)
+        return self.tfs[idx](*x)
+
+
 class Scale(Transform):
     def __init__(self, scale=(0.5,1.0)):
         if _istuple(scale):
             assert len(scale) == 2
-            self.scale_range = scale #sorted(scale)
-            self.scale = scale[0]
+            self.scale_range = tuple(scale) #sorted(scale)
+            self.scale = float(scale[0])
             super(Scale, self).__init__(random_state=True)
         else:
             super(Scale, self).__init__(random_state=False)
-            self.scale = scale
+            self.scale = float(scale)
     def _transform(self, x):
         # assert x.ndim == 3
         h, w = x.shape[:2]
@@ -61,11 +86,12 @@ class Scale(Transform):
     def _set_rand_param(self):
         self.scale = uniform(*self.scale_range)
         
+
 class DiscreteScale(Scale):
     def __init__(self, bins=(0.5, 0.75), keep_prob=0.5):
         super(DiscreteScale, self).__init__(scale=(min(bins), 1.0))
-        self.bins = bins
-        self.keep_prob = keep_prob
+        self.bins = tuple(bins)
+        self.keep_prob = float(keep_prob)
     def _set_rand_param(self):
         self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins)
 
@@ -115,7 +141,11 @@ class VerticalFlip(Flip):
     def __init__(self, flip=None):
         if flip is not None: flip = self._directions[~flip]
         super(VerticalFlip, self).__init__(direction=flip)
-        
+
+
+class Rotate(Flip):
+    _directions = ('90', '180', '270', 'no')
+
 
 class Crop(Transform):
     _inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r')
@@ -148,8 +178,10 @@ class Crop(Transform):
         elif self.bounds == 'r':
             return x[:,w//2:]
         elif len(self.bounds) == 2:
-            assert self.crop_size < (h, w)
+            assert self.crop_size <= (h, w)
             ch, cw = self.crop_size
+            if (ch,cw) == (h,w):
+                return x
             cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
             return x[cy:cy+ch, cx:cx+cw]
         else:
@@ -188,6 +220,59 @@ class MSCrop(Crop):
         self.bounds = (left, top, left+cw, top+ch)
 
 
+class Shift(Transform):
+    def __init__(self, x_shift=(-0.0625, 0.0625), y_shift=(-0.0625, 0.0625), circular=True):
+        super(Shift, self).__init__(random_state=_istuple(x_shift) or _istuple(y_shift))
+
+        if _istuple(x_shift):
+            self.xshift_range = tuple(x_shift)
+            self.xshift = float(x_shift[0])
+        else:
+            self.xshift = float(x_shift)
+            self.xshift_range = (self.xshift, self.xshift)
+
+        if _istuple(y_shift):
+            self.yshift_range = tuple(y_shift)
+            self.yshift = float(y_shift[0])
+        else:
+            self.yshift = float(y_shift)
+            self.yshift_range = (self.yshift, self.yshift)
+
+        self.circular = circular
+
+    def _transform(self, im):
+        h, w = im.shape[:2]
+        xsh = -int(self.xshift*w)
+        ysh = -int(self.yshift*h)
+        if self.circular:
+            # Shift along the x-axis
+            im_shifted = np.concatenate((im[:, xsh:], im[:, :xsh]), axis=1)
+            # Shift along the y-axis
+            im_shifted = np.concatenate((im_shifted[ysh:], im_shifted[:ysh]), axis=0)
+        else:
+            zeros = np.zeros(im.shape)
+            im1, im2 = (zeros, im) if xsh < 0 else (im, zeros)
+            im_shifted = np.concatenate((im1[:, xsh:], im2[:, :xsh]), axis=1)
+            im1, im2 = (zeros, im_shifted) if ysh < 0 else (im_shifted, zeros)
+            im_shifted = np.concatenate((im1[ysh:], im2[:ysh]), axis=0)
+
+        return im_shifted
+        
+    def _set_rand_param(self):
+        self.xshift = uniform(*self.xshift_range)
+        self.yshift = uniform(*self.yshift_range)
+
+
+class XShift(Shift):
+    def __init__(self, x_shift=(-0.0625, 0.0625), circular=True):
+        super(XShift, self).__init__(x_shift, 0.0, circular)
+
+
+class YShift(Shift):
+    def __init__(self, y_shift=(-0.0625, 0.0625), circular=True):
+        super(YShift, self).__init__(0.0, y_shift, circular)
+
+
 # Color jittering and transformation
 # The followings partially refer to https://github.com/albu/albumentations/
 class _ValueTransform(Transform):
@@ -201,8 +286,12 @@ class _ValueTransform(Transform):
         def wrapper(obj, x):
             # # Make a copy
             # x = x.copy()
-            x = tf(obj, np.clip(x, *obj.limit))
-            return np.clip(x, *obj.limit)
+            dtype = x.dtype
+            # The calculations are done with floating type in case of overflow
+            # This is a stupid yet simple way
+            x = tf(obj, np.clip(x.astype(np.float32), *obj.limit))
+            # Convert back to the original type
+            return np.clip(x, *obj.limit).astype(dtype)
         return wrapper
         
 
@@ -222,7 +311,7 @@ class ColorJitter(_ValueTransform):
         else:
             if _istuple(shift):
                 if len(shift) != _nc:
-                    raise ValueError("specify the shift value (or range) for every channel")
+                    raise ValueError("please specify the shift value (or range) for every channel.")
                 rs = all(_istuple(s) for s in shift)
                 self.shift = self.range = shift
             else:
@@ -233,23 +322,20 @@ class ColorJitter(_ValueTransform):
         self.random_state = rs
         
         def _(x):
-            return x, ()
+            return x
         self.convert_to = _
         self.convert_back = _
     
     @_ValueTransform.keep_range
     def _transform(self, x):
-        # CAUTION! 
-        # Type conversion here
-        x, params = self.convert_to(x)
+        x = self.convert_to(x)
         for i, c in enumerate(self._channel):
-            x[...,c] += self.shift[i]
-            x[...,c] = self._clip(x[...,c])
-        x, _ = self.convert_back(x, *params)
+            x[...,c] = self._clip(x[...,c]+float(self.shift[i]))
+        x = self.convert_back(x)
         return x
         
     def _clip(self, x):
-        return np.clip(x, *self.limit)
+        return x
         
     def _set_rand_param(self):
         if len(self._channel) == 1:
@@ -262,19 +348,21 @@ class HSVShift(ColorJitter):
     def __init__(self, shift, limit):
         super().__init__(shift, limit)
         def _convert_to(x):
-            type_x = x.dtype
             x = x.astype(np.float32)
             # Normalize to [0,1]
             x -= self.limit[0]
             x /= self.limit_range
             x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV)
-            return x, (type_x,)
-        def _convert_back(x, type_x):
+            return x
+        def _convert_back(x):
             x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB)
-            return x.astype(type_x) * self.limit_range + self.limit[0], ()
+            return x * self.limit_range + self.limit[0]
         # Pack conversion methods
         self.convert_to = _convert_to
         self.convert_back = _convert_back
+
+        def _clip(self, x):
+            raise NotImplementedError
         
 
 class HueShift(HSVShift):
@@ -332,7 +420,7 @@ class PCAJitter(_ValueTransform):
         old_shape = x.shape
         x = np.reshape(x, (-1,3), order='F')   # For RGB
         x_mean = np.mean(x, 0)
-        x -= x_mean
+        x = x - x_mean
         cov_x = np.cov(x, rowvar=False)
         eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x))
         # The eigen vectors are already unit "length"
@@ -354,9 +442,9 @@ class ContraBrightScale(_ValueTransform):
     
     @_ValueTransform.keep_range
     def _transform(self, x):
-        if self.alpha != 1:
+        if not math.isclose(self.alpha, 1.0):
             x *= self.alpha
-        if self.beta != 0:
+        if not math.isclose(self.beta, 0.0):
             x += self.beta*np.mean(x)
         return x
     
@@ -387,7 +475,7 @@ class _AddNoise(_ValueTransform):
     def __call__(self, *args):
         shape = args[0].shape
         if any(im.shape != shape for im in args):
-            raise ValueError("the input images should be of same size")
+            raise ValueError("the input images should be of same size.")
         self._im_shape = shape
         return super().__call__(*args)
         
@@ -398,17 +486,4 @@ class AddGaussNoise(_AddNoise):
         self.mu = mu
         self.sigma = sigma
     def _set_rand_param(self):
-        self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu
-        
-
-def __test():
-    a = np.arange(12).reshape((2,2,3)).astype(np.float64)
-    tf = Compose(BrightnessScale(), AddGaussNoise(), HueShift())
-    print(a[...,0])
-    c = tf(a)
-    print(c[...,0])
-    print(a[...,0])
-    
-    
-if __name__ == '__main__':
-    __test()
+        self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu
\ No newline at end of file
diff --git a/src/train.py b/src/train.py
index 9a21a8b..c434b8b 100644
--- a/src/train.py
+++ b/src/train.py
@@ -131,7 +131,7 @@ def main():
     args = parse_args()
     gpc, logger = set_gpc_and_logger(args)
 
-    if exists(args.exp_config):
+    if args.exp_config:
         # Make a copy of the config file
         cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False)
         shutil.copy(args.exp_config, cfg_path)
@@ -147,16 +147,11 @@ def main():
 
     try:
         trainer = CDTrainer(args.model, args.dataset, args.optimizer, args)
-        if args.cmd == 'train':
-            trainer.train()
-        elif args.cmd == 'val':
-            trainer.validate()
-        else:
-            pass
+        trainer.run()
     except BaseException as e:
         import traceback
         # Catch ALL kinds of exceptions
-        logger.error(traceback.format_exc())
+        logger.fatal(traceback.format_exc())
         exit(1)
 
 if __name__ == '__main__':
diff --git a/src/utils/misc.py b/src/utils/misc.py
index ca3fe55..fbcedf6 100644
--- a/src/utils/misc.py
+++ b/src/utils/misc.py
@@ -1,5 +1,6 @@
 import logging
 import os
+import sys
 from time import localtime
 from collections import OrderedDict
 from weakref import proxy
@@ -17,8 +18,13 @@ class Logger:
         Logger._count += 1
         self._logger.setLevel(logging.DEBUG)
 
+        self._err_handler = logging.StreamHandler(stream=sys.stderr)
+        self._err_handler.setLevel(logging.ERROR)
+        self._err_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT))
+        self._logger.addHandler(self._err_handler)
+
         if scrn:
-            self._scrn_handler = logging.StreamHandler()
+            self._scrn_handler = logging.StreamHandler(stream=sys.stdout)
             self._scrn_handler.setLevel(logging.INFO)
             self._scrn_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT))
             self._logger.addHandler(self._scrn_handler)
@@ -50,9 +56,12 @@ class Logger:
     def error(self, *args, **kwargs):
         return self._logger.error(*args, **kwargs)
 
+    def fatal(self, *args, **kwargs):
+        return self._logger.critical(*args, **kwargs)
+
     @staticmethod
-    def make_desc(counter, total, *triples):
-        desc = "[{}/{}]".format(counter, total)
+    def make_desc(counter, total, *triples, opt_str=''):
+        desc = "[{}/{}] {}".format(counter, total, opt_str)
         # The three elements of each triple are
         # (name to display, AverageMeter object, formatting string)
         for name, obj, fmt in triples:
@@ -258,6 +267,7 @@ class _Tree:
     def add_node(self, path, val=None):
         if not path.strip():
             raise ValueError("the path is null")
+        path = path.strip('/')
         if val is None:
             val = self._def_val
         names = self.parse_path(path)
@@ -281,6 +291,8 @@ class OutPathGetter:
     def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs):
         super().__init__()
         self._root = root.rstrip('/')    # Work robustly for multiple ending '/'s
+        if len(self._root) == 0 and len(root) > 0:
+            self._root = '/'    # In case of the system root dir
         self._suffix = suffix
         self._keys = dict(log=log, out=out, weight=weight, **subs)
         self._dir_tree = _Tree(
diff --git a/train9.sh b/train9.sh
index f1dcab8..b70d801 100755
--- a/train9.sh
+++ b/train9.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
-# Activate conda environment
-source activate $ME
+# # Activate conda environment
+# source activate $ME
 
 # Change directory
 cd src
-- 
GitLab