Skip to content
Snippets Groups Projects
Commit 0e08f673 authored by Bobholamovic's avatar Bobholamovic
Browse files

Add Air Change Dataset

parent 5427d7e9
No related branches found
No related tags found
No related merge requests found
......@@ -12,15 +12,15 @@ as the [official repo](https://github.com/rcdaudt/fully_convolutional_change_det
```bash
# The network definition scripts are from the original repo
git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git
git clone --recurse-submodules git@github.com:Bobholamovic/FCN-CD-PyTorch.git
cd FCN-CD-PyTorch
mkdir exp
cd src
```
For training, try
```bash
# In the root directory of this repository
mkdir exp
cd src
python train.py train --exp-config ../config_base.yaml
```
......
......@@ -3,23 +3,23 @@
# Data
# Common
dataset: OSCD
crop_size: 256
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 1000
repeats: 3200
# Optimizer
optimizer: Adam
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0001
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 40
num_epochs: 10
resume: ''
load_optim: True
anew: False
......@@ -48,4 +48,4 @@ weights:
# Model
model: siamunet_conc
num_feats_in: 13
\ No newline at end of file
num_feats_in: 3
\ No newline at end of file
......@@ -3,7 +3,7 @@
# Dataset directories
IMDB_OSCD = '~/Datasets/OSCDDataset/'
IMDB_AC = '~/Datasets/SZTAKI_AirChange_Benchmark/'
IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/'
# Checkpoint templates
CKP_LATEST = 'checkpoint_latest.pth'
......
......@@ -109,7 +109,7 @@ def single_model_factory(model_name, C):
def single_optim_factory(optim_name, params, C):
name = optim_name.upper()
name = optim_name.strip().upper()
if name == 'ADAM':
return torch.optim.Adam(
params,
......@@ -117,6 +117,13 @@ def single_optim_factory(optim_name, params, C):
lr=C.lr,
weight_decay=C.weight_decay
)
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
......@@ -137,7 +144,8 @@ def single_critn_factory(critn_name, C):
def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip
module = _import_module('data', ds_name.strip())
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
......@@ -150,17 +158,17 @@ def single_train_ds_factory(ds_name, C):
root = constants.IMDB_OSCD
)
)
elif ds_name == 'AC':
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AC
root = constants.IMDB_AirChange
)
)
else:
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
......@@ -171,11 +179,13 @@ def single_train_ds_factory(ds_name, C):
def single_val_ds_factory(ds_name, C):
module = _import_module('data', ds_name.strip())
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
transforms=(None, None, None)
transforms=(None, None, None),
repeats=1
)
if ds_name == 'OSCD':
configs.update(
......@@ -183,7 +193,7 @@ def single_val_ds_factory(ds_name, C):
root = constants.IMDB_OSCD
)
)
elif ds_name == 'AC':
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AirChange
......@@ -246,4 +256,4 @@ def data_factory(dataset_names, phase, C):
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name)() for name in name_list]
return [getattr(metrics, name.strip())() for name in name_list]
......@@ -212,9 +212,9 @@ class CDTrainer(Trainer):
for i, (t1, t2, label) in enumerate(pb):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
......@@ -267,6 +267,6 @@ class CDTrainer(Trainer):
self.logger.dump(desc)
if store:
self.save_image(name[0], CM.squeeze(-1), epoch)
self.save_image(name[0], (CM*255).squeeze(-1), epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
from ._AirChange import _AirChangeDataset
class AC_SzadaDataset(_AirChangeDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
@property
def LOCATION(self):
return 'Szada'
@property
def TEST_SAMPLE_IDS(self):
return (0,)
@property
def N_PAIRS(self):
return 7
\ No newline at end of file
from ._AirChange import _AirChangeDataset
class AC_TiszadobDataset(_AirChangeDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
@property
def LOCATION(self):
return 'Tiszadob'
@property
def TEST_SAMPLE_IDS(self):
return (2,)
@property
def N_PAIRS(self):
return 5
\ No newline at end of file
......@@ -33,7 +33,7 @@ class OSCDDataset(CDDataset):
with open(txt_file, 'r') as f:
cities = [city.strip() for city in f.read().strip().split(',')]
if self.phase == 'train':
# For training, use the first 10 pairs
# For training, use the first 11 pairs
cities = cities[:-3]
else:
# For validation, use the remaining 3 pairs
......
import abc
from os.path import join, basename
from multiprocessing import Manager
import numpy as np
from . import CDDataset
from .common import default_loader
from .augmentation import Crop
class _AirChangeDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
self.cropper = Crop(bounds=(0, 0, 748, 448))
self._manager = Manager()
sync_list = self._manager.list
self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
self.labels = sync_list([None]*self.N_PAIRS)
@property
@abc.abstractmethod
def LOCATION(self):
return ''
@property
@abc.abstractmethod
def TEST_SAMPLE_IDS(self):
return ()
@property
@abc.abstractmethod
def N_PAIRS(self):
return 0
def _read_file_paths(self):
if self.phase == 'train':
sample_ids = range(self.N_PAIRS)
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
else:
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS]
return t1_list, t2_list, label_list
def fetch_image(self, image_name):
_, i, t = image_name.split('-')
i, t = int(i), int(t[:-4])
if self.images[t][i] is None:
image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
return self.images[t][i]
def fetch_label(self, label_name):
index = int(label_name.split('-')[1])
if self.labels[index] is None:
label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt'))
label = (label / 255.0).astype(np.uint8) # To 0,1
self.labels[index] = label if self.phase == 'train' else self.cropper(label)
return self.labels[index]
@staticmethod
def _bmp_loader(bmp_path_wo_ext):
# Case insensitive .bmp loader
try:
return default_loader(bmp_path_wo_ext+'.bmp')
except FileNotFoundError:
return default_loader(bmp_path_wo_ext+'.BMP')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment