From ab581084a838825a059217363e60ae13d0947e13 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk>
Date: Tue, 11 Feb 2025 11:01:28 +0100
Subject: [PATCH] refactored augmentations to use monai instead

---
 qim3d/ml/_augmentations.py | 89 ++++++++++++++++++++------------------
 1 file changed, 46 insertions(+), 43 deletions(-)

diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index e12b1a51..ede3796c 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -1,8 +1,8 @@
-"""Class for choosing the level of data augmentations with albumentations"""
+"""Class for choosing the level of data augmentations with MONAI"""
 
 class Augmentation:
     """
-    Class for defining image augmentation transformations using the Albumentations library.
+    Class for defining image augmentation transformations using the MONAI library.
         
     Args:
         resize (str, optional): Specifies how the images should be reshaped to the appropriate size.
@@ -11,6 +11,7 @@ class Augmentation:
         transform_test (str, optional): level of transformation for the test set.
         mean (float, optional): The mean value for normalizing pixel intensities.
         std (float, optional): The standard deviation value for normalizing pixel intensities.
+        is_3d (bool, optional): Specifies if the images are 3D or 2D. Default is True.
 
     Raises:
         ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
@@ -25,7 +26,8 @@ class Augmentation:
                  transform_validation: str | None = None,
                  transform_test: str | None = None,
                  mean: float = 0.5, 
-                 std: float = 0.5
+                 std: float = 0.5,
+                 is_3d: bool = True,
                 ):
 
         if resize not in ['crop', 'reshape', 'padding']:
@@ -37,70 +39,71 @@ class Augmentation:
         self.transform_train = transform_train
         self.transform_validation = transform_validation
         self.transform_test = transform_test
+        self.is_3d = is_3d
     
-    def augment(self, im_h: int, im_w: int, level: bool | None = None):
+    def augment(self, im_h: int, im_w: int, im_d: int | None = None, level: str | None = None):
         """
-        Returns an albumentations.core.composition.Compose class depending on the augmentation level.
-        A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
-        The A.Resize() function is used if the user has specified a 'resize' int or tuple at the creation of the Augmentation class.
+        Creates an augmentation pipeline based on the specified level.
 
         Args:
-            im_h (int): image height for resize.
-            im_w (int): image width for resize.
-            level (str, optional): level of augmentation.
+            im_h (int): Height of the image.
+            im_w (int): Width of the image.
+            im_d (int, optional): Depth of the image (for 3D).
+            level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy'].
 
         Raises:
             ValueError: If `level` is neither None, light, moderate nor heavy.
         """
-        import albumentations as A
-        from albumentations.pytorch import ToTensorV2
+        from monai.transforms import (
+            Compose, RandRotate90, RandFlip, RandAffine, ToTensor, \
+            RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad
+        )
 
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
 
-        # Baseline
-        baseline_aug = [
-            A.Normalize(mean = (self.mean),std = (self.std)),
-            ToTensorV2()
-        ]
+        # Baseline augmentations
+        baseline_aug = [ToTensor()]
+
+        # For 2D, add normalization to the baseline augmentations
+        # TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
+        if not self.is_3d:
+            baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
+
+        # Resize augmentations
         if self.resize == 'crop':
-            resize_aug = [
-                A.CenterCrop(im_h,im_w)
-            ]
+            resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))] if self.is_3d else [CenterSpatialCrop((im_h, im_w))]
+        
         elif self.resize == 'reshape':
-            resize_aug =[
-                A.Resize(im_h,im_w)
-            ]
-        elif self.resize == 'padding':
-            resize_aug = [
-                A.PadIfNeeded(im_h,im_w,border_mode = 0) # OpenCV border mode
-            ]
+            resize_aug = [Resize((im_d, im_h, im_w))] if self.is_3d else [Resize((im_h, im_w))]
         
+        elif self.resize == 'padding':
+            resize_aug = [SpatialPad((im_d, im_h, im_w))] if self.is_3d else [SpatialPad((im_h, im_w))]
+
         # Level of augmentation
         if level == None:
             level_aug = []
+
         elif level == 'light':
-            level_aug = [
-                A.RandomRotate90()
-            ]
+            level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1, 2))] if self.is_3d else [RandRotate90(prob=1)]
+        
         elif level == 'moderate':
             level_aug = [
-                A.RandomRotate90(),
-                A.HorizontalFlip(p = 0.3),
-                A.VerticalFlip(p = 0.3),
-                A.GlassBlur(sigma = 0.7, p = 0.1),
-                A.Affine(scale = [0.9,1.1], translate_percent = (0.1,0.1))
+                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandFlip(prob=0.3, spatial_axis=0),
+                RandFlip(prob=0.3, spatial_axis=1),
+                RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
+                RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
             ]
+
         elif level == 'heavy':
             level_aug = [
-                A.RandomRotate90(),
-                A.HorizontalFlip(p = 0.7),
-                A.VerticalFlip(p = 0.7),
-                A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
-                A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
+                RandRotate90(prob=1, spatial_axes=(0, 1, 2)) if self.is_3d else RandRotate90(prob=1),
+                RandFlip(prob=0.7, spatial_axis=0),
+                RandFlip(prob=0.7, spatial_axis=1),
+                RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
+                RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
             ]
 
-        augment = A.Compose(level_aug + resize_aug + baseline_aug)
-        
-        return augment
\ No newline at end of file
+        return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file
-- 
GitLab