diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 7bba1467fd1687f62a11dd7e18c162ddf9f69c13..03c3c399a1b70aa6f70b9e551d0d179b70590a5d 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -65,7 +65,9 @@ class Dataset(torch.utils.data.Dataset):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
-        if image_path.suffix in ['.nii', '.nii.gz']:
+        full_suffix = ''.join(image_path.suffixes)
+
+        if full_suffix in ['.nii', '.nii.gz']:
 
             # Load 3D volume
             image_data = nib.load(str(image_path))
@@ -75,21 +77,42 @@ class Dataset(torch.utils.data.Dataset):
             image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
             target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
 
+            # Add extra channel dimension
+            image = np.expand_dims(image, axis=0)
+            target = np.expand_dims(target, axis=0)
+
         else:
+
             # Load 2D image
             image = Image.open(str(image_path))
             image = np.array(image)
             target = Image.open(str(target_path))
             target = np.array(target)
 
+            # Grayscale image
+            if len(image.shape) == 2 and len(target.shape) == 2:
+
+                # Add channel dimension
+                image = np.expand_dims(image, axis=0)
+                target = np.expand_dims(target, axis=0)
+            
+            # RGB image
+            elif len(image.shape) == 3 and len(target.shape) == 3:
+
+                # Convert to (C, H, W)
+                image = image.transpose((2, 0, 1))
+                target = target.transpose((2, 0, 1))
+
         if self.transform:
-            transformed = self.transform(image=image, mask=target)
-            image = transformed["image"]
-            target = transformed["mask"]
+            image = self.transform(image) # uint8
+            target = self.transform(target) # int32
+            
+        # TODO: Which dtype?
+        image = image.clone().detach().to(dtype=torch.float32)
+        target = target.clone().detach().to(dtype=torch.float32)
 
         return image, target
 
-
     # TODO: working with images of different sizes
     def check_shape_consistency(self, sample_images: tuple[str]):
         image_shapes= []
@@ -234,6 +257,7 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
 
     # TODO: Support more formats for 3D images
     if full_suffix in ['.nii', '.nii.gz']:
+
         # Load 3D volume
         image = nib.load(str(first_img)).get_fdata()
         orig_shape = image.shape
@@ -247,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
 
     final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
 
-    train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train))
-    val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation))
-    test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test))
+    train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train)) 
+    val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation))
+    test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))
@@ -265,7 +289,7 @@ def prepare_dataloaders(train_set: torch.utils.data,
                         test_set: torch.utils.data, 
                         batch_size: int, 
                         shuffle_train: bool = True, 
-                        num_workers: int = 8, 
+                        num_workers: int = 8,
                         pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:  
     """
     Prepares the dataloaders for model training.