From ddc40406aa3768b7502174276dc4d34850257a68 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk>
Date: Mon, 17 Feb 2025 14:55:56 +0100
Subject: [PATCH] removed 2D dataloader

---
 qim3d/ml/_data.py | 147 +++++++++++-----------------------------------
 1 file changed, 33 insertions(+), 114 deletions(-)

diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index d2bd93a1..a5717b79 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
-        full_suffix = ''.join(image_path.suffixes)
-
-        if full_suffix in ['.nii', '.nii.gz']:
-
-            # Load 3D volume
-            image_data = nib.load(str(image_path))
-            target_data = nib.load(str(target_path))
-
-            # Get data from volume
-            image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
-            target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
-
-            # Add extra channel dimension
-            image = np.expand_dims(image, axis=0)
-            target = np.expand_dims(target, axis=0)
-
-        else:
-
-            # Load 2D image
-            image = Image.open(str(image_path))
-            image = np.array(image)
-            target = Image.open(str(target_path))
-            target = np.array(target)
-
-            # Grayscale image
-            if len(image.shape) == 2 and len(target.shape) == 2:
+        # Load 3D volume
+        image_data = nib.load(str(image_path))
+        target_data = nib.load(str(target_path))
 
-                # Add channel dimension
-                image = np.expand_dims(image, axis=0)
-                target = np.expand_dims(target, axis=0)
-            
-            # RGB image
-            elif len(image.shape) == 3 and len(target.shape) == 3:
+        # Get data from volume
+        image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
+        target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
 
-                # Convert to (C, H, W)
-                image = image.transpose((2, 0, 1))
-                target = target.transpose((2, 0, 1))
+        # Add extra channel dimension
+        image = np.expand_dims(image, axis=0)
+        target = np.expand_dims(target, axis=0)
 
         if self.transform:
             transformed = self.transform({"image": image, "label": target})
             image = transformed["image"]
             target = transformed["label"]
-
-            # image = self.transform(image) # uint8
-            # target = self.transform(target) # int32
             
-        # TODO: Which dtype?
         image = image.clone().detach().to(dtype=torch.float32)
         target = target.clone().detach().to(dtype=torch.float32)
 
@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset):
         return consistency_check
     
     def _get_shape(self, image_path):
-        full_suffix = ''.join(image_path.suffixes)
-
-        if full_suffix in ['.nii', '.nii.gz']:
-            # Load 3D volume
-            image = nib.load(str(image_path)).get_fdata()
-            return image.shape
-        
-        else:
-            # Load 2D image
-            image = Image.open(str(image_path))
-            return image.size
 
+        # Load 3D volume
+        image = nib.load(str(image_path)).get_fdata()
+        return image.shape
 
 def check_resize(
     orig_shape: tuple, 
     resize: tuple, 
-    n_channels: int, 
-    is_3d: bool
+    n_channels: int,
     ) -> tuple:
     """
     Checks and adjusts the resize dimensions based on the original shape and the number of channels.
@@ -164,7 +125,6 @@ def check_resize(
         orig_shape (tuple): Original shape of the image.
         resize (tuple): Desired resize dimensions.
         n_channels (int): Number of channels in the model.
-        is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
 
     Returns:
         tuple: Final resize dimensions.
@@ -174,23 +134,22 @@ def check_resize(
     """
 
     # 3D images
-    if is_3d:
-        orig_d, orig_h, orig_w = orig_shape
-        final_d = resize[0] if resize[0] else orig_d
-        final_h = resize[1] if resize[1] else orig_h
-        final_w = resize[2] if resize[2] else orig_w
-   
-        # Finding suitable size to upsize with padding 
-        if resize == 'padding':
-            final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
-            final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
-            final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
-        
-        # Finding suitable size to downsize with crop / resize
-        else:
-            final_d = (orig_d // 2**n_channels) * 2**n_channels
-            final_h = (orig_h // 2**n_channels) * 2**n_channels
-            final_w = (orig_w // 2**n_channels) * 2**n_channels
+    orig_d, orig_h, orig_w = orig_shape
+    final_d = resize[0] if resize[0] else orig_d
+    final_h = resize[1] if resize[1] else orig_h
+    final_w = resize[2] if resize[2] else orig_w
+
+    # Finding suitable size to upsize with padding 
+    if resize == 'padding':
+        final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
+        final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
+        final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
+    
+    # Finding suitable size to downsize with crop / resize
+    else:
+        final_d = (orig_d // 2**n_channels) * 2**n_channels
+        final_h = (orig_h // 2**n_channels) * 2**n_channels
+        final_w = (orig_w // 2**n_channels) * 2**n_channels
 
         # Check if the image size is too small compared to the model's depth
         if final_d == 0 or final_h == 0 or final_w == 0:
@@ -205,35 +164,6 @@ def check_resize(
 
         return final_d, final_h, final_w
     
-    # 2D images
-    else:
-        orig_h, orig_w = orig_shape
-        final_h = resize[0] if resize[0] else orig_h
-        final_w = resize[1] if resize[1] else orig_w
-
-        # Finding suitable size to upsize with padding 
-        if resize == 'padding':
-            final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
-            final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
-
-        # Finding suitable size to downsize with crop / resize    
-        else:
-            final_h = (orig_h // 2**n_channels) * 2**n_channels
-            final_w = (orig_w // 2**n_channels) * 2**n_channels
-
-        # Check if the image size is too small compared to the model's depth
-        if final_h == 0 or final_w == 0:
-            msg = "The size of the image is too small compared to the depth of the UNet. \
-                   Choose a different 'resize' and/or a smaller model."
-            
-            raise ValueError(msg)
-        
-        if final_h != orig_h or final_w != orig_w:
-            log.warning(f"The image size doesn't match the Unet model's depth. \
-                          The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
-
-        return final_h, final_w
-
 def prepare_datasets(
         path: str, 
         val_fraction: float, 
@@ -262,23 +192,12 @@ def prepare_datasets(
     # Determine if the dataset is 2D or 3D by checking the first image
     im_path = Path(path) / 'train'
     first_img = sorted((im_path / "images").iterdir())[0]
-    full_suffix = ''.join(first_img.suffixes)
 
-    # TODO: Support more formats for 3D images
-    if full_suffix in ['.nii', '.nii.gz']:
-
-        # Load 3D volume
-        image = nib.load(str(first_img)).get_fdata()
-        orig_shape = image.shape
-        is_3d = True
-    
-    else:
-        # Load 2D image
-        image = Image.open(str(first_img))
-        orig_shape = image.size[:2]
-        is_3d = False
+    # Load 3D volume
+    image = nib.load(str(first_img)).get_fdata()
+    orig_shape = image.shape
 
-    final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
+    final_shape = check_resize(orig_shape, resize, n_channels)
 
     train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train)) 
     val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation))
-- 
GitLab