From 32fa3f08f167f1b48e43ae03e32c3b199e708aa0 Mon Sep 17 00:00:00 2001 From: s184364 <s184364@student.dtu.dk> Date: Mon, 22 Jan 2024 16:47:37 +0100 Subject: [PATCH] Changed to less drastic heavy augmentation, changed the way augmentation level is chosen in prepare_datasets. --- qim3d/utils/augmentations.py | 13 ++++++++++--- qim3d/utils/data.py | 6 +++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py index 19154b70..ed93955f 100644 --- a/qim3d/utils/augmentations.py +++ b/qim3d/utils/augmentations.py @@ -40,7 +40,7 @@ class Augmentation: self.transform_validation = transform_validation self.transform_test = transform_test - def augment(self, im_h, im_w, level=None): + def augment(self, im_h, im_w, type=None): """ Returns an albumentations.core.composition.Compose class depending on the augmentation level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. @@ -49,12 +49,19 @@ class Augmentation: Args: im_h (int): image height for resize. im_w (int): image width for resize. - level (str, optional): level of augmentation. + type (str, optional): level of augmentation. Raises: ValueError: If `level` is neither None, light, moderate nor heavy. """ + if type=='train': + level = self.transform_train + elif type=='validation': + level = self.transform_validation + elif type=='test': + level = self.transform_test + # Check if one of standard augmentation levels if level not in [None,'light','moderate','heavy']: raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") @@ -98,7 +105,7 @@ class Augmentation: A.HorizontalFlip(p = 0.7), A.VerticalFlip(p = 0.7), A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3), - A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)) + A.Affine(scale = [0.9,1.1], translate_percent = (-0.2,0.2), shear = (-5,5)) ] augment = A.Compose(level_aug + resize_aug + baseline_aug) diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index fb2a0384..9dec6104 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) - train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) - val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) - test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test)) + train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train')) + val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation')) + test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test')) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set)) -- GitLab