Skip to content
Snippets Groups Projects
Commit 9126a3ce authored by Bobholamovic's avatar Bobholamovic
Browse files

Add option to save optim state

parent 39d0b776
No related branches found
No related tags found
1 merge request!2Update outdated code
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 8
num_epochs: 15
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
trace_freq: 1
device: cuda
......
......@@ -106,7 +106,11 @@ class Trainer:
acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
self._save_checkpoint(
self.model.state_dict(),
self.optimizer.state_dict() if self.ctx['save_optim'] else {},
(max_acc, best_epoch), epoch+1, is_best
)
def evaluate(self):
if self.checkpoint:
......@@ -164,11 +168,8 @@ class Trainer:
else:
self._init_max_acc_and_epoch = max_acc_and_epoch
if self.ctx['load_optim'] and self.is_training:
try:
# Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer'])
except KeyError:
self.logger.warning("Warning: failed to load optimizer parameters.")
# Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer'])
state_dict.update(update_dict)
self.model.load_state_dict(state_dict)
......
......@@ -62,6 +62,7 @@ def parse_args():
group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE',
help='number of epochs to train (default: %(default)s)')
group_train.add_argument('--load-optim', action='store_true')
group_train.add_argument('--save-optim', action='store_true')
group_train.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint')
group_train.add_argument('--anew', action='store_true',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment