From 7f0846c13cd76c56fc4c60280e3875a16822056b Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Sat, 25 Jan 2020 15:17:17 +0800
Subject: [PATCH] Add Lebedev

---
 config_base.yaml      | 28 +++++++++++++-------------
 src/constants.py      |  3 ++-
 src/core/factories.py | 18 +++++++++++++++--
 src/core/trainers.py  |  5 +++++
 src/data/Lebedev.py   | 47 +++++++++++++++++++++++++++++++++++++++++++
 train9.sh             | 20 ++++++++++++++++++
 6 files changed, 104 insertions(+), 17 deletions(-)
 create mode 100644 src/data/Lebedev.py
 create mode 100755 train9.sh

diff --git a/config_base.yaml b/config_base.yaml
index 8eee72e..838ac12 100644
--- a/config_base.yaml
+++ b/config_base.yaml
@@ -3,23 +3,23 @@
 
 # Data
 # Common
-dataset: AC_Tiszadob
-crop_size: 112
+dataset: Lebedev
+crop_size: 224
 num_workers: 1
-repeats: 3200
+repeats: 1
 
 
 # Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
+optimizer: Adam
+lr: 1e-4
+lr_mode: step
+weight_decay: 0.0
+step: 5
 
 
 # Training related
-batch_size: 32
-num_epochs: 10
+batch_size: 8 
+num_epochs: 15
 resume: ''
 load_optim: True
 anew: False
@@ -42,10 +42,10 @@ suffix_off: False
 # Criterion
 criterion: NLL
 weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
+  - 0.117   # Weight of no-change class
+  - 0.883   # Weight of change class
 
 
 # Model
-model: siamunet_conc
-num_feats_in: 3
\ No newline at end of file
+model: EF
+num_feats_in: 6
\ No newline at end of file
diff --git a/src/constants.py b/src/constants.py
index e41882f..c9bfdd0 100644
--- a/src/constants.py
+++ b/src/constants.py
@@ -3,7 +3,8 @@
 
 # Dataset directories
 IMDB_OSCD = '~/Datasets/OSCDDataset/'
-IMDB_AirChange = '~/Datasets/SZTAKI_AirChange_Benchmark/'
+IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/'
+IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/'
 
 # Checkpoint templates
 CKP_LATEST = 'checkpoint_latest.pth'
diff --git a/src/core/factories.py b/src/core/factories.py
index 2db268a..64e644a 100644
--- a/src/core/factories.py
+++ b/src/core/factories.py
@@ -164,7 +164,14 @@ def single_train_ds_factory(ds_name, C):
     elif ds_name.startswith('AC'):
         configs.update(
             dict(
-                root = constants.IMDB_AirChange
+                root = constants.IMDB_AIRCHANGE
+            )
+        )
+    elif ds_name == 'Lebedev':
+        configs.update(
+            dict(
+                root = constants.IMDB_LEBEDEV,
+                subsets = ('real',)
             )
         )
     else:
@@ -199,7 +206,14 @@ def single_val_ds_factory(ds_name, C):
     elif ds_name.startswith('AC'):
         configs.update(
             dict(
-                root = constants.IMDB_AirChange
+                root = constants.IMDB_AIRCHANGE
+            )
+        )
+    elif ds_name == 'Lebedev':
+        configs.update(
+            dict(
+                root = constants.IMDB_LEBEDEV,
+                subsets = ('real',)
             )
         )
     else:
diff --git a/src/core/trainers.py b/src/core/trainers.py
index 3e1381a..c41863f 100644
--- a/src/core/trainers.py
+++ b/src/core/trainers.py
@@ -242,6 +242,11 @@ class CDTrainer(Trainer):
 
         with torch.no_grad():
             for i, (name, t1, t2, label) in enumerate(pb):
+                if self.phase == 'train' and i >= 16: 
+                    # Do not validate all images on training phase
+                    pb.close()
+                    self.logger.warning("validation ends early")
+                    break
                 t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
 
                 prob = self.model(t1, t2)
diff --git a/src/data/Lebedev.py b/src/data/Lebedev.py
new file mode 100644
index 0000000..bad5651
--- /dev/null
+++ b/src/data/Lebedev.py
@@ -0,0 +1,47 @@
+from glob import glob
+from os.path import join, basename
+
+import numpy as np
+
+from . import CDDataset
+from .common import default_loader
+
+class LebedevDataset(CDDataset):
+    def __init__(
+        self, 
+        root, phase='train', 
+        transforms=(None, None, None), 
+        repeats=1,
+        subsets=('real', 'with_shift', 'without_shift')
+    ):
+        self.subsets = subsets
+        super().__init__(root, phase, transforms, repeats)
+
+    def _read_file_paths(self):
+        t1_list, t2_list, label_list = [], [], []
+
+        for subset in self.subsets:
+            # Get subset directory
+            if subset == 'real':
+                subset_dir = join(self.root, 'Real', 'subset')
+            elif subset == 'with_shift':
+                subset_dir = join(self.root, 'Model', 'with_shift')
+            elif subset == 'without_shift':
+                subset_dir = join(self.root, 'Model', 'without_shift')
+            else:
+                raise RuntimeError('unrecognized key encountered')
+
+            pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
+            refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
+            t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
+            t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
+
+            label_list.extend(refs)
+            t1_list.extend(t1s)
+            t2_list.extend(t2s)
+
+        return t1_list, t2_list, label_list
+
+    def fetch_label(self, label_path):
+        # To {0,1}
+        return (super().fetch_label(label_path) / 255.0).astype(np.uint8)  
\ No newline at end of file
diff --git a/train9.sh b/train9.sh
new file mode 100755
index 0000000..f1dcab8
--- /dev/null
+++ b/train9.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Activate conda environment
+source activate $ME
+
+# Change directory
+cd src
+
+# Define constants
+ARCHS=("siamdiff" "siamconc" "EF")
+DATASETS=("AC_Szada" "AC_Tiszadob" "OSCD")
+
+# LOOP
+for arch in ${ARCHS[@]}
+do
+    for dataset in ${DATASETS[@]}
+    do
+        python train.py train --exp-config ../config_${arch}_${dataset}.yaml
+    done
+done
\ No newline at end of file
-- 
GitLab