Skip to content
Snippets Groups Projects
Commit 66801424 authored by s193396's avatar s193396
Browse files

added data preparation pipeline to docs

parent a065371b
No related branches found
No related tags found
No related merge requests found
---
hide:
- navigation
- toc
---
# Machine learning models
......@@ -17,3 +16,9 @@ The `qim3d` library aims to ease the creation of ML models for volumetric images
options:
members:
- UNet
::: qim3d.ml
options:
members:
- prepare_datasets
- prepare_dataloaders
This diff is collapsed.
"""Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path
from PIL import Image
from qim3d.utils import log
import torch
import numpy as np
import nibabel as nib
from typing import Optional, Callable
import torch.nn as nn
from ._augmentations import Augmentation
class Dataset(torch.utils.data.Dataset):
......@@ -33,11 +32,6 @@ class Dataset(torch.utils.data.Dataset):
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path="path/to/dataset", split="train",
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
"""
def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None):
super().__init__()
......@@ -169,7 +163,7 @@ def check_resize(
def prepare_datasets(
path: str,
val_fraction: float,
model: nn.Module,
model: torch.nn.Module,
augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
"""
......@@ -179,10 +173,26 @@ def prepare_datasets(
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
augmentation (monai.transforms.Compose): Augmentation class for the dataset with predefined augmentation levels.
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
ValueError: If the validation fraction is not a float, and is not between 0 and 1.
Example:
```python
import qim3d
base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
```
"""
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
......@@ -230,8 +240,31 @@ def prepare_dataloaders(train_set: torch.utils.data,
test_set (torch.utils.data): Testing dataset.
batch_size (int): Size of the batches that should be trained upon.
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
num_workers (int, optional): Defines how many processes should be run in parallel.
pin_memory (bool, optional): Loads the datasets as CUDA tensors.
num_workers (int, optional): Defines how many processes should be run in parallel. Default is 8.
pin_memory (bool, optional): Loads the datasets as CUDA tensors. Default is False.
Example:
```python
import qim3d
base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
train_set = train_set,
val_set = val_set,
test_set = test_set,
batch_size = 1,
)
```
"""
from torch.utils.data import DataLoader
......
......@@ -6,9 +6,7 @@ from qim3d.utils import log
class UNet(nn.Module):
"""
3D UNet model for QIM imaging.
This class represents a 3D UNet model designed for imaging segmentation tasks.
3D UNet model designed for imaging segmentation tasks.
Args:
size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
......@@ -21,6 +19,13 @@ class UNet(nn.Module):
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
Example:
```python
import qim3d
model = qim3d.ml.models.UNet(size = 'small')
```
"""
def __init__(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment