diff --git a/src/core/data.py b/src/core/data.py index 13a2ef4b30f0aa9590c0b7701ad56772eb467849..e5e7032e81f76313c3bf972d6de28f72923dbcce 100644 --- a/src/core/data.py +++ b/src/core/data.py @@ -53,16 +53,16 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta): raise FileNotFoundError # phase stands for the working mode, # 'train' for training and 'eval' for validating or testing. - if phase not in ('train', 'eval'): - raise ValueError("Invalid phase") + # if phase not in ('train', 'eval'): + # raise ValueError("Invalid phase") # subset is the sub-dataset to use. # For some datasets there are three subsets, # while for others there are only train and test(val). - if subset not in ('train', 'val', 'test'): - raise ValueError("Invalid subset") + # if subset not in ('train', 'val', 'test'): + # raise ValueError("Invalid subset") self.phase = phase self.transforms = transforms - self.repeats = int(repeats) + self.repeats = repeats # Use 'train' subset during training. self.subset = 'train' if self.phase == 'train' else subset diff --git a/src/data/augmentations.py b/src/data/augmentations.py index 826182944dea5bc6dc6d4a7cc48d151269177c46..45b697a451c5519bb030bbf9a06829ad5f6c3617 100644 --- a/src/data/augmentations.py +++ b/src/data/augmentations.py @@ -31,8 +31,8 @@ def _isseq(x): return isinstance(x, (tuple, list)) class Transform: def __init__(self, rand_state=False, prob_apply=1.0): - self._rand_state = bool(rand_state) - self.prob_apply = float(prob_apply) + self._rand_state = rand_state + self.prob_apply = prob_apply def _transform(self, x, params): raise NotImplementedError @@ -100,7 +100,7 @@ class Scale(Transform): raise ValueError self.scale = tuple(scale) else: - self.scale = float(scale) + self.scale = scale def _transform(self, x, params): if self._rand_state: @@ -138,10 +138,7 @@ class FlipRotate(Transform): _DIRECTIONS = ('ud', 'lr', '90', '180', '270') def __init__(self, direction=None, prob_apply=1.0): super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply) - if direction is not None: - if direction not in self._DIRECTIONS: - raise ValueError("Invalid direction") - self.direction = direction + self.direction = direction def _transform(self, x, params): if self._rand_state: @@ -212,9 +209,6 @@ class Crop(Transform): if _no_bounds: if crop_size is None: raise TypeError("crop_size should be specified if bounds is set to None.") - else: - if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)): - raise ValueError("Invalid bounds") self.bounds = bounds self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size) @@ -287,12 +281,12 @@ class Shift(Transform): if _isseq(xshift): self.xshift = tuple(xshift) else: - self.xshift = float(xshift) + self.xshift = xshift if _isseq(yshift): self.yshift = tuple(yshift) else: - self.yshift = float(yshift) + self.yshift = yshift self.circular = circular @@ -368,12 +362,12 @@ class ContrastBrightScale(_ValueTransform): if _isseq(alpha): self.alpha = tuple(alpha) else: - self.alpha = float(alpha) + self.alpha = alpha if _isseq(beta): self.beta = tuple(beta) else: - self.beta = float(beta) + self.beta = beta @_ValueTransform.keep_range def _transform(self, x, params): @@ -406,8 +400,8 @@ class BrightnessScale(ContrastBrightScale): class AddGaussNoise(_ValueTransform): def __init__(self, mu=0.0, sigma=0.1, prob_apply=1.0, limit=(0, 255)): super().__init__(True, prob_apply, limit) - self.mu = float(mu) - self.sigma = float(sigma) + self.mu = mu + self.sigma = sigma @_ValueTransform.keep_range def _transform(self, x, params): diff --git a/src/impl/trainers/cd_trainer.py b/src/impl/trainers/cd_trainer.py index e0710868f9b7529f3c8b56aa67e6aa609fa5273f..85a885432ae21e957c6203dbae5bbad5c785f63e 100644 --- a/src/impl/trainers/cd_trainer.py +++ b/src/impl/trainers/cd_trainer.py @@ -58,7 +58,7 @@ class CDTrainer(Trainer): self.out_dir = self.ctx['out_dir'] self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug - self.val_iters = float(self.ctx['val_iters']) + self.val_iters = self.ctx['val_iters'] def init_learning_rate(self): # Set learning rate adjustment strategy diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 291151e62ff723f8f7524bbef94a0288e45fb7b9..13a9baae4602caf5830838250935d54a100c1b30 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -9,7 +9,7 @@ class AverageMeter: super().__init__() if callback is not None: self.calculate = callback - self.calc_avg = bool(calc_avg) + self.calc_avg = calc_avg self.reset() def calculate(self, *args): @@ -43,10 +43,6 @@ class AverageMeter: class Metric(AverageMeter): __name__ = 'Metric' def __init__(self, n_classes=2, mode='separ', reduction='binary'): - if mode not in ('accum', 'separ'): - raise ValueError("Invalid working mode") - if reduction not in ('mean', 'none', 'binary'): - raise ValueError("Invalid reduction type") self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False) self.mode = mode if reduction == 'binary' and n_classes != 2: @@ -63,6 +59,8 @@ class Metric(AverageMeter): cm = self._cm.sum elif self.mode == 'separ': cm = self._cm.val + else: + raise ValueError("Invalid working mode") if self.reduction == 'none': # Do not reduce size @@ -73,6 +71,8 @@ class Metric(AverageMeter): elif self.reduction == 'binary': # The pos_class be 1 return self._calculate_metric(cm)[1] + else: + raise ValueError("Invalid reduction type") def reset(self): super().reset() diff --git a/src/utils/utils.py b/src/utils/utils.py index 6bdadb11b67d6030c3da6338ed023c1f6ec89f3b..f6afd01a1a90562d27f4abe39be155fb4e468f7f 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -55,9 +55,6 @@ class HookHelper: self.fetch_dict = fetch_dict self.out_dict = out_dict self._handles = [] - - if hook_type not in ('forward_in', 'forward_out', 'backward'): - raise NotImplementedError("Hook type is not implemented.") self.hook_type = hook_type def __enter__(self): @@ -106,7 +103,7 @@ class HookHelper: ) ) else: - raise NotImplementedError + raise NotImplementedError("Hook type is not implemented.") def __exit__(self, exc_type, exc_val, ext_tb): for handle in self._handles: