Skip to content
Snippets Groups Projects
Commit 7f0846c1 authored by Bobholamovic's avatar Bobholamovic
Browse files

Add Lebedev

parent 4d7e2b88
No related branches found
No related tags found
No related merge requests found
...@@ -3,23 +3,23 @@ ...@@ -3,23 +3,23 @@
# Data # Data
# Common # Common
dataset: AC_Tiszadob dataset: Lebedev
crop_size: 112 crop_size: 224
num_workers: 1 num_workers: 1
repeats: 3200 repeats: 1
# Optimizer # Optimizer
optimizer: SGD optimizer: Adam
lr: 0.001 lr: 1e-4
lr_mode: const lr_mode: step
weight_decay: 0.0005 weight_decay: 0.0
step: 2 step: 5
# Training related # Training related
batch_size: 32 batch_size: 8
num_epochs: 10 num_epochs: 15
resume: '' resume: ''
load_optim: True load_optim: True
anew: False anew: False
...@@ -42,10 +42,10 @@ suffix_off: False ...@@ -42,10 +42,10 @@ suffix_off: False
# Criterion # Criterion
criterion: NLL criterion: NLL
weights: weights:
- 1.0 # Weight of no-change class - 0.117 # Weight of no-change class
- 10.0 # Weight of change class - 0.883 # Weight of change class
# Model # Model
model: siamunet_conc model: EF
num_feats_in: 3 num_feats_in: 6
\ No newline at end of file \ No newline at end of file
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# Dataset directories # Dataset directories
IMDB_OSCD = '~/Datasets/OSCDDataset/' IMDB_OSCD = '~/Datasets/OSCDDataset/'
IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/' IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/'
IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/'
# Checkpoint templates # Checkpoint templates
CKP_LATEST = 'checkpoint_latest.pth' CKP_LATEST = 'checkpoint_latest.pth'
......
...@@ -164,7 +164,14 @@ def single_train_ds_factory(ds_name, C): ...@@ -164,7 +164,14 @@ def single_train_ds_factory(ds_name, C):
elif ds_name.startswith('AC'): elif ds_name.startswith('AC'):
configs.update( configs.update(
dict( dict(
root = constants.IMDB_AirChange root = constants.IMDB_AIRCHANGE
)
)
elif ds_name == 'Lebedev':
configs.update(
dict(
root = constants.IMDB_LEBEDEV,
subsets = ('real',)
) )
) )
else: else:
...@@ -199,7 +206,14 @@ def single_val_ds_factory(ds_name, C): ...@@ -199,7 +206,14 @@ def single_val_ds_factory(ds_name, C):
elif ds_name.startswith('AC'): elif ds_name.startswith('AC'):
configs.update( configs.update(
dict( dict(
root = constants.IMDB_AirChange root = constants.IMDB_AIRCHANGE
)
)
elif ds_name == 'Lebedev':
configs.update(
dict(
root = constants.IMDB_LEBEDEV,
subsets = ('real',)
) )
) )
else: else:
......
...@@ -242,6 +242,11 @@ class CDTrainer(Trainer): ...@@ -242,6 +242,11 @@ class CDTrainer(Trainer):
with torch.no_grad(): with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb): for i, (name, t1, t2, label) in enumerate(pb):
if self.phase == 'train' and i >= 16:
# Do not validate all images on training phase
pb.close()
self.logger.warning("validation ends early")
break
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device) t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2) prob = self.model(t1, t2)
......
from glob import glob
from os.path import join, basename
import numpy as np
from . import CDDataset
from .common import default_loader
class LebedevDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
subsets=('real', 'with_shift', 'without_shift')
):
self.subsets = subsets
super().__init__(root, phase, transforms, repeats)
def _read_file_paths(self):
t1_list, t2_list, label_list = [], [], []
for subset in self.subsets:
# Get subset directory
if subset == 'real':
subset_dir = join(self.root, 'Real', 'subset')
elif subset == 'with_shift':
subset_dir = join(self.root, 'Model', 'with_shift')
elif subset == 'without_shift':
subset_dir = join(self.root, 'Model', 'without_shift')
else:
raise RuntimeError('unrecognized key encountered')
pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
label_list.extend(refs)
t1_list.extend(t1s)
t2_list.extend(t2s)
return t1_list, t2_list, label_list
def fetch_label(self, label_path):
# To {0,1}
return (super().fetch_label(label_path) / 255.0).astype(np.uint8)
\ No newline at end of file
#!/bin/bash
# Activate conda environment
source activate $ME
# Change directory
cd src
# Define constants
ARCHS=("siamdiff" "siamconc" "EF")
DATASETS=("AC_Szada" "AC_Tiszadob" "OSCD")
# LOOP
for arch in ${ARCHS[@]}
do
for dataset in ${DATASETS[@]}
do
python train.py train --exp-config ../config_${arch}_${dataset}.yaml
done
done
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment