diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index 219c208c16c1b720c92e6db6f98722043d1be590..276e33f08304ec548bd637881a84934c4312cdac 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -11,7 +11,6 @@ class Augmentation:
         transform_test (str, optional): level of transformation for the test set.
         mean (float, optional): The mean value for normalizing pixel intensities.
         std (float, optional): The standard deviation value for normalizing pixel intensities.
-        is_3d (bool, optional): Specifies if the images are 3D or 2D. Default is True.
 
     Raises:
         ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
@@ -27,7 +26,6 @@ class Augmentation:
                  transform_test: str | None = None,
                  mean: float = 0.5, 
                  std: float = 0.5,
-                 is_3d: bool = True,
                 ):
 
         if resize not in ['crop', 'reshape', 'padding']:
@@ -39,7 +37,6 @@ class Augmentation:
         self.transform_train = transform_train
         self.transform_validation = transform_validation
         self.transform_test = transform_test
-        self.is_3d = is_3d
     
     def augment(self, img_shape: tuple, level: str | None = None):
         """
@@ -57,37 +54,31 @@ class Augmentation:
             RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd
         )
 
-        # Check if 2D or 3D
-        if len(img_shape) == 2:
-            im_h, im_w = img_shape
-        
-        elif len(img_shape) == 3:
+        # Check if image is 3D
+        if len(img_shape) == 3:
             im_d, im_h, im_w = img_shape
+        
+        else: 
+            msg = f"Invalid image shape: {img_shape}. Must be 3D."
+            raise ValueError(msg)
 
         # Check if one of standard augmentation levels
-        if level not in [None,'light','moderate','heavy']:
-            raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
+        if level not in [None,'light', 'moderate', 'heavy']:
+            msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
+            raise ValueError(msg)
 
         # Baseline augmentations
+        # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
         baseline_aug = [ToTensor()]
 
-        # For 2D, add normalization to the baseline augmentations
-        # TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
-        if not self.is_3d:
-            # baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
-            baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std))
-
         # Resize augmentations
         if self.resize == 'crop':
-            # resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))]
             resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))]
         
         elif self.resize == 'reshape':
-            # resize_aug = [Resize((im_d, im_h, im_w))]
             resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
         
         elif self.resize == 'padding':
-            # resize_aug = [SpatialPad((im_d, im_h, im_w))]
             resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
 
         # Level of augmentation
@@ -98,40 +89,25 @@ class Augmentation:
             resize_aug = []
 
         elif level == 'light':
-            # level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))]
+            # TODO: Do rotations along other axes?
             level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))]
         
         elif level == 'moderate':
-            # level_aug = [
-            #     RandRotate90(prob=1, spatial_axes=(0, 1)),
-            #     RandFlip(prob=0.3, spatial_axis=0),
-            #     RandFlip(prob=0.3, spatial_axis=1),
-            #     RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
-            #     RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
-            # ]
             level_aug = [
-                    RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
-                    RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
-                    RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
-                    RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
-                    RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
+                RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
+                RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
+                RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
+                RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
+                RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
                 ]
             
         elif level == 'heavy':
-            # level_aug = [
-            #     RandRotate90(prob=1, spatial_axes=(0, 1)),
-            #     RandFlip(prob=0.7, spatial_axis=0),
-            #     RandFlip(prob=0.7, spatial_axis=1),
-            #     RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
-            #     RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
-            # ]
-
             level_aug = [
-                    RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
-                    RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
-                    RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
-                    RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
-                    RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
+                RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
+                RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
+                RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
+                RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
+                RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
                 ]
             
         return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file