Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import os.path as osp
from random import randint
from functools import partial
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler
from skimage import io
from tqdm import tqdm
from core.trainer import Trainer
from utils.data_utils import (
to_array, to_pseudo_color,
normalize_8bit,
quantize_8bit as quantize
)
from utils.utils import mod_crop, HookHelper
from utils.metrics import (AverageMeter, Precision, Recall, Accuracy, F1Score)
class CDTrainer(Trainer):
def __init__(self, settings):
super().__init__(settings['model'], settings['dataset'], 'NLL', settings['optimizer'], settings)
self.tb_on = (hasattr(self.logger, 'log_path') or self.debug) and self.ctx['tb_on']
if self.tb_on:
# Initialize tensorboard
if hasattr(self.logger, 'log_path'):
tb_dir = self.path(
'log',
osp.join('tb', osp.splitext(osp.basename(self.logger.log_path))[0], '.'),
name='tb',
auto_make=True,
suffix=False
)
else:
tb_dir = self.path(
'log',
osp.join('tb', 'debug', '.'),
name='tb',
auto_make=True,
suffix=False
)
for root, dirs, files in os.walk(self.gpc.get_dir('tb'), False):
for f in files:
os.remove(osp.join(root, f))
for d in dirs:
os.rmdir(osp.join(root, d))
self.tb_writer = SummaryWriter(tb_dir)
self.logger.show_nl("\nTensorboard logdir: {}".format(osp.abspath(self.gpc.get_dir('tb'))))
self.tb_intvl = int(self.ctx['tb_intvl'])
# Global steps
self.train_step = 0
self.eval_step = 0
# Whether to save network output
self.out_dir = self.ctx['out_dir']
self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def init_learning_rate(self):
# Set learning rate adjustment strategy
if self.ctx['lr_mode'] == 'const':
return self.lr
else:
def _simple_scheduler_step(self, epoch, acc):
self.scheduler.step()
return self.scheduler.get_lr()[0]
def _scheduler_step_with_acc(self, epoch, acc):
self.scheduler.step(acc)
# Only return the lr of the first param group
return self.optimizer.param_groups[0]['lr']
lr_mode = self.ctx['lr_mode']
if lr_mode == 'step':
self.scheduler = lr_scheduler.StepLR(
self.optimizer, self.ctx['step'], gamma=0.5
)
self.adjust_learning_rate = partial(_simple_scheduler_step, self)
elif lr_mode == 'exp':
self.scheduler = lr_scheduler.ExponentialLR(
self.optimizer, gamma=0.9
)
self.adjust_learning_rate = partial(_simple_scheduler_step, self)
elif lr_mode == 'plateau':
if self.load_checkpoint:
self.logger.warn("The old state of the lr scheduler will not be restored.")
self.scheduler = lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.5, threshold=1e-4
)
self.adjust_learning_rate = partial(_scheduler_step_with_acc, self)
return self.optimizer.param_groups[0]['lr']
else:
raise NotImplementedError
if self.start_epoch > 0:
# Restore previous state
# FIXME: This will trigger pytorch warning "Detected call of `lr_scheduler.step()`
# before `optimizer.step()`" in pytorch 1.1.0 and later.
# Perhaps I should store the state of scheduler to a checkpoint file and restore it from disk.
last_epoch = self.start_epoch
while self.scheduler.last_epoch < last_epoch:
self.scheduler.step()
return self.scheduler.get_lr()[0]
def train_epoch(self, epoch):
losses = AverageMeter()
len_train = len(self.train_loader)
width = len(str(len_train))
start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
pb = tqdm(self.train_loader)
self.model.train()
for i, (t1, t2, tar) in enumerate(pb):
t1, t2, tar = t1.to(self.device), t2.to(self.device), tar.to(self.device)
show_imgs_on_tb = self.tb_on and (i%self.tb_intvl == 0)
prob = self.model(t1, t2)
loss = self.criterion(prob, tar)
losses.update(loss.item(), n=self.batch_size)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
desc = (start_pattern+" Loss: {:.4f} ({:.4f})").format(i+1, len_train, losses.val, losses.avg)
pb.set_description(desc)
if i % max(1, len_train//10) == 0:
self.logger.dump(desc)
if self.tb_on:
# Write to tensorboard
self.tb_writer.add_scalar("Train/loss", losses.val, self.train_step)
if show_imgs_on_tb:
self.tb_writer.add_image("Train/t1_picked", normalize_8bit(t1.detach()[0]), self.train_step)
self.tb_writer.add_image("Train/t2_picked", normalize_8bit(t2.detach()[0]), self.train_step)
self.tb_writer.add_image("Train/labels_picked", tar[0].unsqueeze(0), self.train_step)
self.tb_writer.flush()
self.train_step += 1
def evaluate_epoch(self, epoch):
self.logger.show_nl("Epoch: [{0}]".format(epoch))
losses = AverageMeter()
len_eval = len(self.eval_loader)
width = len(str(len_eval))
start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
pb = tqdm(self.eval_loader)
# Construct metrics
metrics = (Precision(), Recall(), F1Score(), Accuracy())
self.model.eval()
with torch.no_grad():
for i, (name, t1, t2, tar) in enumerate(pb):
if self.is_training and i >= self.val_iters:
# This saves time
pb.close()
self.logger.warn("Evaluation ends early.")
break
t1, t2, tar = t1.to(self.device), t2.to(self.device), tar.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, tar)
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
cm = to_array(torch.argmax(prob[0], 0)).astype('uint8')
tar = to_array(tar[0]).astype('uint8')
for m in metrics:
m.update(cm, tar)
desc = (start_pattern+" Loss: {:.4f} ({:.4f})").format(i+1, len_eval, losses.val, losses.avg)
for m in metrics:
desc += " {} {:.4f} ({:.4f})".format(m.__name__, m.val, m.avg)
pb.set_description(desc)
self.logger.dump(desc)
if self.tb_on:
self.tb_writer.add_image("Eval/t1", normalize_8bit(t1[0]), self.eval_step)
self.tb_writer.add_image("Eval/t2", normalize_8bit(t2[0]), self.eval_step)
self.tb_writer.add_image("Eval/labels", quantize(tar), self.eval_step, dataformats='HW')
prob = quantize(to_array(torch.exp(prob[0,1])))
self.tb_writer.add_image("Eval/prob", to_pseudo_color(prob), self.eval_step, dataformats='HWC')
self.tb_writer.add_image("Eval/cm", quantize(cm), self.eval_step, dataformats='HW')
self.eval_step += 1
if self.save:
self.save_image(name[0], quantize(cm), epoch)
if self.tb_on:
self.tb_writer.add_scalar("Eval/loss", losses.avg, self.eval_step)
self.tb_writer.add_scalars("Eval/metrics", {m.__name__.lower(): m.avg for m in metrics}, self.eval_step)
return metrics[2].avg # F1-score
def save_image(self, file_name, image, epoch):
file_path = osp.join(
'epoch_{}'.format(epoch),
self.out_dir,
file_name
)
out_path = self.path(
'out', file_path,
suffix=not self.ctx['suffix_off'],
auto_make=True,
underline=True
)
return io.imsave(out_path, image)
# def __del__(self):
# if self.tb_on:
# self.tb_writer.close()