Skip to content
Snippets Groups Projects

Update outdated code

Open manli requested to merge github/fork/Bobholamovic/master into master
1 file
+ 9
5
Compare changes
  • Side-by-side
  • Inline
@@ -2,76 +2,82 @@ import shutil
@@ -2,76 +2,82 @@ import shutil
import os
import os
from types import MappingProxyType
from types import MappingProxyType
from copy import deepcopy
from copy import deepcopy
 
from abc import ABCMeta, abstractmethod
import torch
import torch
from skimage import io
from tqdm import tqdm
import constants
import constants
from data.common import to_array
from .misc import Logger, OutPathGetter, R
from utils.misc import R
from .factories import (model_factory, optim_factory, critn_factory, data_factory)
from utils.metrics import AverageMeter
from utils.utils import mod_crop
from .factories import (model_factory, optim_factory, critn_factory, data_factory, metric_factory)
class Trainer:
class Trainer(metaclass=ABCMeta):
def __init__(self, model, dataset, criterion, optimizer, settings):
def __init__(self, model, dataset, criterion, optimizer, settings):
super().__init__()
super().__init__()
 
# Make a copy of settings in case of unexpected changes
context = deepcopy(settings)
context = deepcopy(settings)
self.ctx = MappingProxyType(vars(context))
# self.ctx is a proxy so that context will be read-only outside __init__
self.mode = ('train', 'val').index(context.cmd)
self.ctx = MappingProxyType(context)
self.mode = ('train', 'eval').index(context['cmd'])
self.logger = R['LOGGER']
self.debug = context['debug_on']
self.gpc = R['GPC'] # Global Path Controller
self.log = not context['log_off']
 
self.batch_size = context['batch_size']
 
self.checkpoint = context['resume']
 
self.load_checkpoint = (len(self.checkpoint)>0)
 
self.num_epochs = context['num_epochs']
 
self.lr = float(context['lr'])
 
self.track_intvl = int(context['track_intvl'])
 
self.device = torch.device(context['device'])
 
 
self.gpc = OutPathGetter(
 
root=os.path.join(context['exp_dir'], context['tag']),
 
suffix=context['suffix']
 
) # Global Path Controller
 
 
self.logger = Logger(
 
scrn=True,
 
log_dir=self.gpc.get_dir('log') if self.log else '',
 
phase=context['cmd']
 
)
self.path = self.gpc.get_path
self.path = self.gpc.get_path
self.batch_size = context.batch_size
for k, v in sorted(context.items()):
self.checkpoint = context.resume
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context.num_epochs
self.lr = float(context.lr)
self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir
self.trace_freq = int(context.trace_freq)
self.device = torch.device(context.device)
self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()):
self.logger.show("{}: {}".format(k,v))
self.logger.show("{}: {}".format(k,v))
self.model = model_factory(model, context)
self.model = model_factory(model, context)
self.model.to(self.device)
self.model.to(self.device)
self.criterion = critn_factory(criterion, context)
self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device)
self.criterion.to(self.device)
self.metrics = metric_factory(context.metrics, context)
if self.is_training:
if self.is_training:
self.train_loader = data_factory(dataset, 'train', context)
self.train_loader = data_factory(dataset, 'train', context)
self.val_loader = data_factory(dataset, 'val', context)
self.eval_loader = data_factory(dataset, 'eval', context)
self.optimizer = optim_factory(optimizer, self.model, context)
self.optimizer = optim_factory(optimizer, self.model, context)
else:
else:
self.val_loader = data_factory(dataset, 'val', context)
self.eval_loader = data_factory(dataset, 'eval', context)
self.start_epoch = 0
self.start_epoch = 0
self._init_max_acc_and_epoch = (0.0, 0)
self._init_acc_epoch = (0.0, -1)
@property
@property
def is_training(self):
def is_training(self):
return self.mode == 0
return self.mode == 0
 
@abstractmethod
def train_epoch(self, epoch):
def train_epoch(self, epoch):
raise NotImplementedError
pass
def validate_epoch(self, epoch=0, store=False):
@abstractmethod
raise NotImplementedError
def evaluate_epoch(self, epoch):
 
return 0.0
def _write_prompt(self):
def _write_prompt(self):
self.logger.dump(input("\nWrite some notes: "))
self.logger.dump(input("\nWrite some notes: "))
def run(self):
def run(self):
if self.is_training:
if self.is_training:
self._write_prompt()
if self.log and not self.debug:
 
