diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py
index 19154b7010f92a893bae81f8ea8dbce48e227e4f..ed93955ffbdd91cb02219d617c492c3aa14752cf 100644
--- a/qim3d/utils/augmentations.py
+++ b/qim3d/utils/augmentations.py
@@ -40,7 +40,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h, im_w, level=None):
+    def augment(self, im_h, im_w, type=None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
@@ -49,12 +49,19 @@ class Augmentation:
         Args:
             im_h (int): image height for resize.
             im_w (int): image width for resize.
-            level (str, optional): level of augmentation.
+            type (str, optional): level of augmentation.
 
         Raises:
             ValueError: If `level` is neither None, light, moderate nor heavy.
         """
         
+        if type=='train':   
+            level = self.transform_train
+        elif type=='validation':
+            level = self.transform_validation
+        elif type=='test':
+            level = self.transform_test
+
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
@@ -98,7 +105,7 @@ class Augmentation:
                 A.HorizontalFlip(p = 0.7),
                 A.VerticalFlip(p = 0.7),
                 A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
-                A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
+                A.Affine(scale = [0.9,1.1], translate_percent = (-0.2,0.2), shear = (-5,5))
             ]
 
         augment = A.Compose(level_aug + resize_aug + baseline_aug)
diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
index fb2a03845870beabe1b1c23970c451a3750ba67e..9dec610470e511d9295d039f8009635431aa44c5 100644
--- a/qim3d/utils/data.py
+++ b/qim3d/utils/data.py
@@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
         
     final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
 
-    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
-    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
-    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
+    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train'))
+    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation'))
+    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test'))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))