From 9126a3ce33e88b6fa23a5575c9b8d6b9a5bc6a77 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Sat, 2 May 2020 17:56:25 +0800 Subject: [PATCH] Add option to save optim state --- config_EF_AC_Szada.yaml | 1 + config_EF_AC_Tiszadob.yaml | 1 + config_EF_OSCD.yaml | 1 + config_base.yaml | 1 + config_siamconc_AC_Szada.yaml | 1 + config_siamconc_AC_Tiszadob.yaml | 1 + config_siamconc_OSCD.yaml | 1 + config_siamdiff_AC_Szada.yaml | 1 + config_siamdiff_AC_Tiszadob.yaml | 1 + config_siamdiff_OSCD.yaml | 1 + src/core/trainers.py | 13 +++++++------ src/train.py | 1 + 12 files changed, 18 insertions(+), 6 deletions(-) diff --git a/config_EF_AC_Szada.yaml b/config_EF_AC_Szada.yaml index 7e969ae..0c9b917 100644 --- a/config_EF_AC_Szada.yaml +++ b/config_EF_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_EF_AC_Tiszadob.yaml b/config_EF_AC_Tiszadob.yaml index 3c93b3a..91632dd 100644 --- a/config_EF_AC_Tiszadob.yaml +++ b/config_EF_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_EF_OSCD.yaml b/config_EF_OSCD.yaml index a91842c..9902142 100644 --- a/config_EF_OSCD.yaml +++ b/config_EF_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_base.yaml b/config_base.yaml index 838ac12..f1b5ae9 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -22,6 +22,7 @@ batch_size: 8 num_epochs: 15 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_AC_Szada.yaml b/config_siamconc_AC_Szada.yaml index 4142e68..5be5494 100644 --- a/config_siamconc_AC_Szada.yaml +++ b/config_siamconc_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_AC_Tiszadob.yaml b/config_siamconc_AC_Tiszadob.yaml index 8eee72e..979facf 100644 --- a/config_siamconc_AC_Tiszadob.yaml +++ b/config_siamconc_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_OSCD.yaml b/config_siamconc_OSCD.yaml index 6a25726..640f83d 100644 --- a/config_siamconc_OSCD.yaml +++ b/config_siamconc_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_AC_Szada.yaml b/config_siamdiff_AC_Szada.yaml index 3c0ad37..055f970 100644 --- a/config_siamdiff_AC_Szada.yaml +++ b/config_siamdiff_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_AC_Tiszadob.yaml b/config_siamdiff_AC_Tiszadob.yaml index 02f67ba..931fc0c 100644 --- a/config_siamdiff_AC_Tiszadob.yaml +++ b/config_siamdiff_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_OSCD.yaml b/config_siamdiff_OSCD.yaml index 90a1671..ca28ae3 100644 --- a/config_siamdiff_OSCD.yaml +++ b/config_siamdiff_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/src/core/trainers.py b/src/core/trainers.py index 445b96a..83330b3 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -106,7 +106,11 @@ class Trainer: acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch - self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best) + self._save_checkpoint( + self.model.state_dict(), + self.optimizer.state_dict() if self.ctx['save_optim'] else {}, + (max_acc, best_epoch), epoch+1, is_best + ) def evaluate(self): if self.checkpoint: @@ -164,11 +168,8 @@ class Trainer: else: self._init_max_acc_and_epoch = max_acc_and_epoch if self.ctx['load_optim'] and self.is_training: - try: - # Note that weight decay might be modified here - self.optimizer.load_state_dict(checkpoint['optimizer']) - except KeyError: - self.logger.warning("Warning: failed to load optimizer parameters.") + # Note that weight decay might be modified here + self.optimizer.load_state_dict(checkpoint['optimizer']) state_dict.update(update_dict) self.model.load_state_dict(state_dict) diff --git a/src/train.py b/src/train.py index c434b8b..782336b 100644 --- a/src/train.py +++ b/src/train.py @@ -62,6 +62,7 @@ def parse_args(): group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE', help='number of epochs to train (default: %(default)s)') group_train.add_argument('--load-optim', action='store_true') + group_train.add_argument('--save-optim', action='store_true') group_train.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') group_train.add_argument('--anew', action='store_true', -- GitLab