From 32fa3f08f167f1b48e43ae03e32c3b199e708aa0 Mon Sep 17 00:00:00 2001
From: s184364 <s184364@student.dtu.dk>
Date: Mon, 22 Jan 2024 16:47:37 +0100
Subject: [PATCH] Changed to less drastic heavy augmentation, changed the way
 augmentation level is chosen in prepare_datasets.

---
 qim3d/utils/augmentations.py | 13 ++++++++++---
 qim3d/utils/data.py          |  6 +++---
 2 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py
index 19154b70..ed93955f 100644
--- a/qim3d/utils/augmentations.py
+++ b/qim3d/utils/augmentations.py
@@ -40,7 +40,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h, im_w, level=None):
+    def augment(self, im_h, im_w, type=None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
@@ -49,12 +49,19 @@ class Augmentation:
         Args:
             im_h (int): image height for resize.
             im_w (int): image width for resize.
-            level (str, optional): level of augmentation.
+            type (str, optional): level of augmentation.
 
         Raises:
             ValueError: If `level` is neither None, light, moderate nor heavy.
         """
         
+        if type=='train':   
+            level = self.transform_train
+        elif type=='validation':
+            level = self.transform_validation
+        elif type=='test':
+            level = self.transform_test
+
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
@@ -98,7 +105,7 @@ class Augmentation:
                 A.HorizontalFlip(p = 0.7),
                 A.VerticalFlip(p = 0.7),
                 A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
-                A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
+                A.Affine(scale = [0.9,1.1], translate_percent = (-0.2,0.2), shear = (-5,5))
             ]
 
         augment = A.Compose(level_aug + resize_aug + baseline_aug)
diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
index fb2a0384..9dec6104 100644
--- a/qim3d/utils/data.py
+++ b/qim3d/utils/data.py
@@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
         
     final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
 
-    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
-    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
-    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
+    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train'))
+    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation'))
+    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test'))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))
-- 
GitLab