Skip to content
Snippets Groups Projects
Commit 66801424 authored by s193396's avatar s193396
Browse files

added data preparation pipeline to docs

parent a065371b
No related branches found
No related tags found
No related merge requests found
--- ---
hide: hide:
- navigation - navigation
- toc
--- ---
# Machine learning models # Machine learning models
...@@ -17,3 +16,9 @@ The `qim3d` library aims to ease the creation of ML models for volumetric images ...@@ -17,3 +16,9 @@ The `qim3d` library aims to ease the creation of ML models for volumetric images
options: options:
members: members:
- UNet - UNet
::: qim3d.ml
options:
members:
- prepare_datasets
- prepare_dataloaders
This diff is collapsed.
"""Provides a custom Dataset class for building a PyTorch dataset.""" """Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path from pathlib import Path
from PIL import Image
from qim3d.utils import log from qim3d.utils import log
import torch import torch
import numpy as np import numpy as np
import nibabel as nib import nibabel as nib
from typing import Optional, Callable from typing import Optional, Callable
import torch.nn as nn
from ._augmentations import Augmentation from ._augmentations import Augmentation
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
...@@ -33,11 +32,6 @@ class Dataset(torch.utils.data.Dataset): ...@@ -33,11 +32,6 @@ class Dataset(torch.utils.data.Dataset):
Methods: Methods:
__len__(): Returns the total number of samples in the dataset. __len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index. __getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path="path/to/dataset", split="train",
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
""" """
def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None): def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None):
super().__init__() super().__init__()
...@@ -169,7 +163,7 @@ def check_resize( ...@@ -169,7 +163,7 @@ def check_resize(
def prepare_datasets( def prepare_datasets(
path: str, path: str,
val_fraction: float, val_fraction: float,
model: nn.Module, model: torch.nn.Module,
augmentation: Augmentation, augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]: ) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
""" """
...@@ -179,10 +173,26 @@ def prepare_datasets( ...@@ -179,10 +173,26 @@ def prepare_datasets(
path (str): Path to the dataset. path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set. val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model. model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels. augmentation (monai.transforms.Compose): Augmentation class for the dataset with predefined augmentation levels.
Raises: Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1. ValueError: If the validation fraction is not a float, and is not between 0 and 1.
Example:
```python
import qim3d
base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
```
""" """
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1): if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
...@@ -230,8 +240,31 @@ def prepare_dataloaders(train_set: torch.utils.data, ...@@ -230,8 +240,31 @@ def prepare_dataloaders(train_set: torch.utils.data,
test_set (torch.utils.data): Testing dataset. test_set (torch.utils.data): Testing dataset.
batch_size (int): Size of the batches that should be trained upon. batch_size (int): Size of the batches that should be trained upon.
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness). shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
num_workers (int, optional): Defines how many processes should be run in parallel. num_workers (int, optional): Defines how many processes should be run in parallel. Default is 8.
pin_memory (bool, optional): Loads the datasets as CUDA tensors. pin_memory (bool, optional): Loads the datasets as CUDA tensors. Default is False.
Example:
```python
import qim3d
base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
train_set = train_set,
val_set = val_set,
test_set = test_set,
batch_size = 1,
)
```
""" """
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
...@@ -6,9 +6,7 @@ from qim3d.utils import log ...@@ -6,9 +6,7 @@ from qim3d.utils import log
class UNet(nn.Module): class UNet(nn.Module):
""" """
3D UNet model for QIM imaging. 3D UNet model designed for imaging segmentation tasks.
This class represents a 3D UNet model designed for imaging segmentation tasks.
Args: Args:
size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'. size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
...@@ -21,6 +19,13 @@ class UNet(nn.Module): ...@@ -21,6 +19,13 @@ class UNet(nn.Module):
Raises: Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'. ValueError: If `size` is not one of 'small', 'medium', or 'large'.
Example:
```python
import qim3d
model = qim3d.ml.models.UNet(size = 'small')
```
""" """
def __init__( def __init__(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment