diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index ede3796ca5692d2ed2b230d6c14a679ebe72f59f..dc82e8e5a208ff10a0b9b7e51373208a881ac0f3 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -86,11 +86,11 @@ class Augmentation:
             level_aug = []
 
         elif level == 'light':
-            level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1, 2))] if self.is_3d else [RandRotate90(prob=1)]
+            level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)]
         
         elif level == 'moderate':
             level_aug = [
-                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
                 RandFlip(prob=0.3, spatial_axis=0),
                 RandFlip(prob=0.3, spatial_axis=1),
                 RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
@@ -99,7 +99,7 @@ class Augmentation:
 
         elif level == 'heavy':
             level_aug = [
-                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
                 RandFlip(prob=0.7, spatial_axis=0),
                 RandFlip(prob=0.7, spatial_axis=1),
                 RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),