self._write_prompt()
self.train()
self.train()
else:
else:
self.evaluate()
self.evaluate()
@@ -80,23 +86,20 @@ class Trainer:
@@ -80,23 +86,20 @@ class Trainer:
if self.load_checkpoint:
if self.load_checkpoint:
self._resume_from_checkpoint()
self._resume_from_checkpoint()
max_acc, best_epoch = self._init_max_acc_and_epoch
max_acc, best_epoch = self._init_acc_epoch
 
lr = self.init_learning_rate()
for epoch in range(self.start_epoch, self.num_epochs):
for epoch in range(self.start_epoch, self.num_epochs):
lr = self._adjust_learning_rate(epoch)
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
# Train for one epoch
# Train for one epoch
 
self.model.train()
self.train_epoch(epoch)
self.train_epoch(epoch)
# Clear the history of metric objects
# Evaluate the model
for m in self.metrics:
self.logger.show_nl("Evaluate")
m.reset()
self.model.eval()
acc = self.evaluate_epoch(epoch=epoch)
# Evaluate the model on validation set
self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save)
is_best = acc > max_acc
is_best = acc > max_acc
if is_best:
if is_best:
@@ -105,77 +108,79 @@ class Trainer:
@@ -105,77 +108,79 @@ class Trainer:
self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
acc, epoch, max_acc, best_epoch))
acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch
# Do not save checkpoints in debugging mode
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
if not self.debug:
 
self._save_checkpoint(
 
self.model.state_dict(),
 
self.optimizer.state_dict() if self.ctx['save_optim'] else {},
 
(max_acc, best_epoch), epoch, is_best
 
)
 
 
lr = self.adjust_learning_rate(epoch, acc)
def evaluate(self):
def evaluate(self):
if self.checkpoint:
if self.checkpoint:
if self._resume_from_checkpoint():
if self._resume_from_checkpoint():
self.validate_epoch(self.ckp_epoch, self.save)
self.model.eval()
else:
self.evaluate_epoch(self.start_epoch)
self.logger.warning("Warning: no checkpoint assigned!")
def _adjust_learning_rate(self, epoch):
if self.ctx['lr_mode'] == 'step':
lr = self.lr * (0.5 ** (epoch // self.ctx['step']))
elif self.ctx['lr_mode'] == 'poly':
lr = self.lr * (1 - epoch / self.num_epochs) ** 1.1
elif self.ctx['lr_mode'] == 'const':
lr = self.lr
else:
else:
raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode']))
self.logger.error("No checkpoint assigned!")
 
 
def init_learning_rate(self):
 
return self.lr
for param_group in self.optimizer.param_groups:
def adjust_learning_rate(self, epoch, acc):
param_group['lr'] = lr
return self.lr
return lr
def _resume_from_checkpoint(self):
def _resume_from_checkpoint(self):
## XXX: This could be slow!
# XXX: This could be slow!
if not os.path.isfile(self.checkpoint):
if not os.path.isfile(self.checkpoint):
self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
return False
return False
self.logger.show("=> Loading checkpoint '{}'".format(
self.logger.show("=> Loading checkpoint '{}'...".format(self.checkpoint))
self.checkpoint))
checkpoint = torch.load(self.checkpoint, map_location=self.device)
checkpoint = torch.load(self.checkpoint, map_location=self.device)
state_dict = self.model.state_dict()
state_dict = self.model.state_dict()
ckp_dict = checkpoint.get('state_dict', checkpoint)
ckp_dict = checkpoint.get('state_dict', checkpoint)
update_dict = {k:v for k,v in ckp_dict.items()
update_dict = {
if k in state_dict and state_dict[k].shape == v.shape}
k:v for k,v in ckp_dict.items()
 
if k in state_dict and state_dict[k].shape == v.shape and state_dict[k].dtype == v.dtype
 
}
num_to_update = len(update_dict)
num_to_update = len(update_dict)
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if not self.is_training and (num_to_update < len(state_dict)):
if not self.is_training and (num_to_update < len(state_dict)):
self.logger.error("=> Mismatched checkpoint for evaluation")
self.logger.error("=> Mismatched checkpoint for evaluation")
return False
return False
self.logger.warning("Warning: trying to load an mismatched checkpoint.")
self.logger.warn("Trying to load a mismatched checkpoint.")
if num_to_update == 0:
if num_to_update == 0:
self.logger.error("=> No parameter is to be loaded.")
self.logger.error("=> No parameter is to be loaded.")
return False
return False
else:
else:
self.logger.warning("=> {} params are to be loaded.".format(num_to_update))
self.logger.warn("=> {} params are to be loaded.".format(num_to_update))
elif (not self.ctx['anew']) or not self.is_training:
ckp_epoch = -1
self.start_epoch = checkpoint.get('epoch', 0)
else:
max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch))
ckp_epoch = checkpoint.get('epoch', -1)
# For backward compatibility
self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch))
if isinstance(max_acc_and_epoch, (float, int)):
if not self.is_training:
self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
self.start_epoch = ckp_epoch
else:
elif not self.ctx['anew']:
self._init_max_acc_and_epoch = max_acc_and_epoch
self.start_epoch = ckp_epoch+1
if self.ctx['load_optim'] and self.is_training:
if self.ctx['load_optim']:
try:
# XXX: Note that weight decay might be modified here.
# Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.warn("Weight decay might have been modified.")
except KeyError:
self.logger.warning("Warning: failed to load optimizer parameters.")
state_dict.update(update_dict)
state_dict.update(update_dict)
self.model.load_state_dict(state_dict)
self.model.load_state_dict(state_dict)
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format(
if ckp_epoch == -1:
self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch
self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint))
))
else:
 
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format(
 
self.checkpoint, ckp_epoch, *self._init_acc_epoch
 
))
return True
return True
def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best):
def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best):
@@ -186,117 +191,46 @@ class Trainer:
@@ -186,117 +191,46 @@ class Trainer:
'max_acc': max_acc
'max_acc': max_acc
}
}
# Save history
# Save history
history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True)
# epoch+1 instead of epoch is contained in the checkpoint name so that it will be easy for
if epoch % self.trace_freq == 0:
# one to recognize "the next start_epoch".
 
