From 9126a3ce33e88b6fa23a5575c9b8d6b9a5bc6a77 Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Sat, 2 May 2020 17:56:25 +0800
Subject: [PATCH] Add option to save optim state

---
 config_EF_AC_Szada.yaml          |  1 +
 config_EF_AC_Tiszadob.yaml       |  1 +
 config_EF_OSCD.yaml              |  1 +
 config_base.yaml                 |  1 +
 config_siamconc_AC_Szada.yaml    |  1 +
 config_siamconc_AC_Tiszadob.yaml |  1 +
 config_siamconc_OSCD.yaml        |  1 +
 config_siamdiff_AC_Szada.yaml    |  1 +
 config_siamdiff_AC_Tiszadob.yaml |  1 +
 config_siamdiff_OSCD.yaml        |  1 +
 src/core/trainers.py             | 13 +++++++------
 src/train.py                     |  1 +
 12 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/config_EF_AC_Szada.yaml b/config_EF_AC_Szada.yaml
index 7e969ae..0c9b917 100644
--- a/config_EF_AC_Szada.yaml
+++ b/config_EF_AC_Szada.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_EF_AC_Tiszadob.yaml b/config_EF_AC_Tiszadob.yaml
index 3c93b3a..91632dd 100644
--- a/config_EF_AC_Tiszadob.yaml
+++ b/config_EF_AC_Tiszadob.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_EF_OSCD.yaml b/config_EF_OSCD.yaml
index a91842c..9902142 100644
--- a/config_EF_OSCD.yaml
+++ b/config_EF_OSCD.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_base.yaml b/config_base.yaml
index 838ac12..f1b5ae9 100644
--- a/config_base.yaml
+++ b/config_base.yaml
@@ -22,6 +22,7 @@ batch_size: 8
 num_epochs: 15
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamconc_AC_Szada.yaml b/config_siamconc_AC_Szada.yaml
index 4142e68..5be5494 100644
--- a/config_siamconc_AC_Szada.yaml
+++ b/config_siamconc_AC_Szada.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamconc_AC_Tiszadob.yaml b/config_siamconc_AC_Tiszadob.yaml
index 8eee72e..979facf 100644
--- a/config_siamconc_AC_Tiszadob.yaml
+++ b/config_siamconc_AC_Tiszadob.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamconc_OSCD.yaml b/config_siamconc_OSCD.yaml
index 6a25726..640f83d 100644
--- a/config_siamconc_OSCD.yaml
+++ b/config_siamconc_OSCD.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamdiff_AC_Szada.yaml b/config_siamdiff_AC_Szada.yaml
index 3c0ad37..055f970 100644
--- a/config_siamdiff_AC_Szada.yaml
+++ b/config_siamdiff_AC_Szada.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamdiff_AC_Tiszadob.yaml b/config_siamdiff_AC_Tiszadob.yaml
index 02f67ba..931fc0c 100644
--- a/config_siamdiff_AC_Tiszadob.yaml
+++ b/config_siamdiff_AC_Tiszadob.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/config_siamdiff_OSCD.yaml b/config_siamdiff_OSCD.yaml
index 90a1671..ca28ae3 100644
--- a/config_siamdiff_OSCD.yaml
+++ b/config_siamdiff_OSCD.yaml
@@ -22,6 +22,7 @@ batch_size: 32
 num_epochs: 10
 resume: ''
 load_optim: True
+save_optim: True
 anew: False
 trace_freq: 1
 device: cuda
diff --git a/src/core/trainers.py b/src/core/trainers.py
index 445b96a..83330b3 100644
--- a/src/core/trainers.py
+++ b/src/core/trainers.py
@@ -106,7 +106,11 @@ class Trainer:
                                 acc, epoch, max_acc, best_epoch))
 
             # The checkpoint saves next epoch
-            self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
+            self._save_checkpoint(
+                self.model.state_dict(), 
+                self.optimizer.state_dict() if self.ctx['save_optim'] else {}, 
+                (max_acc, best_epoch), epoch+1, is_best
+            )
         
     def evaluate(self):
         if self.checkpoint: 
@@ -164,11 +168,8 @@ class Trainer:
             else:
                 self._init_max_acc_and_epoch = max_acc_and_epoch
             if self.ctx['load_optim'] and self.is_training:
-                try:
-                    # Note that weight decay might be modified here
-                    self.optimizer.load_state_dict(checkpoint['optimizer'])
-                except KeyError:
-                    self.logger.warning("Warning: failed to load optimizer parameters.")
+                # Note that weight decay might be modified here
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
 
         state_dict.update(update_dict)
         self.model.load_state_dict(state_dict)
diff --git a/src/train.py b/src/train.py
index c434b8b..782336b 100644
--- a/src/train.py
+++ b/src/train.py
@@ -62,6 +62,7 @@ def parse_args():
     group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE',
                         help='number of epochs to train (default: %(default)s)')
     group_train.add_argument('--load-optim', action='store_true')
+    group_train.add_argument('--save-optim', action='store_true')
     group_train.add_argument('--resume', default='', type=str, metavar='PATH',
                         help='path to latest checkpoint')
     group_train.add_argument('--anew', action='store_true',
-- 
GitLab