Skip to content
Snippets Groups Projects
Commit fb0797c5 authored by ofhkr's avatar ofhkr
Browse files

update: train val test dataloader for several types of stored data.

parent ba974ff9
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
......@@ -2,64 +2,80 @@
from pathlib import Path
from PIL import Image
from qim3d.io.logger import log
from qim3d.utils.internal_tools import find_one_image
from torch.utils.data import DataLoader
import os
import torch
import numpy as np
class Dataset(torch.utils.data.Dataset):
"""
Custom Dataset class for building a PyTorch dataset.
Args:
root_path (str): The root directory path of the dataset.
split (str, optional): The split of the dataset, either "train" or "test".
Default is "train".
transform (callable, optional): A callable function or transformation to
be applied to the data. Default is None.
''' Custom Dataset class for building a PyTorch dataset.
Raises:
ValueError: If the provided split is not valid (neither "train" nor "test").
Attributes:
split (str): The split of the dataset ("train" or "test").
transform (callable): The transformation to be applied to the data.
sample_images (list): A list containing the paths to the sample images in the dataset.
sample_targets (list): A list containing the paths to the corresponding target images
in the dataset.
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path="path/to/dataset", split="train",
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
"""
def __init__(self, root_path: str, split="train", transform=None):
Case 1: There are no folder - all images and targets are stored in the same data directory.
The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
|-- data
|-- img1.tif
|-- img1_mask.tif
|-- img2.tif
|-- img2_mask.tif
|-- ...
Case 2: There are two folders - one with all the images and one with all the targets.
|-- data
|-- images
|-- img1.tif
|-- img2.tif
|-- ...
|-- masks
|-- img1_mask.tif
|-- img2_mask.tif
|-- ...
Case 3: There are many folders - each folder with a case (eg. patient) and multiple images.
|-- data
|-- patient1
|-- p1_img1.tif
|-- p1_img1_mask.tif
|-- p1_img2.tif
|-- p1_img2_mask.tif
|-- p1_img3.tif
|-- p1_img3_mask.tif
|-- ...
|-- patient2
|-- p2_img1.tif
|-- p2_img1_mask.tif
|-- p2_img2.tif
|-- p2_img2_mask.tif
|-- p2_img3.tif
|-- p2_img3_mask.tif
|-- ...
|-- ...
'''
def __init__(self, root_path: str, transform=None):
super().__init__()
# Check if split is valid
if split not in ["train", "test"]:
raise ValueError("Split must be either train or test")
self.split = split
self.root_path = root_path
self.transform = transform
path = Path(root_path) / split
self.sample_images = [file for file in sorted((path / "images").iterdir())]
self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
assert len(self.sample_images) == len(self.sample_targets)
# scans folders
self._data_scan()
# finds the images and targets given the folder setup
self._find_samples()
assert len(self.sample_images)==len(self.sample_targets)
# checking the characteristics of the dataset
self.check_shape_consistency(self.sample_images)
def __len__(self):
return len(self.sample_images)
def __getitem__(self, idx):
image_path = self.sample_images[idx]
target_path = self.sample_targets[idx]
......@@ -75,19 +91,110 @@ class Dataset(torch.utils.data.Dataset):
target = transformed["mask"]
return image, target
def _data_scan(self):
''' Find out which of the three categories the data belongs to.
'''
# how many folders there are:
files = os.listdir(self.root_path)
n_folders = 0
folder_names = []
for f in files:
if os.path.isdir(Path(self.root_path,f)):
n_folders += 1
folder_names.append(f)
self.n_folders = n_folders
self.folder_names = folder_names
def _find_samples(self):
''' Scans and retrieves the images and targets from their given folder configuration.
'''
target_folder_names = ['mask','label','target']
# Case 1
if self.n_folders == 0:
sample_images = []
sample_targets = []
for file in os.listdir(self.root_path):
# checks if a label extension is in the filename
if any(ext in file.lower() for ext in target_folder_names):
sample_targets.append(Path(self.root_path,file))
# otherwise the file is assumed to be the image
else:
sample_images.append(Path(self.root_path,file))
self.sample_images = sorted(sample_images)
self.sample_targets = sorted(sample_targets)
# Case 2
elif self.n_folders == 2:
# if the first folder contains the targets:
if any(ext in self.folder_names[0].lower() for ext in target_folder_names):
images = self.folders_names[1]
targets = self.folder_names[0]
# if the second folder contains the targets:
elif any(ext in self.folder_names[1].lower() for ext in target_folder_names):
images = self.folder_names[0]
targets = self.folder_names[1]
else:
raise ValueError('Folder names do not match categories such as "mask", "label" or "target".')
self.sample_images = [image for image in sorted(Path(self.root_path,images).iterdir())]
self.sample_targets = [target for target in sorted(Path(self.root_path,targets).iterdir())]
# Case 3
elif self.n_folders > 2:
sample_images = []
sample_targets = []
for folder in os.listdir(self.root_path):
# if some files are not a folder
if not os.path.isdir(Path(self.root_path,folder)):
raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder)} is not a folder.')
for file in os.listdir(Path(self.root_path,folder)):
# if files are not images:
if not os.path.isfile(Path(self.root_path,folder,file)):
raise NotImplementedError(f'The current data structure is not supported. {Path(self.root_path,folder,file)} is not a file.')
# checks if a label extension is in the filename
if any(ext in file for ext in target_folder_names):
sample_targets.append(Path(self.root_path,folder,file))
# otherwise the file is assumed to be the image
else:
sample_images.append(Path(self.root_path,folder,file))
self.sample_images = sorted(sample_images)
self.sample_targets = sorted(sample_targets)
else:
raise NotImplementedError('The current data structure is not supported.')
# TODO: working with images of different sizes
def check_shape_consistency(self,sample_images):
image_shapes= []
for image_path in sample_images:
image_shapes = []
for image_path in sample_images[:100]:
image_shape = self._get_shape(image_path)
image_shapes.append(image_shape)
# check if all images have the same size.
consistency_check = all(i == image_shapes[0] for i in image_shapes)
if not consistency_check:
unique_shapes = len(set(image_shapes))
if unique_shapes>1:
raise NotImplementedError(
"Only images of all the same size can be processed at the moment"
)
......@@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset):
log.debug(
"Images are all the same size!"
)
return consistency_check
def _get_shape(self,image_path):
return Image.open(str(image_path)).size
......@@ -133,45 +239,98 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
return h_adjust, w_adjust
def prepare_datasets(path: str, val_fraction: float, model, augmentation):
"""
Splits and augments the train/validation/test datasets.
def prepare_datasets(
path:str,
val_fraction: float,
test_fraction: float,
model,
augmentation,
train_folder:str = None,
val_folder:str = None,
test_folder:str = None
):
'''Splits and augments the train/validation/test datasets
Args:
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
'''
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
"""
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
raise ValueError("The validation fraction must be a float between 0 and 1.")
if not isinstance(test_fraction,float) or not (0 <= test_fraction < 1):
raise ValueError("The test fraction must be a float between 0 and 1.")
if (val_fraction + test_fraction)>=1:
print(int(val_fraction+test_fraction)*100)
raise ValueError(f"The validation and test fractions cover {int((val_fraction+test_fraction)*100)}%. "
"Make sure to lower it below 100%, and include some place for the training data.")
# find one image:
image = Image.open(find_one_image(path = path))
orig_h,orig_w = image.size[:2]
resize = augmentation.resize
n_channels = len(model.channels)
# taking the size of the 1st image in the dataset
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img))
orig_h, orig_w = image.size[:2]
final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train'))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation'))
test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test'))
# change number of channels in UNet if needed
if len(np.array(image).shape)>2:
model.img_channels = np.array(image).shape[2]
model.update_params()
# Only Train and Test folders are given, splits Train into Train/Val.
if train_folder and test_folder and not val_folder:
log.info('Only train and test given, splitting train_folder with val fraction.')
train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test'))
indices = torch.randperm(len(train_set))
split_idx = int(np.floor(val_fraction * len(train_set)))
train_set = torch.utils.data.Subset(train_set,indices[split_idx:])
val_set = torch.utils.data.Subset(val_set,indices[:split_idx])
# Only Train and Val folder are given.
elif train_folder and val_folder and not test_folder:
log.info('Only train and validation folder provided, will not be able to make inference on test data.')
train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
test_set = None
# All Train/Val/Test folders are given.
elif train_folder and val_folder and test_folder:
log.info('Retrieving data from train, validation and test folder.')
train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test'))
# None of the train/val/test folders are given:
elif not(train_folder or val_folder or test_folder):
log.info('No specific train/validation/test folders given. Splitting the data into train/validation/test sets.')
train_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'train'))
val_set = Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'validation'))
test_set =Dataset(root_path=path,transform=augmentation.augment(final_h, final_w,type = 'test'))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
indices = torch.randperm(len(train_set))
train_idx = int(np.floor((1-val_fraction-test_fraction)*len(train_set)))
val_idx = train_idx + int(np.floor(val_fraction*len(train_set)))
train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
train_set = torch.utils.data.Subset(train_set,indices[:train_idx])
val_set = torch.utils.data.Subset(val_set,indices[train_idx:val_idx])
test_set = torch.utils.data.Subset(test_set,indices[val_idx:])
return train_set, val_set, test_set
else:
raise ValueError("Your folder configuration cannot be recognized. "
"Give a path to the dataset, or paths to the train/validation/test folders.")
return train_set,val_set,test_set
def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 0, pin_memory = False):
......
......@@ -303,4 +303,15 @@ def get_css():
with open(css_path,'r') as file:
css_content = file.read()
return css_content
\ No newline at end of file
return css_content
def find_one_image(path):
for entry in os.scandir(path):
if entry.is_dir():
return find_one_image(entry.path)
elif entry.is_file():
if any(entry.path.endswith(imagetype) for imagetype in ['jpg','jpeg','tif','tiff','png','PNG']):
return entry.path
# If all folders/sub-folders do not have anything:
raise ValueError('No Images Found.')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment