diff --git a/config_EF_AC_Szada.yaml b/config_EF_AC_Szada.yaml index 7e969aed3ac85133561b7e505bcea0c5332f9549..0c9b9172fa4d2efea38ed74b124b6f1f2f8773fa 100644 --- a/config_EF_AC_Szada.yaml +++ b/config_EF_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_EF_AC_Tiszadob.yaml b/config_EF_AC_Tiszadob.yaml index 3c93b3a120c51f9ca439a68278a45826e783f82a..91632ddce86b521d5f860551e13032a2058ca24e 100644 --- a/config_EF_AC_Tiszadob.yaml +++ b/config_EF_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_EF_OSCD.yaml b/config_EF_OSCD.yaml index a91842cbf08dc599b7abece1390fdc549151bfe9..9902142254332ac316c750523f4b0c489cf40833 100644 --- a/config_EF_OSCD.yaml +++ b/config_EF_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_base.yaml b/config_base.yaml index 838ac12abb837780e934a6964ffc3d88c4e97510..f1b5ae959b85cab26407821c3706ed40efd39c85 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -22,6 +22,7 @@ batch_size: 8 num_epochs: 15 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_AC_Szada.yaml b/config_siamconc_AC_Szada.yaml index 4142e68416e556a94f943fb9afc764ab9dc0b769..5be549448d7570dd2a082d7e5af720d28a5f9cea 100644 --- a/config_siamconc_AC_Szada.yaml +++ b/config_siamconc_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_AC_Tiszadob.yaml b/config_siamconc_AC_Tiszadob.yaml index 8eee72e4f9a60e6ae029f8668c52dd3a7215b612..979facfe47a19d064f2cd0707898ea781f727af5 100644 --- a/config_siamconc_AC_Tiszadob.yaml +++ b/config_siamconc_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamconc_OSCD.yaml b/config_siamconc_OSCD.yaml index 6a2572626949064c269d671ca7fa09d2bdb7df3c..640f83dfe48257b7dbaa3dad7efd0b6cd586ac81 100644 --- a/config_siamconc_OSCD.yaml +++ b/config_siamconc_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_AC_Szada.yaml b/config_siamdiff_AC_Szada.yaml index 3c0ad37fa99072ea4bea3c67596c223efa5931b7..055f97092c06b4ed72973f65f4861a9fca72e2ff 100644 --- a/config_siamdiff_AC_Szada.yaml +++ b/config_siamdiff_AC_Szada.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_AC_Tiszadob.yaml b/config_siamdiff_AC_Tiszadob.yaml index 02f67ba1d19c18279dd4a69d51ee1afa7bc98a06..931fc0c49fe07eb7afb349807e1f763d3e07acab 100644 --- a/config_siamdiff_AC_Tiszadob.yaml +++ b/config_siamdiff_AC_Tiszadob.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/config_siamdiff_OSCD.yaml b/config_siamdiff_OSCD.yaml index 90a16710b2cdcdc5376c85e1994f5ae043c0d474..ca28ae30da2981461b07152eb2c132256233c332 100644 --- a/config_siamdiff_OSCD.yaml +++ b/config_siamdiff_OSCD.yaml @@ -22,6 +22,7 @@ batch_size: 32 num_epochs: 10 resume: '' load_optim: True +save_optim: True anew: False trace_freq: 1 device: cuda diff --git a/src/core/trainers.py b/src/core/trainers.py index 445b96ab2685efc7bc69788fd28da09eaf73517c..83330b3be8ef599b7d9390a8a5852f8651e50d16 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -106,7 +106,11 @@ class Trainer: acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch - self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best) + self._save_checkpoint( + self.model.state_dict(), + self.optimizer.state_dict() if self.ctx['save_optim'] else {}, + (max_acc, best_epoch), epoch+1, is_best + ) def evaluate(self): if self.checkpoint: @@ -164,11 +168,8 @@ class Trainer: else: self._init_max_acc_and_epoch = max_acc_and_epoch if self.ctx['load_optim'] and self.is_training: - try: - # Note that weight decay might be modified here - self.optimizer.load_state_dict(checkpoint['optimizer']) - except KeyError: - self.logger.warning("Warning: failed to load optimizer parameters.") + # Note that weight decay might be modified here + self.optimizer.load_state_dict(checkpoint['optimizer']) state_dict.update(update_dict) self.model.load_state_dict(state_dict) diff --git a/src/train.py b/src/train.py index c434b8b863f15707893deeb7b78b671225d534ce..782336b048b55f8e0345f84f47d566c92fe6a90d 100644 --- a/src/train.py +++ b/src/train.py @@ -62,6 +62,7 @@ def parse_args(): group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE', help='number of epochs to train (default: %(default)s)') group_train.add_argument('--load-optim', action='store_true') + group_train.add_argument('--save-optim', action='store_true') group_train.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') group_train.add_argument('--anew', action='store_true',