From 0e08f67378e61f7afe3341afb01e6b74ec9ad426 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Tue, 10 Dec 2019 02:08:05 +0800 Subject: [PATCH] Add Air Change Dataset --- README.md | 8 ++--- config_base.yaml | 14 ++++---- src/constants.py | 2 +- src/core/factories.py | 28 ++++++++++----- src/core/trainers.py | 6 ++-- src/data/AC_Szada.py | 23 ++++++++++++ src/data/AC_Tiszadob.py | 23 ++++++++++++ src/data/OSCD.py | 2 +- src/data/_AirChange.py | 79 +++++++++++++++++++++++++++++++++++++++++ 9 files changed, 160 insertions(+), 25 deletions(-) create mode 100644 src/data/AC_Szada.py create mode 100644 src/data/AC_Tiszadob.py create mode 100644 src/data/_AirChange.py diff --git a/README.md b/README.md index a802bac..1368d86 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,15 @@ as the [official repo](https://github.com/rcdaudt/fully_convolutional_change_det ```bash # The network definition scripts are from the original repo -git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git +git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git +cd FCN-CD-PyTorch +mkdir exp +cd src ``` For training, try ```bash -# In the root directory of this repository -mkdir exp -cd src python train.py train --exp-config ../config_base.yaml ``` diff --git a/config_base.yaml b/config_base.yaml index f581c71..8eee72e 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -3,23 +3,23 @@ # Data # Common -dataset: OSCD -crop_size: 256 +dataset: AC_Tiszadob +crop_size: 112 num_workers: 1 -repeats: 1000 +repeats: 3200 # Optimizer -optimizer: Adam +optimizer: SGD lr: 0.001 lr_mode: const -weight_decay: 0.0001 +weight_decay: 0.0005 step: 2 # Training related batch_size: 32 -num_epochs: 40 +num_epochs: 10 resume: '' load_optim: True anew: False @@ -48,4 +48,4 @@ weights: # Model model: siamunet_conc -num_feats_in: 13 \ No newline at end of file +num_feats_in: 3 \ No newline at end of file diff --git a/src/constants.py b/src/constants.py index dfa17b4..e41882f 100644 --- a/src/constants.py +++ b/src/constants.py @@ -3,7 +3,7 @@ # Dataset directories IMDB_OSCD = '~/Datasets/OSCDDataset/' -IMDB_AC = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/' # Checkpoint templates CKP_LATEST = 'checkpoint_latest.pth' diff --git a/src/core/factories.py b/src/core/factories.py index 60e32db..652b776 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -109,7 +109,7 @@ def single_model_factory(model_name, C): def single_optim_factory(optim_name, params, C): - name = optim_name.upper() + name = optim_name.strip().upper() if name == 'ADAM': return torch.optim.Adam( params, @@ -117,6 +117,13 @@ def single_optim_factory(optim_name, params, C): lr=C.lr, weight_decay=C.weight_decay ) + elif name == 'SGD': + return torch.optim.SGD( + params, + lr=C.lr, + momentum=0.9, + weight_decay=C.weight_decay + ) else: raise NotImplementedError("{} is not a supported optimizer type".format(optim_name)) @@ -137,7 +144,8 @@ def single_critn_factory(critn_name, C): def single_train_ds_factory(ds_name, C): from data.augmentation import Compose, Crop, Flip - module = _import_module('data', ds_name.strip()) + ds_name = ds_name.strip() + module = _import_module('data', ds_name) dataset = getattr(module, ds_name+'Dataset') configs = dict( phase='train', @@ -150,17 +158,17 @@ def single_train_ds_factory(ds_name, C): root = constants.IMDB_OSCD ) ) - elif ds_name == 'AC': + elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AC + root = constants.IMDB_AirChange ) ) else: pass dataset_obj = dataset(**configs) - + return data.DataLoader( dataset_obj, batch_size=C.batch_size, @@ -171,11 +179,13 @@ def single_train_ds_factory(ds_name, C): def single_val_ds_factory(ds_name, C): - module = _import_module('data', ds_name.strip()) + ds_name = ds_name.strip() + module = _import_module('data', ds_name) dataset = getattr(module, ds_name+'Dataset') configs = dict( phase='val', - transforms=(None, None, None) + transforms=(None, None, None), + repeats=1 ) if ds_name == 'OSCD': configs.update( @@ -183,7 +193,7 @@ def single_val_ds_factory(ds_name, C): root = constants.IMDB_OSCD ) ) - elif ds_name == 'AC': + elif ds_name.startswith('AC'): configs.update( dict( root = constants.IMDB_AirChange @@ -246,4 +256,4 @@ def data_factory(dataset_names, phase, C): def metric_factory(metric_names, C): from utils import metrics name_list = _parse_input_names(metric_names) - return [getattr(metrics, name)() for name in name_list] + return [getattr(metrics, name.strip())() for name in name_list] diff --git a/src/core/trainers.py b/src/core/trainers.py index 2c5b139..3e1381a 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -212,9 +212,9 @@ class CDTrainer(Trainer): for i, (t1, t2, label) in enumerate(pb): t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device) - + prob = self.model(t1, t2) - + loss = self.criterion(prob, label) losses.update(loss.item(), n=self.batch_size) @@ -267,6 +267,6 @@ class CDTrainer(Trainer): self.logger.dump(desc) if store: - self.save_image(name[0], CM.squeeze(-1), epoch) + self.save_image(name[0], (CM*255).squeeze(-1), epoch) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) \ No newline at end of file diff --git a/src/data/AC_Szada.py b/src/data/AC_Szada.py new file mode 100644 index 0000000..a9b9619 --- /dev/null +++ b/src/data/AC_Szada.py @@ -0,0 +1,23 @@ +from ._AirChange import _AirChangeDataset + + +class AC_SzadaDataset(_AirChangeDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + + @property + def LOCATION(self): + return 'Szada' + + @property + def TEST_SAMPLE_IDS(self): + return (0,) + + @property + def N_PAIRS(self): + return 7 \ No newline at end of file diff --git a/src/data/AC_Tiszadob.py b/src/data/AC_Tiszadob.py new file mode 100644 index 0000000..830450d --- /dev/null +++ b/src/data/AC_Tiszadob.py @@ -0,0 +1,23 @@ +from ._AirChange import _AirChangeDataset + + +class AC_TiszadobDataset(_AirChangeDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + + @property + def LOCATION(self): + return 'Tiszadob' + + @property + def TEST_SAMPLE_IDS(self): + return (2,) + + @property + def N_PAIRS(self): + return 5 \ No newline at end of file diff --git a/src/data/OSCD.py b/src/data/OSCD.py index 53c1225..e6ba2a7 100644 --- a/src/data/OSCD.py +++ b/src/data/OSCD.py @@ -33,7 +33,7 @@ class OSCDDataset(CDDataset): with open(txt_file, 'r') as f: cities = [city.strip() for city in f.read().strip().split(',')] if self.phase == 'train': - # For training, use the first 10 pairs + # For training, use the first 11 pairs cities = cities[:-3] else: # For validation, use the remaining 3 pairs diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py new file mode 100644 index 0000000..4d8555d --- /dev/null +++ b/src/data/_AirChange.py @@ -0,0 +1,79 @@ +import abc +from os.path import join, basename +from multiprocessing import Manager + +import numpy as np + +from . import CDDataset +from .common import default_loader +from .augmentation import Crop + + +class _AirChangeDataset(CDDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + self.cropper = Crop(bounds=(0, 0, 748, 448)) + + self._manager = Manager() + sync_list = self._manager.list + self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)]) + self.labels = sync_list([None]*self.N_PAIRS) + + @property + @abc.abstractmethod + def LOCATION(self): + return '' + + @property + @abc.abstractmethod + def TEST_SAMPLE_IDS(self): + return () + + @property + @abc.abstractmethod + def N_PAIRS(self): + return 0 + + def _read_file_paths(self): + if self.phase == 'train': + sample_ids = range(self.N_PAIRS) + t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + else: + t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS] + t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS] + label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS] + + return t1_list, t2_list, label_list + + def fetch_image(self, image_name): + _, i, t = image_name.split('-') + i, t = int(i), int(t[:-4]) + if self.images[t][i] is None: + image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1))) + self.images[t][i] = image if self.phase == 'train' else self.cropper(image) + return self.images[t][i] + + def fetch_label(self, label_name): + index = int(label_name.split('-')[1]) + if self.labels[index] is None: + label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt')) + label = (label / 255.0).astype(np.uint8) # To 0,1 + self.labels[index] = label if self.phase == 'train' else self.cropper(label) + return self.labels[index] + + @staticmethod + def _bmp_loader(bmp_path_wo_ext): + # Case insensitive .bmp loader + try: + return default_loader(bmp_path_wo_ext+'.bmp') + except FileNotFoundError: + return default_loader(bmp_path_wo_ext+'.BMP') + + \ No newline at end of file -- GitLab