Skip to content
Snippets Groups Projects
Commit 9a9ada7d authored by ofhkr's avatar ofhkr
Browse files

Added docstring to modified files.

parent fb0797c5
Branches
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
This diff is collapsed.
......@@ -12,7 +12,33 @@ import numpy as np
class Dataset(torch.utils.data.Dataset):
''' Custom Dataset class for building a PyTorch dataset.
Args:
root_path (str): The root directory path of the dataset.
transform (callable, optional): A callable function or transformation to
be applied to the data. Default is None.
Raises:
ValueError: If the provided split is not valid (neither "train" nor "test").
Attributes:
root_path (str): root directory path to the dataset.
transform (callable): The transformation to be applied to the data.
sample_images (list): A list containing the paths to the sample images in the dataset.
sample_targets (list): A list containing the paths to the corresponding target images
in the dataset.
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
_data_scan(): Finds how many folders are in the directory path as well as their names.
_find_samples(): Finds the images and targets according to one of the 3 datastructure cases.
Usage:
dataset = Dataset(root_path="path/to/dataset",
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
Notes:
Case 1: There are no folder - all images and targets are stored in the same data directory.
The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
......@@ -112,6 +138,12 @@ class Dataset(torch.utils.data.Dataset):
def _find_samples(self):
''' Scans and retrieves the images and targets from their given folder configuration.
Raises:
ValueError: in Case 2, if no folder contains any of the labels 'mask', 'label', 'target'.
NotImplementedError: in Case 3, if a file is found among the list of folders.
NotImplementedError: in Case 3, if a folder is found among the list of files.
NotImplementedError: If the data structure does not fall into one of the three cases.
'''
target_folder_names = ['mask','label','target']
......@@ -241,16 +273,41 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
def prepare_datasets(
path:str,
val_fraction: float,
test_fraction: float,
model,
augmentation,
val_fraction: float = 0.1,
test_fraction: float = 0.1,
train_folder:str = None,
val_folder:str = None,
test_folder:str = None
):
'''Splits and augments the train/validation/test datasets
Args:
path (str): Path to the dataset.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
val_fraction (float, optional): Fraction of the data for the validation set.
test_fraction (float, optional): Fraction of the data for the test set.
train_folder (str, optional): Can be used to specify where the data for training data is located.
val_folder (str, optional): Can be used to specify where the data for validation data is located.
test_folder (str, optional): Can be used to specify where the data for testing data is located.
Raises:
ValueError: If the validation fraction is not a float, and is not between 0 and 1.
ValueError: If the test fraction is not a float, and is not between 0 and 1.
ValueError: If the sum of the validation and test fractions is equal or larger than 1.
ValueError: If the combination of train/val/test_folder strings isn't enough to prepare the data for model training.
Usage:
# if all data stored together:
prepare_datasets(path="path/to/dataset", val_fraction = 0.2, test_fraction = 0.1,
model = qim3d.models.UNet(), augmentation = qim3d.utils.Augmentation())
# if data has be pre-configured into training/testing:
prepare_datasets(path="path/to/dataset", val_fraction = 0.2, test_fraction = 0.1,
model = qim3d.models.UNet(), augmentation = qim3d.utils.Augmentation(),
train_folder = 'training_folder_name', test_folder = 'test_folder_name')
'''
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
......@@ -264,7 +321,7 @@ def prepare_datasets(
raise ValueError(f"The validation and test fractions cover {int((val_fraction+test_fraction)*100)}%. "
"Make sure to lower it below 100%, and include some place for the training data.")
# find one image:
# Finds one image:
image = Image.open(find_one_image(path = path))
orig_h,orig_w = image.size[:2]
......@@ -273,7 +330,7 @@ def prepare_datasets(
final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
# change number of channels in UNet if needed
# Change number of channels in UNet if needed
if len(np.array(image).shape)>2:
model.img_channels = np.array(image).shape[2]
model.update_params()
......@@ -281,7 +338,7 @@ def prepare_datasets(
# Only Train and Test folders are given, splits Train into Train/Val.
if train_folder and test_folder and not val_folder:
log.info('Only train and test given, splitting train_folder with val fraction.')
log.info('Only train and test given, splitting train_folder with val_fraction.')
train_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'train'))
val_set = Dataset(root_path=Path(path,train_folder),transform=augmentation.augment(final_h, final_w,type = 'validation'))
test_set = Dataset(root_path=Path(path,test_folder),transform=augmentation.augment(final_h, final_w,type = 'test'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment