Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
66801424
Commit
66801424
authored
4 months ago
by
s193396
Browse files
Options
Downloads
Patches
Plain Diff
added data preparation pipeline to docs
parent
a065371b
No related branches found
No related tags found
No related merge requests found
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
docs/doc/ml/models.md
+7
-2
7 additions, 2 deletions
docs/doc/ml/models.md
docs/notebooks/UNet.ipynb
+10
-21
10 additions, 21 deletions
docs/notebooks/UNet.ipynb
qim3d/ml/_data.py
+49
-16
49 additions, 16 deletions
qim3d/ml/_data.py
qim3d/ml/models/_unet.py
+8
-3
8 additions, 3 deletions
qim3d/ml/models/_unet.py
with
74 additions
and
42 deletions
docs/doc/ml/models.md
+
7
−
2
View file @
66801424
---
---
hide
:
hide
:
-
navigation
-
navigation
-
toc
---
---
# Machine learning models
# Machine learning models
...
@@ -17,3 +16,9 @@ The `qim3d` library aims to ease the creation of ML models for volumetric images
...
@@ -17,3 +16,9 @@ The `qim3d` library aims to ease the creation of ML models for volumetric images
options:
options:
members:
members:
-
UNet
-
UNet
::: qim3d.ml
options:
members:
-
prepare_datasets
-
prepare_dataloaders
This diff is collapsed.
Click to expand it.
docs/notebooks/UNet.ipynb
+
10
−
21
View file @
66801424
This diff is collapsed.
Click to expand it.
qim3d/ml/_data.py
+
49
−
16
View file @
66801424
"""
Provides a custom Dataset class for building a PyTorch dataset.
"""
"""
Provides a custom Dataset class for building a PyTorch dataset.
"""
from
pathlib
import
Path
from
pathlib
import
Path
from
PIL
import
Image
from
qim3d.utils
import
log
from
qim3d.utils
import
log
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
nibabel
as
nib
import
nibabel
as
nib
from
typing
import
Optional
,
Callable
from
typing
import
Optional
,
Callable
import
torch.nn
as
nn
from
._augmentations
import
Augmentation
from
._augmentations
import
Augmentation
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
...
@@ -33,11 +32,6 @@ class Dataset(torch.utils.data.Dataset):
...
@@ -33,11 +32,6 @@ class Dataset(torch.utils.data.Dataset):
Methods:
Methods:
__len__(): Returns the total number of samples in the dataset.
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
__getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path=
"
path/to/dataset
"
, split=
"
train
"
,
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
"""
"""
def
__init__
(
self
,
root_path
:
str
,
split
:
str
=
"
train
"
,
transform
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root_path
:
str
,
split
:
str
=
"
train
"
,
transform
:
Optional
[
Callable
]
=
None
):
super
().
__init__
()
super
().
__init__
()
...
@@ -169,7 +163,7 @@ def check_resize(
...
@@ -169,7 +163,7 @@ def check_resize(
def
prepare_datasets
(
def
prepare_datasets
(
path
:
str
,
path
:
str
,
val_fraction
:
float
,
val_fraction
:
float
,
model
:
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
augmentation
:
Augmentation
,
augmentation
:
Augmentation
,
)
->
tuple
[
torch
.
utils
.
data
.
Subset
,
torch
.
utils
.
data
.
Subset
,
torch
.
utils
.
data
.
Subset
]:
)
->
tuple
[
torch
.
utils
.
data
.
Subset
,
torch
.
utils
.
data
.
Subset
,
torch
.
utils
.
data
.
Subset
]:
"""
"""
...
@@ -179,10 +173,26 @@ def prepare_datasets(
...
@@ -179,10 +173,26 @@ def prepare_datasets(
path (str): Path to the dataset.
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
model (torch.nn.Module): PyTorch Model.
augmentation (
albumentations.core.composition
.Compose): Augmentation class for the dataset with predefined augmentation levels.
augmentation (
monai.transforms
.Compose): Augmentation class for the dataset with predefined augmentation levels.
Raises:
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
ValueError: If the validation fraction is not a float, and is not between 0 and 1.
Example:
```python
import qim3d
base_path =
"
C:/dataset/
"
model = qim3d.ml.models.UNet(size =
'
small
'
)
augmentation = qim3d.ml.Augmentation(resize =
'
crop
'
, transform_train =
'
light
'
)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
```
"""
"""
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
...
@@ -230,8 +240,31 @@ def prepare_dataloaders(train_set: torch.utils.data,
...
@@ -230,8 +240,31 @@ def prepare_dataloaders(train_set: torch.utils.data,
test_set (torch.utils.data): Testing dataset.
test_set (torch.utils.data): Testing dataset.
batch_size (int): Size of the batches that should be trained upon.
batch_size (int): Size of the batches that should be trained upon.
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
num_workers (int, optional): Defines how many processes should be run in parallel.
num_workers (int, optional): Defines how many processes should be run in parallel. Default is 8.
pin_memory (bool, optional): Loads the datasets as CUDA tensors.
pin_memory (bool, optional): Loads the datasets as CUDA tensors. Default is False.
Example:
```python
import qim3d
base_path =
"
C:/dataset/
"
model = qim3d.ml.models.UNet(size =
'
small
'
)
augmentation = qim3d.ml.Augmentation(resize =
'
crop
'
, transform_train =
'
light
'
)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation
)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
train_set = train_set,
val_set = val_set,
test_set = test_set,
batch_size = 1,
)
```
"""
"""
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
...
This diff is collapsed.
Click to expand it.
qim3d/ml/models/_unet.py
+
8
−
3
View file @
66801424
...
@@ -6,9 +6,7 @@ from qim3d.utils import log
...
@@ -6,9 +6,7 @@ from qim3d.utils import log
class
UNet
(
nn
.
Module
):
class
UNet
(
nn
.
Module
):
"""
"""
3D UNet model for QIM imaging.
3D UNet model designed for imaging segmentation tasks.
This class represents a 3D UNet model designed for imaging segmentation tasks.
Args:
Args:
size (
'
small
'
or
'
medium
'
or
'
large
'
, optional): Size of the UNet model. Must be one of
'
small
'
,
'
medium
'
, or
'
large
'
. Defaults to
'
medium
'
.
size (
'
small
'
or
'
medium
'
or
'
large
'
, optional): Size of the UNet model. Must be one of
'
small
'
,
'
medium
'
, or
'
large
'
. Defaults to
'
medium
'
.
...
@@ -21,6 +19,13 @@ class UNet(nn.Module):
...
@@ -21,6 +19,13 @@ class UNet(nn.Module):
Raises:
Raises:
ValueError: If `size` is not one of
'
small
'
,
'
medium
'
, or
'
large
'
.
ValueError: If `size` is not one of
'
small
'
,
'
medium
'
, or
'
large
'
.
Example:
```python
import qim3d
model = qim3d.ml.models.UNet(size =
'
small
'
)
```
"""
"""
def
__init__
(
def
__init__
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment