Skip to content
Snippets Groups Projects
Commit 3a95fb80 authored by Bobholamovic's avatar Bobholamovic
Browse files

Refactor framework

parent f68ed034
No related branches found
No related tags found
1 merge request!2Update outdated code
...@@ -53,16 +53,16 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta): ...@@ -53,16 +53,16 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta):
raise FileNotFoundError raise FileNotFoundError
# phase stands for the working mode, # phase stands for the working mode,
# 'train' for training and 'eval' for validating or testing. # 'train' for training and 'eval' for validating or testing.
if phase not in ('train', 'eval'): # if phase not in ('train', 'eval'):
raise ValueError("Invalid phase") # raise ValueError("Invalid phase")
# subset is the sub-dataset to use. # subset is the sub-dataset to use.
# For some datasets there are three subsets, # For some datasets there are three subsets,
# while for others there are only train and test(val). # while for others there are only train and test(val).
if subset not in ('train', 'val', 'test'): # if subset not in ('train', 'val', 'test'):
raise ValueError("Invalid subset") # raise ValueError("Invalid subset")
self.phase = phase self.phase = phase
self.transforms = transforms self.transforms = transforms
self.repeats = int(repeats) self.repeats = repeats
# Use 'train' subset during training. # Use 'train' subset during training.
self.subset = 'train' if self.phase == 'train' else subset self.subset = 'train' if self.phase == 'train' else subset
......
...@@ -31,8 +31,8 @@ def _isseq(x): return isinstance(x, (tuple, list)) ...@@ -31,8 +31,8 @@ def _isseq(x): return isinstance(x, (tuple, list))
class Transform: class Transform:
def __init__(self, rand_state=False, prob_apply=1.0): def __init__(self, rand_state=False, prob_apply=1.0):
self._rand_state = bool(rand_state) self._rand_state = rand_state
self.prob_apply = float(prob_apply) self.prob_apply = prob_apply
def _transform(self, x, params): def _transform(self, x, params):
raise NotImplementedError raise NotImplementedError
...@@ -100,7 +100,7 @@ class Scale(Transform): ...@@ -100,7 +100,7 @@ class Scale(Transform):
raise ValueError raise ValueError
self.scale = tuple(scale) self.scale = tuple(scale)
else: else:
self.scale = float(scale) self.scale = scale
def _transform(self, x, params): def _transform(self, x, params):
if self._rand_state: if self._rand_state:
...@@ -138,10 +138,7 @@ class FlipRotate(Transform): ...@@ -138,10 +138,7 @@ class FlipRotate(Transform):
_DIRECTIONS = ('ud', 'lr', '90', '180', '270') _DIRECTIONS = ('ud', 'lr', '90', '180', '270')
def __init__(self, direction=None, prob_apply=1.0): def __init__(self, direction=None, prob_apply=1.0):
super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply) super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply)
if direction is not None: self.direction = direction
if direction not in self._DIRECTIONS:
raise ValueError("Invalid direction")
self.direction = direction
def _transform(self, x, params): def _transform(self, x, params):
if self._rand_state: if self._rand_state:
...@@ -212,9 +209,6 @@ class Crop(Transform): ...@@ -212,9 +209,6 @@ class Crop(Transform):
if _no_bounds: if _no_bounds:
if crop_size is None: if crop_size is None:
raise TypeError("crop_size should be specified if bounds is set to None.") raise TypeError("crop_size should be specified if bounds is set to None.")
else:
if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)):
raise ValueError("Invalid bounds")
self.bounds = bounds self.bounds = bounds
self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size) self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size)
...@@ -287,12 +281,12 @@ class Shift(Transform): ...@@ -287,12 +281,12 @@ class Shift(Transform):
if _isseq(xshift): if _isseq(xshift):
self.xshift = tuple(xshift) self.xshift = tuple(xshift)
else: else:
self.xshift = float(xshift) self.xshift = xshift
if _isseq(yshift): if _isseq(yshift):
self.yshift = tuple(yshift) self.yshift = tuple(yshift)
else: else:
self.yshift = float(yshift) self.yshift = yshift
self.circular = circular self.circular = circular
...@@ -368,12 +362,12 @@ class ContrastBrightScale(_ValueTransform): ...@@ -368,12 +362,12 @@ class ContrastBrightScale(_ValueTransform):
if _isseq(alpha): if _isseq(alpha):
self.alpha = tuple(alpha) self.alpha = tuple(alpha)
else: else:
self.alpha = float(alpha) self.alpha = alpha
if _isseq(beta): if _isseq(beta):
self.beta = tuple(beta) self.beta = tuple(beta)
else: else:
self.beta = float(beta) self.beta = beta
@_ValueTransform.keep_range @_ValueTransform.keep_range
def _transform(self, x, params): def _transform(self, x, params):
...@@ -406,8 +400,8 @@ class BrightnessScale(ContrastBrightScale): ...@@ -406,8 +400,8 @@ class BrightnessScale(ContrastBrightScale):
class AddGaussNoise(_ValueTransform): class AddGaussNoise(_ValueTransform):
def __init__(self, mu=0.0, sigma=0.1, prob_apply=1.0, limit=(0, 255)): def __init__(self, mu=0.0, sigma=0.1, prob_apply=1.0, limit=(0, 255)):
super().__init__(True, prob_apply, limit) super().__init__(True, prob_apply, limit)
self.mu = float(mu) self.mu = mu
self.sigma = float(sigma) self.sigma = sigma
@_ValueTransform.keep_range @_ValueTransform.keep_range
def _transform(self, x, params): def _transform(self, x, params):
......
...@@ -58,7 +58,7 @@ class CDTrainer(Trainer): ...@@ -58,7 +58,7 @@ class CDTrainer(Trainer):
self.out_dir = self.ctx['out_dir'] self.out_dir = self.ctx['out_dir']
self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug
self.val_iters = float(self.ctx['val_iters']) self.val_iters = self.ctx['val_iters']
def init_learning_rate(self): def init_learning_rate(self):
# Set learning rate adjustment strategy # Set learning rate adjustment strategy
......
...@@ -9,7 +9,7 @@ class AverageMeter: ...@@ -9,7 +9,7 @@ class AverageMeter:
super().__init__() super().__init__()
if callback is not None: if callback is not None:
self.calculate = callback self.calculate = callback
self.calc_avg = bool(calc_avg) self.calc_avg = calc_avg
self.reset() self.reset()
def calculate(self, *args): def calculate(self, *args):
...@@ -43,10 +43,6 @@ class AverageMeter: ...@@ -43,10 +43,6 @@ class AverageMeter:
class Metric(AverageMeter): class Metric(AverageMeter):
__name__ = 'Metric' __name__ = 'Metric'
def __init__(self, n_classes=2, mode='separ', reduction='binary'): def __init__(self, n_classes=2, mode='separ', reduction='binary'):
if mode not in ('accum', 'separ'):
raise ValueError("Invalid working mode")
if reduction not in ('mean', 'none', 'binary'):
raise ValueError("Invalid reduction type")
self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False) self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False)
self.mode = mode self.mode = mode
if reduction == 'binary' and n_classes != 2: if reduction == 'binary' and n_classes != 2:
...@@ -63,6 +59,8 @@ class Metric(AverageMeter): ...@@ -63,6 +59,8 @@ class Metric(AverageMeter):
cm = self._cm.sum cm = self._cm.sum
elif self.mode == 'separ': elif self.mode == 'separ':
cm = self._cm.val cm = self._cm.val
else:
raise ValueError("Invalid working mode")
if self.reduction == 'none': if self.reduction == 'none':
# Do not reduce size # Do not reduce size
...@@ -73,6 +71,8 @@ class Metric(AverageMeter): ...@@ -73,6 +71,8 @@ class Metric(AverageMeter):
elif self.reduction == 'binary': elif self.reduction == 'binary':
# The pos_class be 1 # The pos_class be 1
return self._calculate_metric(cm)[1] return self._calculate_metric(cm)[1]
else:
raise ValueError("Invalid reduction type")
def reset(self): def reset(self):
super().reset() super().reset()
......
...@@ -55,9 +55,6 @@ class HookHelper: ...@@ -55,9 +55,6 @@ class HookHelper:
self.fetch_dict = fetch_dict self.fetch_dict = fetch_dict
self.out_dict = out_dict self.out_dict = out_dict
self._handles = [] self._handles = []
if hook_type not in ('forward_in', 'forward_out', 'backward'):
raise NotImplementedError("Hook type is not implemented.")
self.hook_type = hook_type self.hook_type = hook_type
def __enter__(self): def __enter__(self):
...@@ -106,7 +103,7 @@ class HookHelper: ...@@ -106,7 +103,7 @@ class HookHelper:
) )
) )
else: else:
raise NotImplementedError raise NotImplementedError("Hook type is not implemented.")
def __exit__(self, exc_type, exc_val, ext_tb): def __exit__(self, exc_type, exc_val, ext_tb):
for handle in self._handles: for handle in self._handles:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment