Skip to content
Snippets Groups Projects

Update outdated code

12 files
+ 18
6
Compare changes
  • Side-by-side
  • Inline

Files

+ 7
6
@@ -106,7 +106,11 @@ class Trainer:
@@ -106,7 +106,11 @@ class Trainer:
acc, epoch, max_acc, best_epoch))
acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch
# The checkpoint saves next epoch
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
self._save_checkpoint(
 
self.model.state_dict(),
 
self.optimizer.state_dict() if self.ctx['save_optim'] else {},
 
(max_acc, best_epoch), epoch+1, is_best
 
)
def evaluate(self):
def evaluate(self):
if self.checkpoint:
if self.checkpoint:
@@ -164,11 +168,8 @@ class Trainer:
@@ -164,11 +168,8 @@ class Trainer:
else:
else:
self._init_max_acc_and_epoch = max_acc_and_epoch
self._init_max_acc_and_epoch = max_acc_and_epoch
if self.ctx['load_optim'] and self.is_training:
if self.ctx['load_optim'] and self.is_training:
try:
# Note that weight decay might be modified here
# Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
except KeyError:
self.logger.warning("Warning: failed to load optimizer parameters.")
state_dict.update(update_dict)
state_dict.update(update_dict)
self.model.load_state_dict(state_dict)
self.model.load_state_dict(state_dict)
Loading