diff --git a/README.md b/README.md index a802bacecde3e1a540f7edded0de3ce691ea99d6..1368d8660f88981380db4f46b65e0ae92273bf52 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,15 @@ as the [official repo](https://github.com/rcdaudt/fully_convolutional_change_det ```bash # The network definition scripts are from the original repo -git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git +git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git +cd FCN-CD-PyTorch +mkdir exp +cd src ``` For training, try ```bash -# In the root directory of this repository -mkdir exp -cd src python train.py train --exp-config ../config_base.yaml ``` diff --git a/config_base.yaml b/config_base.yaml index f581c7126650b139df6c0f1e46ccbc823867f67a..8eee72e4f9a60e6ae029f8668c52dd3a7215b612 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -3,23 +3,23 @@ # Data # Common -dataset: OSCD -crop_size: 256 +dataset: AC_Tiszadob +crop_size: 112 num_workers: 1 -repeats: 1000 +repeats: 3200 # Optimizer -optimizer: Adam +optimizer: SGD lr: 0.001 lr_mode: const -weight_decay: 0.0001 +weight_decay: 0.0005 step: 2 # Training related batch_size: 32 -num_epochs: 40 +num_epochs: 10 resume: '' load_optim: True anew: False @@ -48,4 +48,4 @@ weights: # Model model: siamunet_conc -num_feats_in: 13 \ No newline at end of file +num_feats_in: 3 \ No newline at end of file diff --git a/src/constants.py b/src/constants.py index dfa17b4ced36f6618437d1dbdfd63bf65d329503..e41882fb592a9df376ee92e7f28b464d8944c551 100644 --- a/src/constants.py +++ b/src/constants.py @@ -3,7 +3,7 @@ # Dataset directories IMDB_OSCD = '~/Datasets/OSCDDataset/' -IMDB_AC = '~/Datasets/SZTAKI_AirChange_Benchmark/' +IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/' # Checkpoint templates CKP_LATEST = 'checkpoint_latest.pth' diff --git a/src/core/factories.py b/src/core/factories.py index 60e32db631bcd8b1aba7fc80618e4cf9b6a4999c..652b776bc041e413d3337da19b40ebf566cc51d4 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -109,7 +109,7 @@ def single_model_factory(model_name, C): def single_optim_factory(optim_name, params, C): - name = optim_name.upper() + name = optim_name.strip().upper() if name == 'ADAM': return torch.optim.Adam( params, @@ -117,6 +117,13 @@ def single_optim_factory(optim_name, params, C): lr=C.lr, weight_decay=C.weight_decay ) + elif name == 'SGD': + return torch.optim.SGD( + params, + lr=C.lr, + momentum=0.9, + weight_decay=C.weight_decay + ) else: raise NotImplementedError("{} is not a supported optimizer type".format(optim_name)) @@ -137,7 +144,8 @@ def single_critn_factory(critn_name, C): def single_train_ds_factory(ds_name, C): from data.augmentation import Compose, Crop, Flip - module = _import_module('data', ds_name.strip()) + ds_name = ds_name.strip() + module = _import_module('data', ds_name) dataset = getattr(module, ds_name+'Dataset') configs = dict( phase='train', @@ -150,17 +158,17 @@ def single_train_ds_factory(ds_name, C): root = constants.IMDB_OSCD ) ) - elif ds_name == 'AC': + elif ds_name.startswith('AC'): configs.update( dict( - root = constants.IMDB_AC + root = constants.IMDB_AirChange ) ) else: pass dataset_obj = dataset(**configs) - + return data.DataLoader( dataset_obj, batch_size=C.batch_size, @@ -171,11 +179,13 @@ def single_train_ds_factory(ds_name, C): def single_val_ds_factory(ds_name, C): - module = _import_module('data', ds_name.strip()) + ds_name = ds_name.strip() + module = _import_module('data', ds_name) dataset = getattr(module, ds_name+'Dataset') configs = dict( phase='val', - transforms=(None, None, None) + transforms=(None, None, None), + repeats=1 ) if ds_name == 'OSCD': configs.update( @@ -183,7 +193,7 @@ def single_val_ds_factory(ds_name, C): root = constants.IMDB_OSCD ) ) - elif ds_name == 'AC': + elif ds_name.startswith('AC'): configs.update( dict( root = constants.IMDB_AirChange @@ -246,4 +256,4 @@ def data_factory(dataset_names, phase, C): def metric_factory(metric_names, C): from utils import metrics name_list = _parse_input_names(metric_names) - return [getattr(metrics, name)() for name in name_list] + return [getattr(metrics, name.strip())() for name in name_list] diff --git a/src/core/trainers.py b/src/core/trainers.py index 2c5b139b894a9262a67495d47b88ac6578e0f056..3e1381ace84c06e55652ef89cc3b5e70d7e63673 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -212,9 +212,9 @@ class CDTrainer(Trainer): for i, (t1, t2, label) in enumerate(pb): t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device) - + prob = self.model(t1, t2) - + loss = self.criterion(prob, label) losses.update(loss.item(), n=self.batch_size) @@ -267,6 +267,6 @@ class CDTrainer(Trainer): self.logger.dump(desc) if store: - self.save_image(name[0], CM.squeeze(-1), epoch) + self.save_image(name[0], (CM*255).squeeze(-1), epoch) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) \ No newline at end of file diff --git a/src/data/AC_Szada.py b/src/data/AC_Szada.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b9619157a2f2dc4d892747d636fa27945a967c --- /dev/null +++ b/src/data/AC_Szada.py @@ -0,0 +1,23 @@ +from ._AirChange import _AirChangeDataset + + +class AC_SzadaDataset(_AirChangeDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + + @property + def LOCATION(self): + return 'Szada' + + @property + def TEST_SAMPLE_IDS(self): + return (0,) + + @property + def N_PAIRS(self): + return 7 \ No newline at end of file diff --git a/src/data/AC_Tiszadob.py b/src/data/AC_Tiszadob.py new file mode 100644 index 0000000000000000000000000000000000000000..830450d18d88b9953c7c6740c315b20b40efdbbe --- /dev/null +++ b/src/data/AC_Tiszadob.py @@ -0,0 +1,23 @@ +from ._AirChange import _AirChangeDataset + + +class AC_TiszadobDataset(_AirChangeDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + + @property + def LOCATION(self): + return 'Tiszadob' + + @property + def TEST_SAMPLE_IDS(self): + return (2,) + + @property + def N_PAIRS(self): + return 5 \ No newline at end of file diff --git a/src/data/OSCD.py b/src/data/OSCD.py index 53c1225161d89fa7c55e1cf78474fcb0ff32cd09..e6ba2a77c7b99e375b4d394ed67b41f60008a0b1 100644 --- a/src/data/OSCD.py +++ b/src/data/OSCD.py @@ -33,7 +33,7 @@ class OSCDDataset(CDDataset): with open(txt_file, 'r') as f: cities = [city.strip() for city in f.read().strip().split(',')] if self.phase == 'train': - # For training, use the first 10 pairs + # For training, use the first 11 pairs cities = cities[:-3] else: # For validation, use the remaining 3 pairs diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8555d843bc8b5933e166a4c9aa1612b252cf49 --- /dev/null +++ b/src/data/_AirChange.py @@ -0,0 +1,79 @@ +import abc +from os.path import join, basename +from multiprocessing import Manager + +import numpy as np + +from . import CDDataset +from .common import default_loader +from .augmentation import Crop + + +class _AirChangeDataset(CDDataset): + def __init__( + self, + root, phase='train', + transforms=(None, None, None), + repeats=1 + ): + super().__init__(root, phase, transforms, repeats) + self.cropper = Crop(bounds=(0, 0, 748, 448)) + + self._manager = Manager() + sync_list = self._manager.list + self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)]) + self.labels = sync_list([None]*self.N_PAIRS) + + @property + @abc.abstractmethod + def LOCATION(self): + return '' + + @property + @abc.abstractmethod + def TEST_SAMPLE_IDS(self): + return () + + @property + @abc.abstractmethod + def N_PAIRS(self): + return 0 + + def _read_file_paths(self): + if self.phase == 'train': + sample_ids = range(self.N_PAIRS) + t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + else: + t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS] + t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS] + label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS] + + return t1_list, t2_list, label_list + + def fetch_image(self, image_name): + _, i, t = image_name.split('-') + i, t = int(i), int(t[:-4]) + if self.images[t][i] is None: + image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1))) + self.images[t][i] = image if self.phase == 'train' else self.cropper(image) + return self.images[t][i] + + def fetch_label(self, label_name): + index = int(label_name.split('-')[1]) + if self.labels[index] is None: + label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt')) + label = (label / 255.0).astype(np.uint8) # To 0,1 + self.labels[index] = label if self.phase == 'train' else self.cropper(label) + return self.labels[index] + + @staticmethod + def _bmp_loader(bmp_path_wo_ext): + # Case insensitive .bmp loader + try: + return default_loader(bmp_path_wo_ext+'.bmp') + except FileNotFoundError: + return default_loader(bmp_path_wo_ext+'.BMP') + + \ No newline at end of file