history_path = self.path(
 
'weight', constants.CKP_COUNTED.format(e=epoch+1),
 
suffix=True
 
)
 
if epoch % self.track_intvl == 0:
torch.save(state, history_path)
torch.save(state, history_path)
# Save latest
# Save latest
latest_path = self.path(
latest_path = self.path(
'weight', constants.CKP_LATEST,
'weight', constants.CKP_LATEST,
underline=True
suffix=True
)
)
torch.save(state, latest_path)
torch.save(state, latest_path)
if is_best:
if is_best:
shutil.copyfile(
shutil.copyfile(
latest_path, self.path(
latest_path, self.path(
'weight', constants.CKP_BEST,
'weight', constants.CKP_BEST,
underline=True
suffix=True
)
)
)
)
@property
def ckp_epoch(self):
# Get current epoch of the checkpoint
# For dismatched ckp or no ckp, set to 0
return max(self.start_epoch-1, 0)
def save_image(self, file_name, image, epoch):
file_path = os.path.join(
'epoch_{}/'.format(epoch),
self.out_dir,
file_name
)
out_path = self.path(
'out', file_path,
suffix=not self.suffix_off,
auto_make=True,
underline=True
)
return io.imsave(out_path, image)
class CDTrainer(Trainer):
def __init__(self, arch, dataset, optimizer, settings):
super().__init__(arch, dataset, 'NLL', optimizer, settings)
def train_epoch(self, epoch):
class TrainerSwitcher:
losses = AverageMeter()
r"""A simple utility class to help dispatch actions to different trainers."""
len_train = len(self.train_loader)
def __init__(self, *pairs):
pb = tqdm(self.train_loader)
self._trainer_list = list(pairs)
self.model.train()
for i, (t1, t2, label) in enumerate(pb):
def __call__(self, args, return_obj=True):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
for p, t in self._trainer_list:
if p(args):
prob = self.model(t1, t2)
return t(args) if return_obj else t
return None
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Compute gradients and do SGD step
def add_item(self, predicate, trainer):
self.optimizer.zero_grad()
# Newly added items have higher priority
loss.backward()
self._trainer_list.insert(0, (predicate, trainer))
self.optimizer.step()
desc = self.logger.make_desc(
def add_default(self, trainer):
i+1, len_train,
self._trainer_list.append((lambda: True, trainer))
('loss', losses, '.4f')
)
pb.set_description(desc)
self.logger.dump(desc)
def validate_epoch(self, epoch=0, store=False):
self.logger.show_nl("Epoch: [{0}]".format(epoch))
losses = AverageMeter()
len_val = len(self.val_loader)
pb = tqdm(self.val_loader)
self.model.eval()
with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb):
if self.is_training and i >= 16:
# Do not validate all images on training phase
pb.close()
self.logger.warning("validation ends early")
break
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
label = to_array(label[0]).astype('uint8')
for m in self.metrics:
m.update(CM, label)
desc = self.logger.make_desc(
i+1, len_val,
('loss', losses, '.4f'),
*(
(m.__name__, m, '.4f')
for m in self.metrics
)
)
pb.set_description(desc)
self.logger.dump(desc)
if store:
self.save_image(name[0], CM*255, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
R.register('Trainer_switcher', TrainerSwitcher())
\ No newline at end of file
\ No newline at end of file
Loading