diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index 9dec610470e511d9295d039f8009635431aa44c5..db4bcf6df691bb741555eecaaa8703b05f448889 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -2,64 +2,80 @@ from pathlib import Path from PIL import Image from qim3d.io.logger import log +from qim3d.utils.internal_tools import find_one_image from torch.utils.data import DataLoader +import os import torch import numpy as np class Dataset(torch.utils.data.Dataset): - """ - Custom Dataset class for building a PyTorch dataset. - - Args: - root_path (str): The root directory path of the dataset. - split (str, optional): The split of the dataset, either "train" or "test". - Default is "train". - transform (callable, optional): A callable function or transformation to - be applied to the data. Default is None. + ''' Custom Dataset class for building a PyTorch dataset. - Raises: - ValueError: If the provided split is not valid (neither "train" nor "test"). - - Attributes: - split (str): The split of the dataset ("train" or "test"). - transform (callable): The transformation to be applied to the data. - sample_images (list): A list containing the paths to the sample images in the dataset. - sample_targets (list): A list containing the paths to the corresponding target images - in the dataset. - - Methods: - __len__(): Returns the total number of samples in the dataset. - __getitem__(idx): Returns the image and its target segmentation at the given index. - - Usage: - dataset = Dataset(root_path="path/to/dataset", split="train", - transform=albumentations.Compose([ToTensorV2()])) - image, target = dataset[idx] - """ - def __init__(self, root_path: str, split="train", transform=None): + Case 1: There are no folder - all images and targets are stored in the same data directory. + The image and corresponding target have similar names (eg: data1.tif, data1mask.tif) + + |-- data + |-- img1.tif + |-- img1_mask.tif + |-- img2.tif + |-- img2_mask.tif + |-- ... + + Case 2: There are two folders - one with all the images and one with all the targets. + + |-- data + |-- images + |-- img1.tif + |-- img2.tif + |-- ... + |-- masks + |-- img1_mask.tif + |-- img2_mask.tif + |-- ... + + Case 3: There are many folders - each folder with a case (eg. patient) and multiple images. + + |-- data + |-- patient1 + |-- p1_img1.tif + |-- p1_img1_mask.tif + |-- p1_img2.tif + |-- p1_img2_mask.tif + |-- p1_img3.tif + |-- p1_img3_mask.tif + |-- ... + + |-- patient2 + |-- p2_img1.tif + |-- p2_img1_mask.tif + |-- p2_img2.tif + |-- p2_img2_mask.tif + |-- p2_img3.tif + |-- p2_img3_mask.tif + |-- ... + |-- ... + ''' + def __init__(self, root_path: str, transform=None): super().__init__() - # Check if split is valid - if split not in ["train", "test"]: - raise ValueError("Split must be either train or test") - - self.split = split + self.root_path = root_path self.transform = transform - path = Path(root_path) / split - - self.sample_images = [file for file in sorted((path / "images").iterdir())] - self.sample_targets = [file for file in sorted((path / "labels").iterdir())] - assert len(self.sample_images) == len(self.sample_targets) + # scans folders + self._data_scan() + # finds the images and targets given the folder setup + self._find_samples() + assert len(self.sample_images)==len(self.sample_targets) + # checking the characteristics of the dataset self.check_shape_consistency(self.sample_images) - + def __len__(self): return len(self.sample_images) - + def __getitem__(self, idx): image_path = self.sample_images[idx] target_path = self.sample_targets[idx] @@ -75,19 +91,110 @@ class Dataset(torch.utils.data.Dataset): target = transformed["mask"] return image, target + + + def _data_scan(self): + ''' Find out which of the three categories the data belongs to. + ''' + + # how many folders there are: + files = os.listdir(self.root_path) + n_folders = 0 + folder_names = [] + for f in files: + if os.path.isdir(Path(self.root_path,f)): + n_folders += 1 + folder_names.append(f) + self.n_folders = n_folders + self.folder_names = folder_names + + def _find_samples(self): + ''' Scans and retrieves the images and targets from their given folder configuration. + ''' + + target_folder_names = ['mask','label','target'] + + # Case 1 + if self.n_folders == 0: + sample_images = [] + sample_targets = [] + + for file in os.listdir(self.root_path): + + # checks if a label extension is in the filename + if any(ext in file.lower() for ext in target_folder_names): + sample_targets.append(Path(self.root_path,file)) + + # otherwise the file is assumed to be the image + else: + sample_images.append(Path(self.root_path,file)) + + self.sample_images = sorted(sample_images) + self.sample_targets = sorted(sample_targets) + + # Case 2 + elif self.n_folders == 2: + + # if the first folder contains the targets: + if any(ext in self.folder_names[0].lower() for ext in target_folder_names): + images = self.folders_names[1] + targets = self.folder_names[0] + + # if the second folder contains the targets: + elif any(ext in self.folder_names[1].lower() for ext in target_folder_names): + images = self.folder_names[0] + targets = self.folder_names[1] + + else: + raise ValueError('Folder names do not match categories such as "mask", "label" or "target".') + + self.sample_images = [image for image in sorted(Path(self.root_path,images).iterdir())] + self.sample_targets = [target for target in sorted(Path(self.root_path,targets).iterdir())] + + # Case 3 + elif self.n_folders > 2: + sample_images = [] + sample_targets = [] + + for folder in os.listdir(self.root_path): + + # if some files are not a folder + if not os.path.isdir(Path(self.root_path,folder)): + raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder)} is not a folder.') + + for file in os.listdir(Path(self.root_path,folder)): + + # if files are not images: + if not os.path.isfile(Path(self.root_path,folder,file)): + raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder,file)} is not a file.') + + # checks if a label extension is in the filename + if any(ext in file for ext in target_folder_names): + sample_targets.append(Path(self.root_path,folder,file)) + + # otherwise the file is assumed to be the image + else: + sample_images.append(Path(self.root_path,folder,file)) + + self.sample_images = sorted(sample_images) + self.sample_targets = sorted(sample_targets) + + else: + raise NotImplementedError('The current data structure is not supported.') + # TODO: working with images of different sizes def check_shape_consistency(self,sample_images): - image_shapes= [] - for image_path in sample_images: + image_shapes = [] + for image_path in sample_images[:100]: image_shape = self._get_shape(image_path) image_shapes.append(image_shape) # check if all images have the same size. - consistency_check = all(i == image_shapes[0] for i in image_shapes) - - if not consistency_check: + unique_shapes = len(set(image_shapes)) + + if unique_shapes>1: raise NotImplementedError( "Only images of all the same size can be processed at the moment" ) @@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset): log.debug( "Images are all the same size!" ) - return consistency_check def _get_shape(self,image_path): return Image.open(str(image_path)).size @@ -133,45 +239,98 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): return h_adjust, w_adjust -def prepare_datasets(path: str, val_fraction: float, model, augmentation): - """ - Splits and augments the train/validation/test datasets. +def prepare_datasets( + path:str, + val_fraction: float, + test_fraction: float, + model, + augmentation, + train_folder:str = None, + val_folder:str = None, + test_folder:str = None +): + '''Splits and augments the train/validation/test datasets - Args: - path (str): Path to the dataset. - val_fraction (float): Fraction of the data for the validation set. - model (torch.nn.Module): PyTorch Model. - augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels. + ''' - Raises: - ValueError: if the validation fraction is not a float, and is not between 0 and 1. - """ - if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1): raise ValueError("The validation fraction must be a float between 0 and 1.") + + if not isinstance(test_fraction,float) or not (0 <= test_fraction < 1): + raise ValueError("The test fraction must be a float between 0 and 1.") + if (val_fraction + test_fraction)>=1: + print(int(val_fraction+test_fraction)*100) + raise ValueError(f"The validation and test fractions cover {int((val_fraction+test_fraction)*100)}%. " + "Make sure to lower it below 100%, and include some place for the training data.") + + # find one image: + image = Image.open(find_one_image(path = path)) + orig_h,orig_w = image.size[:2] + resize = augmentation.resize n_channels = len(model.channels) - # taking the size of the 1st image in the dataset - im_path = Path(path) / 'train' - first_img = sorted((im_path / "images").iterdir())[0] - image = Image.open(str(first_img)) - orig_h, orig_w = image.size[:2] - final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) - train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train')) - val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation')) - test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test')) + # change number of channels in UNet if needed + if len(np.array(image).shape)>2: + model.img_channels = np.array(image).shape[2] + model.update_params() + + # Only Train and Test folders are given, splits Train into Train/Val. + if train_folder and test_folder and not val_folder: + + log.info('Only train and test given, splitting train_folder with val fraction.') + train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train')) + val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation')) + test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test')) + + indices = torch.randperm(len(train_set)) + split_idx = int(np.floor(val_fraction * len(train_set))) + + train_set = torch.utils.data.Subset(train_set,indices[split_idx:]) + val_set = torch.utils.data.Subset(val_set,indices[:split_idx]) + + # Only Train and Val folder are given. + elif train_folder and val_folder and not test_folder: + + log.info('Only train and validation folder provided, will not be able to make inference on test data.') + train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train')) + val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation')) + test_set = None + + # All Train/Val/Test folders are given. + elif train_folder and val_folder and test_folder: + + log.info('Retrieving data from train, validation and test folder.') + train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train')) + val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation')) + test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test')) + + # None of the train/val/test folders are given: + elif not(train_folder or val_folder or test_folder): + + log.info('No specific train/validation/test folders given. Splitting the data into train/validation/test sets.') + train_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'train')) + val_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'validation')) + test_set =Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'test')) - split_idx = int(np.floor(val_fraction * len(train_set))) - indices = torch.randperm(len(train_set)) + indices = torch.randperm(len(train_set)) + + train_idx = int(np.floor((1-val_fraction-test_fraction)*len(train_set))) + val_idx = train_idx + int(np.floor(val_fraction*len(train_set))) - train_set = torch.utils.data.Subset(train_set, indices[split_idx:]) - val_set = torch.utils.data.Subset(val_set, indices[:split_idx]) + train_set = torch.utils.data.Subset(train_set,indices[:train_idx]) + val_set = torch.utils.data.Subset(val_set,indices[train_idx:val_idx]) + test_set = torch.utils.data.Subset(test_set,indices[val_idx:]) - return train_set, val_set, test_set + else: + raise ValueError("Your folder configuration cannot be recognized. " + "Give a path to the dataset, or paths to the train/validation/test folders.") + + return train_set,val_set,test_set + def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 0, pin_memory = False): diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index 465cd1cff51b86244813b2c02e2e80b5a142eed5..28bf66a99a22bd66f9f68dcfdc35f3988b3662b8 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -303,4 +303,15 @@ def get_css(): with open(css_path,'r') as file: css_content = file.read() - return css_content \ No newline at end of file + return css_content + +def find_one_image(path): + for entry in os.scandir(path): + if entry.is_dir(): + return find_one_image(entry.path) + elif entry.is_file(): + if any(entry.path.endswith(imagetype) for imagetype in ['jpg','jpeg','tif','tiff','png','PNG']): + return entry.path + + # If all folders/sub-folders do not have anything: + raise ValueError('No Images Found.') \ No newline at end of file