From ed05746d4c0772bb197ac6b482bc3e08d7758bf1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk>
Date: Tue, 11 Feb 2025 14:59:13 +0100
Subject: [PATCH] fixing augmentation errors

---
 qim3d/ml/_augmentations.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index ede3796c..dc82e8e5 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -86,11 +86,11 @@ class Augmentation:
             level_aug = []
 
         elif level == 'light':
-            level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1, 2))] if self.is_3d else [RandRotate90(prob=1)]
+            level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)]
         
         elif level == 'moderate':
             level_aug = [
-                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
                 RandFlip(prob=0.3, spatial_axis=0),
                 RandFlip(prob=0.3, spatial_axis=1),
                 RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
@@ -99,7 +99,7 @@ class Augmentation:
 
         elif level == 'heavy':
             level_aug = [
-                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
                 RandFlip(prob=0.7, spatial_axis=0),
                 RandFlip(prob=0.7, spatial_axis=1),
                 RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
-- 
GitLab