Skip to content
Snippets Groups Projects
Commit 0e08f673 authored by Bobholamovic's avatar Bobholamovic
Browse files

Add Air Change Dataset

parent 5427d7e9
Branches
No related tags found
No related merge requests found
...@@ -13,14 +13,14 @@ as the [official repo](https://github.com/rcdaudt/fully_convolutional_change_det ...@@ -13,14 +13,14 @@ as the [official repo](https://github.com/rcdaudt/fully_convolutional_change_det
```bash ```bash
# The network definition scripts are from the original repo # The network definition scripts are from the original repo
git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git
cd FCN-CD-PyTorch
mkdir exp
cd src
``` ```
For training, try For training, try
```bash ```bash
# In the root directory of this repository
mkdir exp
cd src
python train.py train --exp-config ../config_base.yaml python train.py train --exp-config ../config_base.yaml
``` ```
......
...@@ -3,23 +3,23 @@ ...@@ -3,23 +3,23 @@
# Data # Data
# Common # Common
dataset: OSCD dataset: AC_Tiszadob
crop_size: 256 crop_size: 112
num_workers: 1 num_workers: 1
repeats: 1000 repeats: 3200
# Optimizer # Optimizer
optimizer: Adam optimizer: SGD
lr: 0.001 lr: 0.001
lr_mode: const lr_mode: const
weight_decay: 0.0001 weight_decay: 0.0005
step: 2 step: 2
# Training related # Training related
batch_size: 32 batch_size: 32
num_epochs: 40 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
anew: False anew: False
...@@ -48,4 +48,4 @@ weights: ...@@ -48,4 +48,4 @@ weights:
# Model # Model
model: siamunet_conc model: siamunet_conc
num_feats_in: 13 num_feats_in: 3
\ No newline at end of file \ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Dataset directories # Dataset directories
IMDB_OSCD = '~/Datasets/OSCDDataset/' IMDB_OSCD = '~/Datasets/OSCDDataset/'
IMDB_AC = '~/Datasets/SZTAKI_AirChange_Benchmark/' IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/'
# Checkpoint templates # Checkpoint templates
CKP_LATEST = 'checkpoint_latest.pth' CKP_LATEST = 'checkpoint_latest.pth'
......
...@@ -109,7 +109,7 @@ def single_model_factory(model_name, C): ...@@ -109,7 +109,7 @@ def single_model_factory(model_name, C):
def single_optim_factory(optim_name, params, C): def single_optim_factory(optim_name, params, C):
name = optim_name.upper() name = optim_name.strip().upper()
if name == 'ADAM': if name == 'ADAM':
return torch.optim.Adam( return torch.optim.Adam(
params, params,
...@@ -117,6 +117,13 @@ def single_optim_factory(optim_name, params, C): ...@@ -117,6 +117,13 @@ def single_optim_factory(optim_name, params, C):
lr=C.lr, lr=C.lr,
weight_decay=C.weight_decay weight_decay=C.weight_decay
) )
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else: else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name)) raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
...@@ -137,7 +144,8 @@ def single_critn_factory(critn_name, C): ...@@ -137,7 +144,8 @@ def single_critn_factory(critn_name, C):
def single_train_ds_factory(ds_name, C): def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip from data.augmentation import Compose, Crop, Flip
module = _import_module('data', ds_name.strip()) ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset') dataset = getattr(module, ds_name+'Dataset')
configs = dict( configs = dict(
phase='train', phase='train',
...@@ -150,10 +158,10 @@ def single_train_ds_factory(ds_name, C): ...@@ -150,10 +158,10 @@ def single_train_ds_factory(ds_name, C):
root = constants.IMDB_OSCD root = constants.IMDB_OSCD
) )
) )
elif ds_name == 'AC': elif ds_name.startswith('AC'):
configs.update( configs.update(
dict( dict(
root = constants.IMDB_AC root = constants.IMDB_AirChange
) )
) )
else: else:
...@@ -171,11 +179,13 @@ def single_train_ds_factory(ds_name, C): ...@@ -171,11 +179,13 @@ def single_train_ds_factory(ds_name, C):
def single_val_ds_factory(ds_name, C): def single_val_ds_factory(ds_name, C):
module = _import_module('data', ds_name.strip()) ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset') dataset = getattr(module, ds_name+'Dataset')
configs = dict( configs = dict(
phase='val', phase='val',
transforms=(None, None, None) transforms=(None, None, None),
repeats=1
) )
if ds_name == 'OSCD': if ds_name == 'OSCD':
configs.update( configs.update(
...@@ -183,7 +193,7 @@ def single_val_ds_factory(ds_name, C): ...@@ -183,7 +193,7 @@ def single_val_ds_factory(ds_name, C):
root = constants.IMDB_OSCD root = constants.IMDB_OSCD
) )
) )
elif ds_name == 'AC': elif ds_name.startswith('AC'):
configs.update( configs.update(
dict( dict(
root = constants.IMDB_AirChange root = constants.IMDB_AirChange
...@@ -246,4 +256,4 @@ def data_factory(dataset_names, phase, C): ...@@ -246,4 +256,4 @@ def data_factory(dataset_names, phase, C):
def metric_factory(metric_names, C): def metric_factory(metric_names, C):
from utils import metrics from utils import metrics
name_list = _parse_input_names(metric_names) name_list = _parse_input_names(metric_names)
return [getattr(metrics, name)() for name in name_list] return [getattr(metrics, name.strip())() for name in name_list]
...@@ -267,6 +267,6 @@ class CDTrainer(Trainer): ...@@ -267,6 +267,6 @@ class CDTrainer(Trainer):
self.logger.dump(desc) self.logger.dump(desc)
if store: if store:
self.save_image(name[0], CM.squeeze(-1), epoch) self.save_image(name[0], (CM*255).squeeze(-1), epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
from ._AirChange import _AirChangeDataset
class AC_SzadaDataset(_AirChangeDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
@property
def LOCATION(self):
return 'Szada'
@property
def TEST_SAMPLE_IDS(self):
return (0,)
@property
def N_PAIRS(self):
return 7
\ No newline at end of file
from ._AirChange import _AirChangeDataset
class AC_TiszadobDataset(_AirChangeDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
@property
def LOCATION(self):
return 'Tiszadob'
@property
def TEST_SAMPLE_IDS(self):
return (2,)
@property
def N_PAIRS(self):
return 5
\ No newline at end of file
...@@ -33,7 +33,7 @@ class OSCDDataset(CDDataset): ...@@ -33,7 +33,7 @@ class OSCDDataset(CDDataset):
with open(txt_file, 'r') as f: with open(txt_file, 'r') as f:
cities = [city.strip() for city in f.read().strip().split(',')] cities = [city.strip() for city in f.read().strip().split(',')]
if self.phase == 'train': if self.phase == 'train':
# For training, use the first 10 pairs # For training, use the first 11 pairs
cities = cities[:-3] cities = cities[:-3]
else: else:
# For validation, use the remaining 3 pairs # For validation, use the remaining 3 pairs
......
import abc
from os.path import join, basename
from multiprocessing import Manager
import numpy as np
from . import CDDataset
from .common import default_loader
from .augmentation import Crop
class _AirChangeDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
self.cropper = Crop(bounds=(0, 0, 748, 448))
self._manager = Manager()
sync_list = self._manager.list
self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
self.labels = sync_list([None]*self.N_PAIRS)
@property
@abc.abstractmethod
def LOCATION(self):
return ''
@property
@abc.abstractmethod
def TEST_SAMPLE_IDS(self):
return ()
@property
@abc.abstractmethod
def N_PAIRS(self):
return 0
def _read_file_paths(self):
if self.phase == 'train':
sample_ids = range(self.N_PAIRS)
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
else:
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS]
return t1_list, t2_list, label_list
def fetch_image(self, image_name):
_, i, t = image_name.split('-')
i, t = int(i), int(t[:-4])
if self.images[t][i] is None:
image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
return self.images[t][i]
def fetch_label(self, label_name):
index = int(label_name.split('-')[1])
if self.labels[index] is None:
label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt'))
label = (label / 255.0).astype(np.uint8) # To 0,1
self.labels[index] = label if self.phase == 'train' else self.cropper(label)
return self.labels[index]
@staticmethod
def _bmp_loader(bmp_path_wo_ext):
# Case insensitive .bmp loader
try:
return default_loader(bmp_path_wo_ext+'.bmp')
except FileNotFoundError:
return default_loader(bmp_path_wo_ext+'.BMP')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment