diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index ede3796ca5692d2ed2b230d6c14a679ebe72f59f..dc82e8e5a208ff10a0b9b7e51373208a881ac0f3 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -86,11 +86,11 @@ class Augmentation: level_aug = [] elif level == 'light': - level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1, 2))] if self.is_3d else [RandRotate90(prob=1)] + level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)] elif level == 'moderate': level_aug = [ - RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1), + RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandFlip(prob=0.3, spatial_axis=0), RandFlip(prob=0.3, spatial_axis=1), RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1), @@ -99,7 +99,7 @@ class Augmentation: elif level == 'heavy': level_aug = [ - RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1), + RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandFlip(prob=0.7, spatial_axis=0), RandFlip(prob=0.7, spatial_axis=1), RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),