From 7f0846c13cd76c56fc4c60280e3875a16822056b Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Sat, 25 Jan 2020 15:17:17 +0800 Subject: [PATCH] Add Lebedev --- config_base.yaml | 28 +++++++++++++------------- src/constants.py | 3 ++- src/core/factories.py | 18 +++++++++++++++-- src/core/trainers.py | 5 +++++ src/data/Lebedev.py | 47 +++++++++++++++++++++++++++++++++++++++++++ train9.sh | 20 ++++++++++++++++++ 6 files changed, 104 insertions(+), 17 deletions(-) create mode 100644 src/data/Lebedev.py create mode 100755 train9.sh diff --git a/config_base.yaml b/config_base.yaml index 8eee72e..838ac12 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -3,23 +3,23 @@ # Data # Common -dataset: AC_Tiszadob -crop_size: 112 +dataset: Lebedev +crop_size: 224 num_workers: 1 -repeats: 3200 +repeats: 1 # Optimizer -optimizer: SGD -lr: 0.001 -lr_mode: const -weight_decay: 0.0005 -step: 2 +optimizer: Adam +lr: 1e-4 +lr_mode: step +weight_decay: 0.0 +step: 5 # Training related -batch_size: 32 -num_epochs: 10 +batch_size: 8 +num_epochs: 15 resume: '' load_optim: True anew: False @@ -42,10 +42,10 @@ suffix_off: False # Criterion criterion: NLL weights: - - 1.0 # Weight of no-change class - - 10.0 # Weight of change class + - 0.117 # Weight of no-change class + - 0.883 # Weight of change class # Model -model: siamunet_conc -num_feats_in: 3 \ No newline at end of file +model: EF +num_feats_in: 6 \ No newline at end of file diff --git a/src/constants.py b/src/constants.py index e41882f..c9bfdd0 100644 --- a/src/constants.py +++ b/src/constants.py @@ -3,7 +3,8 @@ # Dataset directories IMDB_OSCD = '~/Datasets/OSCDDataset/' -IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/' # Checkpoint templates CKP_LATEST = 'checkpoint_latest.pth' diff --git a/src/core/factories.py b/src/core/factories.py index 2db268a..64e644a 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -164,7 +164,14 @@ def single_train_ds_factory(ds_name, C): elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AirChange + root = constants.IMDB_AIRCHANGE + ) + ) + elif ds_name == 'Lebedev': + configs.update( + dict( + root = constants.IMDB_LEBEDEV, + subsets = ('real',) ) ) else: @@ -199,7 +206,14 @@ def single_val_ds_factory(ds_name, C): elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AirChange + root = constants.IMDB_AIRCHANGE + ) + ) + elif ds_name == 'Lebedev': + configs.update( + dict( + root = constants.IMDB_LEBEDEV, + subsets = ('real',) ) ) else: diff --git a/src/core/trainers.py b/src/core/trainers.py index 3e1381a..c41863f 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -242,6 +242,11 @@ class CDTrainer(Trainer): with torch.no_grad(): for i, (name, t1, t2, label) in enumerate(pb): + if self.phase == 'train' and i >= 16: + # Do not validate all images on training phase + pb.close() + self.logger.warning("validation ends early") + break t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device) prob = self.model(t1, t2) diff --git a/src/data/Lebedev.py b/src/data/Lebedev.py new file mode 100644 index 0000000..bad5651 --- /dev/null +++ b/src/data/Lebedev.py @@ -0,0 +1,47 @@ +from glob import glob +from os.path import join, basename + +import numpy as np + +from . import CDDataset +from .common import default_loader + +class LebedevDataset(CDDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1, + subsets=('real', 'with_shift', 'without_shift') + ): + self.subsets = subsets + super().__init__(root, phase, transforms, repeats) + + def _read_file_paths(self): + t1_list, t2_list, label_list = [], [], [] + + for subset in self.subsets: + # Get subset directory + if subset == 'real': + subset_dir = join(self.root, 'Real', 'subset') + elif subset == 'with_shift': + subset_dir = join(self.root, 'Model', 'with_shift') + elif subset == 'without_shift': + subset_dir = join(self.root, 'Model', 'without_shift') + else: + raise RuntimeError('unrecognized key encountered') + + pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg' + refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern))) + t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs) + t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs) + + label_list.extend(refs) + t1_list.extend(t1s) + t2_list.extend(t2s) + + return t1_list, t2_list, label_list + + def fetch_label(self, label_path): + # To {0,1} + return (super().fetch_label(label_path) / 255.0).astype(np.uint8) \ No newline at end of file diff --git a/train9.sh b/train9.sh new file mode 100755 index 0000000..f1dcab8 --- /dev/null +++ b/train9.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Activate conda environment +source activate $ME + +# Change directory +cd src + +# Define constants +ARCHS=("siamdiff" "siamconc" "EF") +DATASETS=("AC_Szada" "AC_Tiszadob" "OSCD") + +# LOOP +for arch in ${ARCHS[@]} +do + for dataset in ${DATASETS[@]} + do + python train.py train --exp-config ../config_${arch}_${dataset}.yaml + done +done \ No newline at end of file -- GitLab