From ed05746d4c0772bb197ac6b482bc3e08d7758bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk> Date: Tue, 11 Feb 2025 14:59:13 +0100 Subject: [PATCH] fixing augmentation errors --- qim3d/ml/_augmentations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index ede3796c..dc82e8e5 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -86,11 +86,11 @@ class Augmentation: level_aug = [] elif level == 'light': - level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1, 2))] if self.is_3d else [RandRotate90(prob=1)] + level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)] elif level == 'moderate': level_aug = [ - RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1), + RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandFlip(prob=0.3, spatial_axis=0), RandFlip(prob=0.3, spatial_axis=1), RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1), @@ -99,7 +99,7 @@ class Augmentation: elif level == 'heavy': level_aug = [ - RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1), + RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandFlip(prob=0.7, spatial_axis=0), RandFlip(prob=0.7, spatial_axis=1), RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3), -- GitLab