diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index dc82e8e5a208ff10a0b9b7e51373208a881ac0f3..30a2a7f0df34ccea9917e51d6abd0d0f8aeff7b5 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -41,14 +41,12 @@ class Augmentation:
         self.transform_test = transform_test
         self.is_3d = is_3d
     
-    def augment(self, im_h: int, im_w: int, im_d: int | None = None, level: str | None = None):
+    def augment(self, img_shape: tuple, level: str | None = None):
         """
         Creates an augmentation pipeline based on the specified level.
 
         Args:
-            im_h (int): Height of the image.
-            im_w (int): Width of the image.
-            im_d (int, optional): Depth of the image (for 3D).
+            img_shape (tuple): Dimensions of the image.
             level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy'].
 
         Raises:
@@ -59,6 +57,13 @@ class Augmentation:
             RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad
         )
 
+        # Check if 2D or 3D
+        if len(img_shape) == 2:
+            im_h, im_w = img_shape
+        
+        elif len(img_shape) == 3:
+            im_d, im_h, im_w = img_shape
+
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 03c3c399a1b70aa6f70b9e551d0d179b70590a5d..0a8fca48a613dee02c361e987c54b9a7d678279b 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -271,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
 
     final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
 
-    train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train)) 
-    val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation))
-    test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test))
+    train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train)) 
+    val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation))
+    test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, level = augmentation.transform_test))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))