From 9fb617877c2942b1aab194a6c6b911251baeec29 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Sat, 14 Mar 2020 19:57:15 +0800 Subject: [PATCH] Update custom framework --- .gitignore | 6 +- README.md | 7 +- config_EF_AC_Szada.yaml | 51 ++++++++++ config_EF_AC_Tiszadob.yaml | 51 ++++++++++ config_EF_OSCD.yaml | 51 ++++++++++ config_siamconc_AC_Szada.yaml | 51 ++++++++++ config_siamconc_AC_Tiszadob.yaml | 51 ++++++++++ config_siamconc_OSCD.yaml | 51 ++++++++++ config_siamdiff_AC_Szada.yaml | 51 ++++++++++ config_siamdiff_AC_Tiszadob.yaml | 51 ++++++++++ config_siamdiff_OSCD.yaml | 51 ++++++++++ src/core/factories.py | 80 ++++++--------- src/core/trainers.py | 91 ++++++++++------- src/data/__init__.py | 9 +- src/data/augmentation.py | 161 ++++++++++++++++++++++--------- src/train.py | 11 +-- src/utils/misc.py | 18 +++- train9.sh | 4 +- 18 files changed, 697 insertions(+), 149 deletions(-) create mode 100644 config_EF_AC_Szada.yaml create mode 100644 config_EF_AC_Tiszadob.yaml create mode 100644 config_EF_OSCD.yaml create mode 100644 config_siamconc_AC_Szada.yaml create mode 100644 config_siamconc_AC_Tiszadob.yaml create mode 100644 config_siamconc_OSCD.yaml create mode 100644 config_siamdiff_AC_Szada.yaml create mode 100644 config_siamdiff_AC_Tiszadob.yaml create mode 100644 config_siamdiff_OSCD.yaml diff --git a/.gitignore b/.gitignore index 9b19707..d488f64 100644 --- a/.gitignore +++ b/.gitignore @@ -130,9 +130,9 @@ dmypy.json # Pyre type checker .pyre/ -# Config files -config*.yaml -!/config_base.yaml +# # Config files +# config*.yaml +# !/config_base.yaml # Experiment folder /exp/ \ No newline at end of file diff --git a/README.md b/README.md index b691223..ed0b271 100644 --- a/README.md +++ b/README.md @@ -44,4 +44,9 @@ For evaluation, try python train.py val --exp-config ../config_base.yaml --resume path_to_checkpoint ``` -You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`. \ No newline at end of file +You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`. + +--- +# Changed + +2020.3.14 Add the configuration files of my experiments. \ No newline at end of file diff --git a/config_EF_AC_Szada.yaml b/config_EF_AC_Szada.yaml new file mode 100644 index 0000000..7e969ae --- /dev/null +++ b/config_EF_AC_Szada.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Szada +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: EF +num_feats_in: 6 \ No newline at end of file diff --git a/config_EF_AC_Tiszadob.yaml b/config_EF_AC_Tiszadob.yaml new file mode 100644 index 0000000..3c93b3a --- /dev/null +++ b/config_EF_AC_Tiszadob.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Tiszadob +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: EF +num_feats_in: 6 \ No newline at end of file diff --git a/config_EF_OSCD.yaml b/config_EF_OSCD.yaml new file mode 100644 index 0000000..a91842c --- /dev/null +++ b/config_EF_OSCD.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: OSCD +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: EF +num_feats_in: 26 \ No newline at end of file diff --git a/config_siamconc_AC_Szada.yaml b/config_siamconc_AC_Szada.yaml new file mode 100644 index 0000000..4142e68 --- /dev/null +++ b/config_siamconc_AC_Szada.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Szada +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_conc +num_feats_in: 3 \ No newline at end of file diff --git a/config_siamconc_AC_Tiszadob.yaml b/config_siamconc_AC_Tiszadob.yaml new file mode 100644 index 0000000..8eee72e --- /dev/null +++ b/config_siamconc_AC_Tiszadob.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Tiszadob +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_conc +num_feats_in: 3 \ No newline at end of file diff --git a/config_siamconc_OSCD.yaml b/config_siamconc_OSCD.yaml new file mode 100644 index 0000000..6a25726 --- /dev/null +++ b/config_siamconc_OSCD.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: OSCD +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_conc +num_feats_in: 13 \ No newline at end of file diff --git a/config_siamdiff_AC_Szada.yaml b/config_siamdiff_AC_Szada.yaml new file mode 100644 index 0000000..3c0ad37 --- /dev/null +++ b/config_siamdiff_AC_Szada.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Szada +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_diff +num_feats_in: 3 \ No newline at end of file diff --git a/config_siamdiff_AC_Tiszadob.yaml b/config_siamdiff_AC_Tiszadob.yaml new file mode 100644 index 0000000..02f67ba --- /dev/null +++ b/config_siamdiff_AC_Tiszadob.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: AC_Tiszadob +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_diff +num_feats_in: 3 \ No newline at end of file diff --git a/config_siamdiff_OSCD.yaml b/config_siamdiff_OSCD.yaml new file mode 100644 index 0000000..90a1671 --- /dev/null +++ b/config_siamdiff_OSCD.yaml @@ -0,0 +1,51 @@ +# Basic configurations + + +# Data +# Common +dataset: OSCD +crop_size: 112 +num_workers: 1 +repeats: 3200 + + +# Optimizer +optimizer: SGD +lr: 0.001 +lr_mode: const +weight_decay: 0.0005 +step: 2 + + +# Training related +batch_size: 32 +num_epochs: 10 +resume: '' +load_optim: True +anew: False +trace_freq: 1 +device: cuda +metrics: 'F1Score+Accuracy+Recall+Precision' + + +# Experiment +exp_dir: ../exp/ +out_dir: '' +# tag: '' +# suffix: '' +# DO NOT specify exp-config term +save_on: False +log_off: False +suffix_off: False + + +# Criterion +criterion: NLL +weights: + - 1.0 # Weight of no-change class + - 10.0 # Weight of change class + + +# Model +model: siamunet_diff +num_feats_in: 13 \ No newline at end of file diff --git a/src/core/factories.py b/src/core/factories.py index e700371..5a35789 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -11,6 +11,7 @@ import torch.utils.data as data import constants import utils.metrics as metrics from utils.misc import R +from data.augmentation import * class _Desc: @@ -38,16 +39,6 @@ def _generator_deco(func_name): return _wrapper -def _mark(func): - func.__marked__ = True - return func - - -def _unmark(func): - func.__marked__ = False - return func - - # Duck typing class Duck(tuple): __ducktype__ = object @@ -56,6 +47,12 @@ class Duck(tuple): raise TypeError("please check the input type") return tuple.__new__(cls, args) + def __add__(self, tup): + raise NotImplementedError + + def __mul__(self, tup): + raise NotImplementedError + class DuckMeta(type): def __new__(cls, name, bases, attrs): @@ -63,61 +60,43 @@ class DuckMeta(type): for k, v in getmembers(bases[0]): if k.startswith('__'): continue - if k in attrs and hasattr(attrs[k], '__marked__'): - if attrs[k].__marked__: - continue if isgeneratorfunction(v): - attrs[k] = _generator_deco(k) + attrs.setdefault(k, _generator_deco(k)) elif isfunction(v): - attrs[k] = _func_deco(k) + attrs.setdefault(k, _func_deco(k)) else: - attrs[k] = _Desc(k) + attrs.setdefault(k, _Desc(k)) attrs['__ducktype__'] = bases[0] return super().__new__(cls, name, (Duck,), attrs) -class DuckModel(nn.Module, metaclass=DuckMeta): - DELIM = ':' - @_mark - def load_state_dict(self, state_dict): - dicts = [dict() for _ in range(len(self))] - for k, v in state_dict.items(): - i, *k = k.split(self.DELIM) - k = self.DELIM.join(k) - i = int(i) - dicts[i][k] = v - for i in range(len(self)): self[i].load_state_dict(dicts[i]) +class DuckModel(nn.Module): + def __init__(self, *models): + super().__init__() + ## XXX: The state_dict will be a little larger in size + # Since some extra bytes are stored in every key + self._m = nn.ModuleList(models) + + def __len__(self): + return len(self._m) - @_mark - def state_dict(self): - dict_ = dict() - for i, ins in enumerate(self): - dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()}) - return dict_ + def __getitem__(self, idx): + return self._m[idx] + + def __repr__(self): + return repr(self._m) class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta): - DELIM = ':' + # Cuz this is an instance method @property def param_groups(self): return list(chain.from_iterable(ins.param_groups for ins in self)) - @_mark - def state_dict(self): - dict_ = dict() - for i, ins in enumerate(self): - dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()}) - return dict_ - - @_mark - def load_state_dict(self, state_dict): - dicts = [dict() for _ in range(len(self))] - for k, v in state_dict.items(): - i, *k = k.split(self.DELIM) - k = self.DELIM.join(k) - i = int(i) - dicts[i][k] = v - for i in range(len(self)): self[i].load_state_dict(dicts[i]) + # This is special in dispatching + def load_state_dict(self, state_dicts): + for optim, state_dict in zip(self, state_dicts): + optim.load_state_dict(state_dict) class DuckCriterion(nn.Module, metaclass=DuckMeta): @@ -205,7 +184,6 @@ def _get_basic_configs(ds_name, C): def single_train_ds_factory(ds_name, C): - from data.augmentation import Compose, Crop, Flip ds_name = ds_name.strip() module = _import_module('data', ds_name) dataset = getattr(module, ds_name+'Dataset') diff --git a/src/core/trainers.py b/src/core/trainers.py index 33dcabd..445b96a 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -20,7 +20,7 @@ class Trainer: super().__init__() context = deepcopy(settings) self.ctx = MappingProxyType(vars(context)) - self.phase = context.cmd + self.mode = ('train', 'val').index(context.cmd) self.logger = R['LOGGER'] self.gpc = R['GPC'] # Global Path Controller @@ -44,27 +44,43 @@ class Trainer: self.model.to(self.device) self.criterion = critn_factory(criterion, context) self.criterion.to(self.device) - self.optimizer = optim_factory(optimizer, self.model, context) self.metrics = metric_factory(context.metrics, context) - self.train_loader = data_factory(dataset, 'train', context) - self.val_loader = data_factory(dataset, 'val', context) + if self.is_training: + self.train_loader = data_factory(dataset, 'train', context) + self.val_loader = data_factory(dataset, 'val', context) + self.optimizer = optim_factory(optimizer, self.model, context) + else: + self.val_loader = data_factory(dataset, 'val', context) self.start_epoch = 0 - self._init_max_acc = 0.0 + self._init_max_acc_and_epoch = (0.0, 0) + + @property + def is_training(self): + return self.mode == 0 - def train_epoch(self): + def train_epoch(self, epoch): raise NotImplementedError def validate_epoch(self, epoch=0, store=False): raise NotImplementedError + def _write_prompt(self): + self.logger.dump(input("\nWrite some notes: ")) + + def run(self): + if self.is_training: + self._write_prompt() + self.train() + else: + self.evaluate() + def train(self): if self.load_checkpoint: self._resume_from_checkpoint() - max_acc = self._init_max_acc - best_epoch = self.get_ckp_epoch() + max_acc, best_epoch = self._init_max_acc_and_epoch for epoch in range(self.start_epoch, self.num_epochs): lr = self._adjust_learning_rate(epoch) @@ -72,8 +88,8 @@ class Trainer: self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) # Train for one epoch - self.train_epoch() - + self.train_epoch(epoch) + # Clear the history of metric objects for m in self.metrics: m.reset() @@ -81,7 +97,7 @@ class Trainer: # Evaluate the model on validation set self.logger.show_nl("Validate") acc = self.validate_epoch(epoch=epoch, store=self.save) - + is_best = acc > max_acc if is_best: max_acc = acc @@ -90,14 +106,14 @@ class Trainer: acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch - self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), max_acc, epoch+1, is_best) + self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best) - def validate(self): + def evaluate(self): if self.checkpoint: if self._resume_from_checkpoint(): - self.validate_epoch(self.get_ckp_epoch(), self.save) + self.validate_epoch(self.ckp_epoch, self.save) else: - self.logger.warning("no checkpoint assigned!") + self.logger.warning("Warning: no checkpoint assigned!") def _adjust_learning_rate(self, epoch): if self.ctx['lr_mode'] == 'step': @@ -114,13 +130,14 @@ class Trainer: return lr def _resume_from_checkpoint(self): + ## XXX: This could be slow! if not os.path.isfile(self.checkpoint): - self.logger.error("=> no checkpoint found at '{}'".format(self.checkpoint)) + self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint)) return False - self.logger.show("=> loading checkpoint '{}'".format( + self.logger.show("=> Loading checkpoint '{}'".format( self.checkpoint)) - checkpoint = torch.load(self.checkpoint) + checkpoint = torch.load(self.checkpoint, map_location=self.device) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) @@ -129,32 +146,35 @@ class Trainer: num_to_update = len(update_dict) if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): - if self.phase == 'val' and (num_to_update < len(state_dict)): - self.logger.error("=> mismatched checkpoint for validation") + if not self.is_training and (num_to_update < len(state_dict)): + self.logger.error("=> Mismatched checkpoint for evaluation") return False - self.logger.warning("warning: trying to load an mismatched checkpoint") + self.logger.warning("Warning: trying to load an mismatched checkpoint.") if num_to_update == 0: - self.logger.error("=> no parameter is to be loaded") + self.logger.error("=> No parameter is to be loaded.") return False else: - self.logger.warning("=> {} params are to be loaded".format(num_to_update)) - elif (not self.ctx['anew']) or (self.phase != 'train'): - # Note in the non-anew mode, it is not guaranteed that the contained field - # max_acc be the corresponding one of the loaded checkpoint. - self.start_epoch = checkpoint.get('epoch', self.start_epoch) - self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc) - if self.ctx['load_optim']: + self.logger.warning("=> {} params are to be loaded.".format(num_to_update)) + elif (not self.ctx['anew']) or not self.is_training: + self.start_epoch = checkpoint.get('epoch', 0) + max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch)) + # For backward compatibility + if isinstance(max_acc_and_epoch, (float, int)): + self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch) + else: + self._init_max_acc_and_epoch = max_acc_and_epoch + if self.ctx['load_optim'] and self.is_training: try: # Note that weight decay might be modified here self.optimizer.load_state_dict(checkpoint['optimizer']) except KeyError: - self.logger.warning("warning: failed to load optimizer parameters") + self.logger.warning("Warning: failed to load optimizer parameters.") state_dict.update(update_dict) self.model.load_state_dict(state_dict) - self.logger.show("=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format( - self.checkpoint, self.get_ckp_epoch(), self._init_max_acc + self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format( + self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch )) return True @@ -183,7 +203,8 @@ class Trainer: ) ) - def get_ckp_epoch(self): + @property + def ckp_epoch(self): # Get current epoch of the checkpoint # For dismatched ckp or no ckp, set to 0 return max(self.start_epoch-1, 0) @@ -207,7 +228,7 @@ class CDTrainer(Trainer): def __init__(self, arch, dataset, optimizer, settings): super().__init__(arch, dataset, 'NLL', optimizer, settings) - def train_epoch(self): + def train_epoch(self, epoch): losses = AverageMeter() len_train = len(self.train_loader) pb = tqdm(self.train_loader) @@ -246,7 +267,7 @@ class CDTrainer(Trainer): with torch.no_grad(): for i, (name, t1, t2, label) in enumerate(pb): - if self.phase == 'train' and i >= 16: + if self.is_training and i >= 16: # Do not validate all images on training phase pb.close() self.logger.warning("validation ends early") diff --git a/src/data/__init__.py b/src/data/__init__.py index f4bf36e..14a3da5 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,4 +1,4 @@ -from os.path import join, expanduser, basename +from os.path import join, expanduser, basename, exists import torch import torch.utils.data as data @@ -16,9 +16,12 @@ class CDDataset(data.Dataset): ): super().__init__() self.root = expanduser(root) + if not exists(self.root): + raise FileNotFoundError self.phase = phase - self.transforms = transforms - self.repeats = repeats + self.transforms = list(transforms) + self.transforms += [None]*(3-len(self.transforms)) + self.repeats = int(repeats) self.t1_list, self.t2_list, self.label_list = self._read_file_paths() self.len = len(self.label_list) diff --git a/src/data/augmentation.py b/src/data/augmentation.py index 231d34e..0c5c02b 100644 --- a/src/data/augmentation.py +++ b/src/data/augmentation.py @@ -1,9 +1,24 @@ import random +import math from functools import partial, wraps import numpy as np import cv2 + +__all__ = [ + 'Compose', 'Choose', + 'Scale', 'DiscreteScale', + 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Rotate', + 'Crop', 'MSCrop', + 'Shift', 'XShift', 'YShift', + 'HueShift', 'SaturationShift', 'RGBShift', 'RShift', 'GShift', 'BShift', + 'PCAJitter', + 'ContraBrightScale', 'ContrastScale', 'BrightnessScale', + 'AddGaussNoise' +] + + rand = random.random randi = random.randint choice = random.choice @@ -11,11 +26,10 @@ uniform = random.uniform # gauss = random.gauss gauss = random.normalvariate # This one is thread-safe -# The transformations treat numpy ndarrays only +# The transformations treat 2-D or 3-D numpy ndarrays only, with the optional 3rd dim as the channel dim def _istuple(x): return isinstance(x, (tuple, list)) - class Transform: def __init__(self, random_state=False): self.random_state = random_state @@ -28,6 +42,7 @@ class Transform: def _set_rand_param(self): raise NotImplementedError + class Compose: def __init__(self, *tf): assert len(tf) > 0 @@ -39,17 +54,27 @@ class Compose: x = x[0] for tf in self.tfs: x = tf(x) return x - + + +class Choose: + def __init__(self, *tf): + assert len(tf) > 1 + self.tfs = tf + def __call__(self, *x): + idx = randi(0, len(self.tfs)-1) + return self.tfs[idx](*x) + + class Scale(Transform): def __init__(self, scale=(0.5,1.0)): if _istuple(scale): assert len(scale) == 2 - self.scale_range = scale #sorted(scale) - self.scale = scale[0] + self.scale_range = tuple(scale) #sorted(scale) + self.scale = float(scale[0]) super(Scale, self).__init__(random_state=True) else: super(Scale, self).__init__(random_state=False) - self.scale = scale + self.scale = float(scale) def _transform(self, x): # assert x.ndim == 3 h, w = x.shape[:2] @@ -61,11 +86,12 @@ class Scale(Transform): def _set_rand_param(self): self.scale = uniform(*self.scale_range) + class DiscreteScale(Scale): def __init__(self, bins=(0.5, 0.75), keep_prob=0.5): super(DiscreteScale, self).__init__(scale=(min(bins), 1.0)) - self.bins = bins - self.keep_prob = keep_prob + self.bins = tuple(bins) + self.keep_prob = float(keep_prob) def _set_rand_param(self): self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins) @@ -115,7 +141,11 @@ class VerticalFlip(Flip): def __init__(self, flip=None): if flip is not None: flip = self._directions[~flip] super(VerticalFlip, self).__init__(direction=flip) - + + +class Rotate(Flip): + _directions = ('90', '180', '270', 'no') + class Crop(Transform): _inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r') @@ -148,8 +178,10 @@ class Crop(Transform): elif self.bounds == 'r': return x[:,w//2:] elif len(self.bounds) == 2: - assert self.crop_size < (h, w) + assert self.crop_size <= (h, w) ch, cw = self.crop_size + if (ch,cw) == (h,w): + return x cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1]) return x[cy:cy+ch, cx:cx+cw] else: @@ -188,6 +220,59 @@ class MSCrop(Crop): self.bounds = (left, top, left+cw, top+ch) +class Shift(Transform): + def __init__(self, x_shift=(-0.0625, 0.0625), y_shift=(-0.0625, 0.0625), circular=True): + super(Shift, self).__init__(random_state=_istuple(x_shift) or _istuple(y_shift)) + + if _istuple(x_shift): + self.xshift_range = tuple(x_shift) + self.xshift = float(x_shift[0]) + else: + self.xshift = float(x_shift) + self.xshift_range = (self.xshift, self.xshift) + + if _istuple(y_shift): + self.yshift_range = tuple(y_shift) + self.yshift = float(y_shift[0]) + else: + self.yshift = float(y_shift) + self.yshift_range = (self.yshift, self.yshift) + + self.circular = circular + + def _transform(self, im): + h, w = im.shape[:2] + xsh = -int(self.xshift*w) + ysh = -int(self.yshift*h) + if self.circular: + # Shift along the x-axis + im_shifted = np.concatenate((im[:, xsh:], im[:, :xsh]), axis=1) + # Shift along the y-axis + im_shifted = np.concatenate((im_shifted[ysh:], im_shifted[:ysh]), axis=0) + else: + zeros = np.zeros(im.shape) + im1, im2 = (zeros, im) if xsh < 0 else (im, zeros) + im_shifted = np.concatenate((im1[:, xsh:], im2[:, :xsh]), axis=1) + im1, im2 = (zeros, im_shifted) if ysh < 0 else (im_shifted, zeros) + im_shifted = np.concatenate((im1[ysh:], im2[:ysh]), axis=0) + + return im_shifted + + def _set_rand_param(self): + self.xshift = uniform(*self.xshift_range) + self.yshift = uniform(*self.yshift_range) + + +class XShift(Shift): + def __init__(self, x_shift=(-0.0625, 0.0625), circular=True): + super(XShift, self).__init__(x_shift, 0.0, circular) + + +class YShift(Shift): + def __init__(self, y_shift=(-0.0625, 0.0625), circular=True): + super(YShift, self).__init__(0.0, y_shift, circular) + + # Color jittering and transformation # The followings partially refer to https://github.com/albu/albumentations/ class _ValueTransform(Transform): @@ -201,8 +286,12 @@ class _ValueTransform(Transform): def wrapper(obj, x): # # Make a copy # x = x.copy() - x = tf(obj, np.clip(x, *obj.limit)) - return np.clip(x, *obj.limit) + dtype = x.dtype + # The calculations are done with floating type in case of overflow + # This is a stupid yet simple way + x = tf(obj, np.clip(x.astype(np.float32), *obj.limit)) + # Convert back to the original type + return np.clip(x, *obj.limit).astype(dtype) return wrapper @@ -222,7 +311,7 @@ class ColorJitter(_ValueTransform): else: if _istuple(shift): if len(shift) != _nc: - raise ValueError("specify the shift value (or range) for every channel") + raise ValueError("please specify the shift value (or range) for every channel.") rs = all(_istuple(s) for s in shift) self.shift = self.range = shift else: @@ -233,23 +322,20 @@ class ColorJitter(_ValueTransform): self.random_state = rs def _(x): - return x, () + return x self.convert_to = _ self.convert_back = _ @_ValueTransform.keep_range def _transform(self, x): - # CAUTION! - # Type conversion here - x, params = self.convert_to(x) + x = self.convert_to(x) for i, c in enumerate(self._channel): - x[...,c] += self.shift[i] - x[...,c] = self._clip(x[...,c]) - x, _ = self.convert_back(x, *params) + x[...,c] = self._clip(x[...,c]+float(self.shift[i])) + x = self.convert_back(x) return x def _clip(self, x): - return np.clip(x, *self.limit) + return x def _set_rand_param(self): if len(self._channel) == 1: @@ -262,19 +348,21 @@ class HSVShift(ColorJitter): def __init__(self, shift, limit): super().__init__(shift, limit) def _convert_to(x): - type_x = x.dtype x = x.astype(np.float32) # Normalize to [0,1] x -= self.limit[0] x /= self.limit_range x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV) - return x, (type_x,) - def _convert_back(x, type_x): + return x + def _convert_back(x): x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB) - return x.astype(type_x) * self.limit_range + self.limit[0], () + return x * self.limit_range + self.limit[0] # Pack conversion methods self.convert_to = _convert_to self.convert_back = _convert_back + + def _clip(self, x): + raise NotImplementedError class HueShift(HSVShift): @@ -332,7 +420,7 @@ class PCAJitter(_ValueTransform): old_shape = x.shape x = np.reshape(x, (-1,3), order='F') # For RGB x_mean = np.mean(x, 0) - x -= x_mean + x = x - x_mean cov_x = np.cov(x, rowvar=False) eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x)) # The eigen vectors are already unit "length" @@ -354,9 +442,9 @@ class ContraBrightScale(_ValueTransform): @_ValueTransform.keep_range def _transform(self, x): - if self.alpha != 1: + if not math.isclose(self.alpha, 1.0): x *= self.alpha - if self.beta != 0: + if not math.isclose(self.beta, 0.0): x += self.beta*np.mean(x) return x @@ -387,7 +475,7 @@ class _AddNoise(_ValueTransform): def __call__(self, *args): shape = args[0].shape if any(im.shape != shape for im in args): - raise ValueError("the input images should be of same size") + raise ValueError("the input images should be of same size.") self._im_shape = shape return super().__call__(*args) @@ -398,17 +486,4 @@ class AddGaussNoise(_AddNoise): self.mu = mu self.sigma = sigma def _set_rand_param(self): - self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu - - -def __test(): - a = np.arange(12).reshape((2,2,3)).astype(np.float64) - tf = Compose(BrightnessScale(), AddGaussNoise(), HueShift()) - print(a[...,0]) - c = tf(a) - print(c[...,0]) - print(a[...,0]) - - -if __name__ == '__main__': - __test() + self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu \ No newline at end of file diff --git a/src/train.py b/src/train.py index 9a21a8b..c434b8b 100644 --- a/src/train.py +++ b/src/train.py @@ -131,7 +131,7 @@ def main(): args = parse_args() gpc, logger = set_gpc_and_logger(args) - if exists(args.exp_config): + if args.exp_config: # Make a copy of the config file cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False) shutil.copy(args.exp_config, cfg_path) @@ -147,16 +147,11 @@ def main(): try: trainer = CDTrainer(args.model, args.dataset, args.optimizer, args) - if args.cmd == 'train': - trainer.train() - elif args.cmd == 'val': - trainer.validate() - else: - pass + trainer.run() except BaseException as e: import traceback # Catch ALL kinds of exceptions - logger.error(traceback.format_exc()) + logger.fatal(traceback.format_exc()) exit(1) if __name__ == '__main__': diff --git a/src/utils/misc.py b/src/utils/misc.py index ca3fe55..fbcedf6 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -1,5 +1,6 @@ import logging import os +import sys from time import localtime from collections import OrderedDict from weakref import proxy @@ -17,8 +18,13 @@ class Logger: Logger._count += 1 self._logger.setLevel(logging.DEBUG) + self._err_handler = logging.StreamHandler(stream=sys.stderr) + self._err_handler.setLevel(logging.ERROR) + self._err_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT)) + self._logger.addHandler(self._err_handler) + if scrn: - self._scrn_handler = logging.StreamHandler() + self._scrn_handler = logging.StreamHandler(stream=sys.stdout) self._scrn_handler.setLevel(logging.INFO) self._scrn_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT)) self._logger.addHandler(self._scrn_handler) @@ -50,9 +56,12 @@ class Logger: def error(self, *args, **kwargs): return self._logger.error(*args, **kwargs) + def fatal(self, *args, **kwargs): + return self._logger.critical(*args, **kwargs) + @staticmethod - def make_desc(counter, total, *triples): - desc = "[{}/{}]".format(counter, total) + def make_desc(counter, total, *triples, opt_str=''): + desc = "[{}/{}] {}".format(counter, total, opt_str) # The three elements of each triple are # (name to display, AverageMeter object, formatting string) for name, obj, fmt in triples: @@ -258,6 +267,7 @@ class _Tree: def add_node(self, path, val=None): if not path.strip(): raise ValueError("the path is null") + path = path.strip('/') if val is None: val = self._def_val names = self.parse_path(path) @@ -281,6 +291,8 @@ class OutPathGetter: def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs): super().__init__() self._root = root.rstrip('/') # Work robustly for multiple ending '/'s + if len(self._root) == 0 and len(root) > 0: + self._root = '/' # In case of the system root dir self._suffix = suffix self._keys = dict(log=log, out=out, weight=weight, **subs) self._dir_tree = _Tree( diff --git a/train9.sh b/train9.sh index f1dcab8..b70d801 100755 --- a/train9.sh +++ b/train9.sh @@ -1,7 +1,7 @@ #!/bin/bash -# Activate conda environment -source activate $ME +# # Activate conda environment +# source activate $ME # Change directory cd src -- GitLab