diff --git a/config_base.yaml b/config_base.yaml index 8eee72e4f9a60e6ae029f8668c52dd3a7215b612..838ac12abb837780e934a6964ffc3d88c4e97510 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -3,23 +3,23 @@ # Data # Common -dataset: AC_Tiszadob -crop_size: 112 +dataset: Lebedev +crop_size: 224 num_workers: 1 -repeats: 3200 +repeats: 1 # Optimizer -optimizer: SGD -lr: 0.001 -lr_mode: const -weight_decay: 0.0005 -step: 2 +optimizer: Adam +lr: 1e-4 +lr_mode: step +weight_decay: 0.0 +step: 5 # Training related -batch_size: 32 -num_epochs: 10 +batch_size: 8 +num_epochs: 15 resume: '' load_optim: True anew: False @@ -42,10 +42,10 @@ suffix_off: False # Criterion criterion: NLL weights: - - 1.0 # Weight of no-change class - - 10.0 # Weight of change class + - 0.117 # Weight of no-change class + - 0.883 # Weight of change class # Model -model: siamunet_conc -num_feats_in: 3 \ No newline at end of file +model: EF +num_feats_in: 6 \ No newline at end of file diff --git a/src/constants.py b/src/constants.py index e41882fb592a9df376ee92e7f28b464d8944c551..c9bfdd0d69c6ce1ad2fa67c54cb7782632a3400b 100644 --- a/src/constants.py +++ b/src/constants.py @@ -3,7 +3,8 @@ # Dataset directories IMDB_OSCD = '~/Datasets/OSCDDataset/' -IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/' # Checkpoint templates CKP_LATEST = 'checkpoint_latest.pth' diff --git a/src/core/factories.py b/src/core/factories.py index 2db268a521dd978a84c691ad393819719d4e043c..64e644ae772e1e8d9e0461e122354ef0522c050a 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -164,7 +164,14 @@ def single_train_ds_factory(ds_name, C): elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AirChange + root = constants.IMDB_AIRCHANGE + ) + ) + elif ds_name == 'Lebedev': + configs.update( + dict( + root = constants.IMDB_LEBEDEV, + subsets = ('real',) ) ) else: @@ -199,7 +206,14 @@ def single_val_ds_factory(ds_name, C): elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AirChange + root = constants.IMDB_AIRCHANGE + ) + ) + elif ds_name == 'Lebedev': + configs.update( + dict( + root = constants.IMDB_LEBEDEV, + subsets = ('real',) ) ) else: diff --git a/src/core/trainers.py b/src/core/trainers.py index 3e1381ace84c06e55652ef89cc3b5e70d7e63673..c41863f353d28eaa8821bd569d3640a1755f5b12 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -242,6 +242,11 @@ class CDTrainer(Trainer): with torch.no_grad(): for i, (name, t1, t2, label) in enumerate(pb): + if self.phase == 'train' and i >= 16: + # Do not validate all images on training phase + pb.close() + self.logger.warning("validation ends early") + break t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device) prob = self.model(t1, t2) diff --git a/src/data/Lebedev.py b/src/data/Lebedev.py new file mode 100644 index 0000000000000000000000000000000000000000..bad5651ee7f37bdeebc2df780fee29050abc4c04 --- /dev/null +++ b/src/data/Lebedev.py @@ -0,0 +1,47 @@ +from glob import glob +from os.path import join, basename + +import numpy as np + +from . import CDDataset +from .common import default_loader + +class LebedevDataset(CDDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1, + subsets=('real', 'with_shift', 'without_shift') + ): + self.subsets = subsets + super().__init__(root, phase, transforms, repeats) + + def _read_file_paths(self): + t1_list, t2_list, label_list = [], [], [] + + for subset in self.subsets: + # Get subset directory + if subset == 'real': + subset_dir = join(self.root, 'Real', 'subset') + elif subset == 'with_shift': + subset_dir = join(self.root, 'Model', 'with_shift') + elif subset == 'without_shift': + subset_dir = join(self.root, 'Model', 'without_shift') + else: + raise RuntimeError('unrecognized key encountered') + + pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg' + refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern))) + t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs) + t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs) + + label_list.extend(refs) + t1_list.extend(t1s) + t2_list.extend(t2s) + + return t1_list, t2_list, label_list + + def fetch_label(self, label_path): + # To {0,1} + return (super().fetch_label(label_path) / 255.0).astype(np.uint8) \ No newline at end of file diff --git a/train9.sh b/train9.sh new file mode 100755 index 0000000000000000000000000000000000000000..f1dcab8ef9664a4ec0e874449ae765952bdbde92 --- /dev/null +++ b/train9.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Activate conda environment +source activate $ME + +# Change directory +cd src + +# Define constants +ARCHS=("siamdiff" "siamconc" "EF") +DATASETS=("AC_Szada" "AC_Tiszadob" "OSCD") + +# LOOP +for arch in ${ARCHS[@]} +do + for dataset in ${DATASETS[@]} + do + python train.py train --exp-config ../config_${arch}_${dataset}.yaml + done +done \ No newline at end of file