diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 253da4a168fa6ebfea0120fb8d725b6fdf5591e4..7bba1467fd1687f62a11dd7e18c162ddf9f69c13 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -4,6 +4,7 @@ from PIL import Image
 from qim3d.utils import log
 import torch
 import numpy as np
+import nibabel as nib
 from typing import Optional, Callable
 import torch.nn as nn
 from ._augmentations import Augmentation
@@ -54,7 +55,7 @@ class Dataset(torch.utils.data.Dataset):
         self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
         assert len(self.sample_images) == len(self.sample_targets)
 
-        # checking the characteristics of the dataset
+        # Checking the characteristics of the dataset
         self.check_shape_consistency(self.sample_images)
 
     def __len__(self):
@@ -64,10 +65,22 @@ class Dataset(torch.utils.data.Dataset):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
-        image = Image.open(str(image_path))
-        image = np.array(image)
-        target = Image.open(str(target_path))
-        target = np.array(target)
+        if image_path.suffix in ['.nii', '.nii.gz']:
+
+            # Load 3D volume
+            image_data = nib.load(str(image_path))
+            target_data = nib.load(str(target_path))
+
+            # Get data from volume
+            image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
+            target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
+
+        else:
+            # Load 2D image
+            image = Image.open(str(image_path))
+            image = np.array(image)
+            target = Image.open(str(target_path))
+            target = np.array(target)
 
         if self.transform:
             transformed = self.transform(image=image, mask=target)
@@ -78,13 +91,13 @@ class Dataset(torch.utils.data.Dataset):
 
 
     # TODO: working with images of different sizes
-    def check_shape_consistency(self,sample_images: tuple[str]):
+    def check_shape_consistency(self, sample_images: tuple[str]):
         image_shapes= []
         for image_path in sample_images:
             image_shape = self._get_shape(image_path)
             image_shapes.append(image_shape)
 
-        # check if all images have the same size.
+        # Check if all images have the same size
         consistency_check = all(i == image_shapes[0] for i in image_shapes)
 
         if not consistency_check:
@@ -97,41 +110,102 @@ class Dataset(torch.utils.data.Dataset):
             )
         return consistency_check
     
-    def _get_shape(self,image_path):
-        return Image.open(str(image_path)).size
+    def _get_shape(self, image_path):
+        full_suffix = ''.join(image_path.suffixes)
+
+        if full_suffix in ['.nii', '.nii.gz']:
+            # Load 3D volume
+            image = nib.load(str(image_path)).get_fdata()
+            return image.shape
+        
+        else:
+            # Load 2D image
+            image = Image.open(str(image_path))
+            return image.size
 
 
-def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
+def check_resize(
+    orig_shape: tuple, 
+    resize: tuple, 
+    n_channels: int, 
+    is_3d: bool
+    ) -> tuple:
     """
-    Checks the compatibility of the image shape with the depth of the model.
-    If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
-    If so, the image is changed to the closest appropriate dimension, and the user is notified with a warning.
+    Checks and adjusts the resize dimensions based on the original shape and the number of channels.
 
     Args:
-        im_height (int) : Height of the original image from the dataset.
-        im_width (int)  : Width of the original image from the dataset.
-        resize (str)    : Type of resize to be used on the image if the size doesn't fit the model.
+        orig_shape (tuple): Original shape of the image.
+        resize (tuple): Desired resize dimensions.
         n_channels (int): Number of channels in the model.
+        is_3d (bool): Whether the data is 3D or not.
+
+    Returns:
+        tuple: Final resize dimensions.
 
     Raises:
         ValueError: If the image size is smaller than minimum required for the model's depth.               
     """
-    # finding suitable size to upsize with padding 
-    if resize == 'padding':
-        h_adjust, w_adjust = (im_height // 2**n_channels+1) * 2**n_channels , (im_width // 2**n_channels+1) * 2**n_channels
+
+    # 3D images
+    if is_3d:
+        orig_d, orig_h, orig_w = orig_shape
+        final_d = resize[0] if resize[0] else orig_d
+        final_h = resize[1] if resize[1] else orig_h
+        final_w = resize[2] if resize[2] else orig_w
+   
+        # Finding suitable size to upsize with padding 
+        if resize == 'padding':
+            final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
+            final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
+            final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
         
-    # finding suitable size to downsize with crop / resize
-    else:
-        h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels    
-    
-    if h_adjust == 0 or w_adjust == 0:
-        raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
-    
-    elif h_adjust != im_height or w_adjust != im_width:
-        log.warning(f"The image size doesn't match the Unet model's depth. The image is changed with '{resize}', from {im_height, im_width} to {h_adjust, w_adjust}.")
+        # Finding suitable size to downsize with crop / resize
+        else:
+            final_d = (orig_d // 2**n_channels) * 2**n_channels
+            final_h = (orig_h // 2**n_channels) * 2**n_channels
+            final_w = (orig_w // 2**n_channels) * 2**n_channels
+
+        # Check if the image size is too small compared to the model's depth
+        if final_d == 0 or final_h == 0 or final_w == 0:
+            msg = "The size of the image is too small compared to the depth of the UNet. \
+                   Choose a different 'resize' and/or a smaller model."
+            
+            raise ValueError(msg)
+        
+        if final_d != orig_d or final_h != orig_h or final_w != orig_w:
+            log.warning(f"The image size doesn't match the Unet model's depth. \
+                          The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
+
+        return final_d, final_h, final_w
     
-    return h_adjust, w_adjust 
+    # 2D images
+    else:
+        orig_h, orig_w = orig_shape
+        final_h = resize[0] if resize[0] else orig_h
+        final_w = resize[1] if resize[1] else orig_w
+
+        # Finding suitable size to upsize with padding 
+        if resize == 'padding':
+            final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
+            final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
+
+        # Finding suitable size to downsize with crop / resize    
+        else:
+            final_h = (orig_h // 2**n_channels) * 2**n_channels
+            final_w = (orig_w // 2**n_channels) * 2**n_channels
+
+        # Check if the image size is too small compared to the model's depth
+        if final_h == 0 or final_w == 0:
+            msg = "The size of the image is too small compared to the depth of the UNet. \
+                   Choose a different 'resize' and/or a smaller model."
+            
+            raise ValueError(msg)
+        
+        if final_h != orig_h or final_w != orig_w:
+            log.warning(f"The image size doesn't match the Unet model's depth. \
+                          The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
 
+        return final_h, final_w
 
 def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
     """
@@ -153,17 +227,29 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
     resize = augmentation.resize
     n_channels = len(model.channels)
 
-    # taking the size of the 1st image in the dataset
+    # Determine if the dataset is 2D or 3D by checking the first image
     im_path = Path(path) / 'train'
     first_img = sorted((im_path / "images").iterdir())[0]
-    image = Image.open(str(first_img))
-    orig_h, orig_w = image.size[:2]
-        
-    final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
+    full_suffix = ''.join(first_img.suffixes)
+
+    # TODO: Support more formats for 3D images
+    if full_suffix in ['.nii', '.nii.gz']:
+        # Load 3D volume
+        image = nib.load(str(first_img)).get_fdata()
+        orig_shape = image.shape
+        is_3d = True
+    
+    else:
+        # Load 2D image
+        image = Image.open(str(first_img))
+        orig_shape = image.size[:2]
+        is_3d = False
+
+    final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
 
-    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
-    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
-    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
+    train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train))
+    val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation))
+    test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))