diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py index 19154b7010f92a893bae81f8ea8dbce48e227e4f..ed93955ffbdd91cb02219d617c492c3aa14752cf 100644 --- a/qim3d/utils/augmentations.py +++ b/qim3d/utils/augmentations.py @@ -40,7 +40,7 @@ class Augmentation: self.transform_validation = transform_validation self.transform_test = transform_test - def augment(self, im_h, im_w, level=None): + def augment(self, im_h, im_w, type=None): """ Returns an albumentations.core.composition.Compose class depending on the augmentation level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. @@ -49,12 +49,19 @@ class Augmentation: Args: im_h (int): image height for resize. im_w (int): image width for resize. - level (str, optional): level of augmentation. + type (str, optional): level of augmentation. Raises: ValueError: If `level` is neither None, light, moderate nor heavy. """ + if type=='train': + level = self.transform_train + elif type=='validation': + level = self.transform_validation + elif type=='test': + level = self.transform_test + # Check if one of standard augmentation levels if level not in [None,'light','moderate','heavy']: raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") @@ -98,7 +105,7 @@ class Augmentation: A.HorizontalFlip(p = 0.7), A.VerticalFlip(p = 0.7), A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3), - A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)) + A.Affine(scale = [0.9,1.1], translate_percent = (-0.2,0.2), shear = (-5,5)) ] augment = A.Compose(level_aug + resize_aug + baseline_aug) diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index fb2a03845870beabe1b1c23970c451a3750ba67e..9dec610470e511d9295d039f8009635431aa44c5 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) - train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) - val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) - test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test)) + train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train')) + val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation')) + test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test')) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set))