diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
index 9dec610470e511d9295d039f8009635431aa44c5..db4bcf6df691bb741555eecaaa8703b05f448889 100644
--- a/qim3d/utils/data.py
+++ b/qim3d/utils/data.py
@@ -2,64 +2,80 @@
 from pathlib import Path
 from PIL import Image
 from qim3d.io.logger import log
+from qim3d.utils.internal_tools import find_one_image
 from torch.utils.data import DataLoader
 
+import os
 import torch
 import numpy as np
 
 
 class Dataset(torch.utils.data.Dataset):
-    """
-    Custom Dataset class for building a PyTorch dataset.
-
-    Args:
-        root_path (str): The root directory path of the dataset.
-        split (str, optional): The split of the dataset, either "train" or "test".
-            Default is "train".
-        transform (callable, optional): A callable function or transformation to 
-            be applied to the data. Default is None.
+    ''' Custom Dataset class for building a PyTorch dataset.
 
-    Raises:
-        ValueError: If the provided split is not valid (neither "train" nor "test").
-
-    Attributes:
-        split (str): The split of the dataset ("train" or "test").
-        transform (callable): The transformation to be applied to the data.
-        sample_images (list): A list containing the paths to the sample images in the dataset.
-        sample_targets (list): A list containing the paths to the corresponding target images
-            in the dataset.
-
-    Methods:
-        __len__(): Returns the total number of samples in the dataset.
-        __getitem__(idx): Returns the image and its target segmentation at the given index.
-
-    Usage:
-        dataset = Dataset(root_path="path/to/dataset", split="train", 
-            transform=albumentations.Compose([ToTensorV2()]))
-        image, target = dataset[idx]
-    """
-    def __init__(self, root_path: str, split="train", transform=None):
+    Case 1: There are no folder - all images and targets are stored in the same data directory. 
+            The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
+    
+    |-- data
+        |-- img1.tif
+        |-- img1_mask.tif
+        |-- img2.tif
+        |-- img2_mask.tif
+        |-- ...
+    
+    Case 2: There are two folders - one with all the images and one with all the targets.
+
+    |-- data
+        |-- images
+            |-- img1.tif
+            |-- img2.tif
+            |-- ...
+        |-- masks
+            |-- img1_mask.tif
+            |-- img2_mask.tif
+            |-- ...
+    
+    Case 3: There are many folders - each folder with a case (eg. patient) and multiple images.
+
+    |-- data
+        |-- patient1
+            |-- p1_img1.tif
+            |-- p1_img1_mask.tif
+            |-- p1_img2.tif
+            |-- p1_img2_mask.tif
+            |-- p1_img3.tif
+            |-- p1_img3_mask.tif
+            |-- ...
+
+        |-- patient2
+            |-- p2_img1.tif
+            |-- p2_img1_mask.tif
+            |-- p2_img2.tif
+            |-- p2_img2_mask.tif
+            |-- p2_img3.tif
+            |-- p2_img3_mask.tif
+            |-- ...
+        |-- ...
+    '''
+    def __init__(self, root_path: str, transform=None):
         super().__init__()
 
-        # Check if split is valid
-        if split not in ["train", "test"]:
-            raise ValueError("Split must be either train or test")
-
-        self.split = split
+        self.root_path = root_path
         self.transform = transform
 
-        path = Path(root_path) / split
-
-        self.sample_images = [file for file in sorted((path / "images").iterdir())]
-        self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
-        assert len(self.sample_images) == len(self.sample_targets)
+        # scans folders
+        self._data_scan()
+        # finds the images and targets given the folder setup
+        self._find_samples()
 
+        assert len(self.sample_images)==len(self.sample_targets)
+        
         # checking the characteristics of the dataset
         self.check_shape_consistency(self.sample_images)
-
+    
     def __len__(self):
         return len(self.sample_images)
-
+    
     def __getitem__(self, idx):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
@@ -75,19 +91,110 @@ class Dataset(torch.utils.data.Dataset):
             target = transformed["mask"]
 
         return image, target
+    
+
+    def _data_scan(self):
+        ''' Find out which of the three categories the data belongs to.
+        '''
+
+        # how many folders there are:
+        files = os.listdir(self.root_path)
+        n_folders = 0
+        folder_names = []
+        for f in files:
+            if os.path.isdir(Path(self.root_path,f)):
+                n_folders += 1
+                folder_names.append(f)
 
+        self.n_folders = n_folders
+        self.folder_names = folder_names
 
+
+    def _find_samples(self):
+        ''' Scans and retrieves the images and targets from their given folder configuration.
+        '''
+        
+        target_folder_names = ['mask','label','target']
+
+        # Case 1
+        if self.n_folders == 0:
+            sample_images = []
+            sample_targets = []
+
+            for file in os.listdir(self.root_path):
+
+                # checks if a label extension is in the filename
+                if any(ext in file.lower() for ext in target_folder_names):
+                    sample_targets.append(Path(self.root_path,file))
+                
+                # otherwise the file is assumed to be the image
+                else:
+                    sample_images.append(Path(self.root_path,file))
+            
+            self.sample_images = sorted(sample_images)
+            self.sample_targets = sorted(sample_targets)
+        
+        # Case 2
+        elif self.n_folders == 2:
+
+            # if the first folder contains the targets:
+            if any(ext in self.folder_names[0].lower() for ext in target_folder_names):
+                images = self.folders_names[1]
+                targets  = self.folder_names[0]
+            
+            # if the second folder contains the targets:
+            elif any(ext in self.folder_names[1].lower() for ext in target_folder_names):
+                images = self.folder_names[0]
+                targets  = self.folder_names[1]
+
+            else:
+                raise ValueError('Folder names do not match categories such as "mask", "label" or "target".')
+
+            self.sample_images = [image for image in sorted(Path(self.root_path,images).iterdir())]
+            self.sample_targets = [target for target in sorted(Path(self.root_path,targets).iterdir())]
+
+        # Case 3
+        elif self.n_folders > 2:
+            sample_images = []
+            sample_targets = []
+
+            for folder in os.listdir(self.root_path):
+                
+                # if some files are not a folder
+                if not os.path.isdir(Path(self.root_path,folder)):
+                    raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder)} is not a folder.')
+
+                for file in os.listdir(Path(self.root_path,folder)):
+
+                    # if files are not images:
+                    if not os.path.isfile(Path(self.root_path,folder,file)):
+                        raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder,file)} is not a file.')
+
+                    # checks if a label extension is in the filename
+                    if any(ext in file for ext in target_folder_names):
+                        sample_targets.append(Path(self.root_path,folder,file))
+                    
+                    # otherwise the file is assumed to be the image
+                    else:
+                        sample_images.append(Path(self.root_path,folder,file))
+
+            self.sample_images = sorted(sample_images)
+            self.sample_targets = sorted(sample_targets)
+        
+        else:
+            raise NotImplementedError('The current data structure is not supported.')
+    
     # TODO: working with images of different sizes
     def check_shape_consistency(self,sample_images):
-        image_shapes= []
-        for image_path in sample_images:
+        image_shapes = []
+        for image_path in sample_images[:100]:
             image_shape = self._get_shape(image_path)
             image_shapes.append(image_shape)
 
         # check if all images have the same size.
-        consistency_check = all(i == image_shapes[0] for i in image_shapes)
-
-        if not consistency_check:
+        unique_shapes = len(set(image_shapes))
+               
+        if unique_shapes>1:
             raise NotImplementedError(
                 "Only images of all the same size can be processed at the moment"
             )
@@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset):
             log.debug(
                 "Images are all the same size!"
             )
-        return consistency_check
     
     def _get_shape(self,image_path):
         return Image.open(str(image_path)).size
@@ -133,45 +239,98 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
     return h_adjust, w_adjust 
 
 
-def prepare_datasets(path: str, val_fraction: float, model, augmentation):
-    """
-    Splits and augments the train/validation/test datasets.
+def prepare_datasets(
+    path:str,
+    val_fraction: float,
+    test_fraction: float,
+    model,
+    augmentation,
+    train_folder:str = None,
+    val_folder:str = None,
+    test_folder:str = None
+):
+    '''Splits and augments the train/validation/test datasets
 
-    Args:
-        path (str): Path to the dataset.
-        val_fraction (float): Fraction of the data for the validation set.
-        model (torch.nn.Module): PyTorch Model.
-        augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
+    '''
 
-    Raises:
-        ValueError: if the validation fraction is not a float, and is not between 0 and 1. 
-    """
-    
     if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
         raise ValueError("The validation fraction must be a float between 0 and 1.")
+    
+    if not isinstance(test_fraction,float) or not (0 <= test_fraction < 1):
+        raise ValueError("The test fraction must be a float between 0 and 1.")
 
+    if (val_fraction + test_fraction)>=1:
+        print(int(val_fraction+test_fraction)*100)
+        raise ValueError(f"The validation and test fractions cover {int((val_fraction+test_fraction)*100)}%. "
+                         "Make sure to lower it below 100%, and include some place for the training data.")
+    
+    # find one image:
+    image = Image.open(find_one_image(path = path))
+    orig_h,orig_w = image.size[:2]
+    
     resize = augmentation.resize
     n_channels = len(model.channels)
 
-    # taking the size of the 1st image in the dataset
-    im_path = Path(path) / 'train'
-    first_img = sorted((im_path / "images").iterdir())[0]
-    image = Image.open(str(first_img))
-    orig_h, orig_w = image.size[:2]
-        
     final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
 
-    train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train'))
-    val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation'))
-    test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test'))
+    # change number of channels in UNet if needed
+    if len(np.array(image).shape)>2:
+        model.img_channels = np.array(image).shape[2]
+        model.update_params()
+
+    # Only Train and Test folders are given, splits Train into Train/Val.
+    if train_folder and test_folder and not val_folder:
+        
+        log.info('Only train and test given, splitting train_folder with val fraction.')
+        train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
+        val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
+        test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test'))
+
+        indices = torch.randperm(len(train_set))
+        split_idx = int(np.floor(val_fraction * len(train_set)))
+
+        train_set = torch.utils.data.Subset(train_set,indices[split_idx:])
+        val_set = torch.utils.data.Subset(val_set,indices[:split_idx])        
+
+    # Only Train and Val folder are given.
+    elif train_folder and val_folder and not test_folder:
+
+        log.info('Only train and validation folder provided, will not be able to make inference on test data.')
+        train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
+        val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
+        test_set = None
+
+    # All Train/Val/Test folders are given.
+    elif train_folder and val_folder and test_folder:
+        
+        log.info('Retrieving data from train, validation and test folder.')
+        train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
+        val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
+        test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test'))
+
+    # None of the train/val/test folders are given:
+    elif not(train_folder or val_folder or test_folder):
+
+        log.info('No specific train/validation/test folders given. Splitting the data into train/validation/test sets.')
+        train_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'train'))
+        val_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'validation'))
+        test_set =Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'test'))
 
-    split_idx = int(np.floor(val_fraction * len(train_set)))
-    indices = torch.randperm(len(train_set))
+        indices = torch.randperm(len(train_set))
+        
+        train_idx = int(np.floor((1-val_fraction-test_fraction)*len(train_set)))
+        val_idx = train_idx + int(np.floor(val_fraction*len(train_set)))
 
-    train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
-    val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
+        train_set = torch.utils.data.Subset(train_set,indices[:train_idx])
+        val_set = torch.utils.data.Subset(val_set,indices[train_idx:val_idx])
+        test_set = torch.utils.data.Subset(test_set,indices[val_idx:])
     
-    return train_set, val_set, test_set
+    else:
+        raise ValueError("Your folder configuration cannot be recognized. "
+                         "Give a path to the dataset, or paths to the train/validation/test folders.")
+
+    return train_set,val_set,test_set
+
 
 
 def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 0, pin_memory = False):  
diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py
index 465cd1cff51b86244813b2c02e2e80b5a142eed5..28bf66a99a22bd66f9f68dcfdc35f3988b3662b8 100644
--- a/qim3d/utils/internal_tools.py
+++ b/qim3d/utils/internal_tools.py
@@ -303,4 +303,15 @@ def get_css():
     with open(css_path,'r') as file:
         css_content = file.read()
     
-    return css_content
\ No newline at end of file
+    return css_content
+
+def find_one_image(path):
+    for entry in os.scandir(path):
+        if entry.is_dir():
+            return find_one_image(entry.path)
+        elif entry.is_file():
+            if any(entry.path.endswith(imagetype) for imagetype in ['jpg','jpeg','tif','tiff','png','PNG']):
+                return entry.path
+            
+    # If all folders/sub-folders do not have anything:
+    raise ValueError('No Images Found.')
\ No newline at end of file