Skip to content
Snippets Groups Projects
Commit 3c284b4a authored by s193396's avatar s193396
Browse files

updated segmentation pipeline notebook

parent 968d7efe
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Deep learning volume segmentation # Deep learning volume segmentation
Authors: Alessia Saccardo (s212246@dtu.dk) & Felipe Delestro (fima@dtu.dk) Authors: Alessia Saccardo (s212246@dtu.dk) & Felipe Delestro (fima@dtu.dk)
This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it will highlight the utilization of the annotation tool and walk through the process of creating and training a Unet model. This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it will highlight the utilization of the annotation tool and walk through the process of creating and training a Unet model.
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Imports ### Imports
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import qim3d import qim3d
import numpy as np import numpy as np
import os import os
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Load data ### Load data
The `qim3d` library contains a set of example volumes which can be easily loaded using `qim3d.examples.{volume_name}` The `qim3d` library contains a set of example volumes which can be easily loaded using `qim3d.examples.{volume_name}`
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vol = qim3d.examples.bone_128x128x128 vol = qim3d.examples.bone_128x128x128
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
To easily have an insight of how the volume looks like we can interact with it using the `slicer` function from `qim3d` To easily have an insight of how the volume looks like we can interact with it using the `slicer` function from `qim3d`
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
qim3d.viz.slicer(vol) qim3d.viz.slicer(vol)
``` ```
%% Output %% Output
interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=… interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=…
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Generate dataset for training # Generate dataset for training
In order to train the classification model, we need to create a dataset from the volume. In order to train the classification model, we need to create a dataset from the volume.
This means that we'll need a few slices to be used for `training` and at least one for the `test` This means that we'll need a few slices to be used for `training` and at least one for the `test`
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The dataset for training the model is managed by `qim3d.utils.prepare_datasets` and it expects files to follow this structure: The dataset for training the model is managed by `qim3d.utils.prepare_datasets` and it expects files to follow this structure:
<pre> <pre>
dataset dataset
├── test ├── test
│ ├── images │ ├── images
│ │ └── FileA.png │ │ └── FileA.png
│ └── labels │ └── labels
│ └── FileA.png │ └── FileA.png
└── train └── train
├── images ├── images
│ ├── FileB.png │ ├── FileB.png
│ └── FileC.png │ └── FileC.png
└── labels └── labels
├── FileB.png ├── FileB.png
└── FileC.png └── FileC.png
</pre> </pre>
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Number of slices that will be used # Number of slices that will be used
num_training = 4 num_training = 4
num_test = 1 num_test = 1
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
In the following cell, we get the slice indices, making sure that we're not using the same indices for training and test. In the following cell, we get the slice indices, making sure that we're not using the same indices for training and test.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Seed for random number generator # Seed for random number generator
seed = 0 seed = 0
# Create a set with all the indices # Create a set with all the indices
all_idxs = set(range(vol.shape[0])) all_idxs = set(range(vol.shape[0]))
# Get indices for training data # Get indices for training data
training_idxs = list(np.random.default_rng(seed).choice(list(all_idxs), size=num_training)) training_idxs = list(np.random.default_rng(seed).choice(list(all_idxs), size=num_training))
print(f"Slices for training data...: {training_idxs}") print(f"Slices for training data...: {training_idxs}")
# Get indices for test data # Get indices for test data
test_idxs = list(np.random.default_rng(seed).choice(list(all_idxs - set(training_idxs)), size=num_test, replace=False)) test_idxs = list(np.random.default_rng(seed).choice(list(all_idxs - set(training_idxs)), size=num_test, replace=False))
print(f"Slices for test data.......: {test_idxs}") print(f"Slices for test data.......: {test_idxs}")
``` ```
%% Output %% Output
Slices for training data...: [108, 81, 65, 34] Slices for training data...: [108, 81, 65, 34]
Slices for test data.......: [109] Slices for test data.......: [109]
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Create folder structure ## Create folder structure
Here we create the necessary directories Here we create the necessary directories
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Base path for the training data # Base path for the training data
base_path = os.path.expanduser("~/dataset") base_path = os.path.expanduser("~/dataset")
# Create directories # Create directories
print("Creating directories:") print("Creating directories:")
for folder_split in ["train", "test"]: for folder_split in ["train", "test"]:
for folder_type in ["images", "labels"]: for folder_type in ["images", "labels"]:
path = os.path.join(base_path, folder_split, folder_type) path = os.path.join(base_path, folder_split, folder_type)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
print(path) print(path)
# Here we have the option to remove any previous files # Here we have the option to remove any previous files
clean_files = True clean_files = True
if clean_files: if clean_files:
for root, dirs, files in os.walk(base_path): for root, dirs, files in os.walk(base_path):
for file in files: for file in files:
file_path = os.path.join(root, file) file_path = os.path.join(root, file)
os.remove(file_path) os.remove(file_path)
``` ```
%% Output %% Output
Creating directories: Creating directories:
C:\Users\s193396/dataset\train\images C:\Users\s193396/dataset\train\images
C:\Users\s193396/dataset\train\labels C:\Users\s193396/dataset\train\labels
C:\Users\s193396/dataset\test\images C:\Users\s193396/dataset\test\images
C:\Users\s193396/dataset\test\labels C:\Users\s193396/dataset\test\labels
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Annotate data # Annotate data
The following cell will generate an annotation tool for each slice that was requested. The following cell will generate an annotation tool for each slice that was requested.
You should use the tool to drawn a mask over the structures you're willing to detect. You should use the tool to drawn a mask over the structures you're willing to detect.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
annotation_tools = {} annotation_tools = {}
for idx in training_idxs + test_idxs: for idx in training_idxs + test_idxs:
if idx in training_idxs: if idx in training_idxs:
subset = "training" subset = "training"
elif idx in test_idxs: elif idx in test_idxs:
subset = "test" subset = "test"
annotation_tools[idx] = qim3d.gui.annotation_tool.Interface() annotation_tools[idx] = qim3d.gui.annotation_tool.Interface()
annotation_tools[idx].name_suffix = f"_{idx}" annotation_tools[idx].name_suffix = f"_{idx}"
print(f"Annotation for slice {idx} ({subset})") print(f"Annotation for slice {idx} ({subset})")
annotation_tools[idx].launch(vol[idx]) annotation_tools[idx].launch(vol[idx])
``` ```
%% Output %% Output
Annotation for slice 108 (training) Annotation for slice 108 (training)
Annotation for slice 81 (training) Annotation for slice 81 (training)
Annotation for slice 65 (training) Annotation for slice 65 (training)
c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade. c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade.
-------- --------
warnings.warn( warnings.warn(
Annotation for slice 34 (training) Annotation for slice 34 (training)
Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB
c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade. c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade.
-------- --------
warnings.warn( warnings.warn(
Annotation for slice 109 (test) Annotation for slice 109 (test)
c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade. c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade.
-------- --------
warnings.warn( warnings.warn(
c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade. c:\Users\s193396\AppData\Local\miniconda3\envs\qim3d\lib\site-packages\gradio\analytics.py:106: UserWarning: IMPORTANT: You are using gradio version 4.44.0, however version 4.44.1 is available, please upgrade.
-------- --------
warnings.warn( warnings.warn(
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Getting masks from the annotation tool ### Getting masks from the annotation tool
The masks are stored in the annotation tool. Here we extract the masks and save them to disk, following the standard needed for the DL model. The masks are stored in the annotation tool. Here we extract the masks and save them to disk, following the standard needed for the DL model.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
print("Saving images and masks to disk") print("Saving images and masks to disk")
for idx in training_idxs + test_idxs: for idx in training_idxs + test_idxs:
if idx in training_idxs: if idx in training_idxs:
folder_split = "train" folder_split = "train"
elif idx in test_idxs: elif idx in test_idxs:
folder_split = "test" folder_split = "test"
print (f"- slice {idx} ({folder_split})") print (f"- slice {idx} ({folder_split})")
mask_dict = annotation_tools[idx].get_result() mask_dict = annotation_tools[idx].get_result()
mask = list(mask_dict.values())[0] mask = list(mask_dict.values())[0]
# Save image # Save image
qim3d.io.save(os.path.join(base_path,folder_split,"images",f"{idx}.png"), vol[idx], replace=True) qim3d.io.save(os.path.join(base_path,folder_split,"images",f"{idx}.png"), vol[idx], replace=True)
# Save label # Save label
qim3d.io.save(os.path.join(base_path,folder_split,"labels",f"{idx}.png"), mask, replace=True) qim3d.io.save(os.path.join(base_path,folder_split,"labels",f"{idx}.png"), mask, replace=True)
``` ```
%% Output %% Output
Loaded shape: (128, 128) Loaded shape: (128, 128)
INFO:qim3d:Loaded shape: (128, 128) INFO:qim3d:Loaded shape: (128, 128)
Volume using 16.0 KB of memory Volume using 16.0 KB of memory
INFO:qim3d:Volume using 16.0 KB of memory INFO:qim3d:Volume using 16.0 KB of memory
System memory: System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
INFO:qim3d:System memory: INFO:qim3d:System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
Loaded shape: (128, 128) Loaded shape: (128, 128)
INFO:qim3d:Loaded shape: (128, 128) INFO:qim3d:Loaded shape: (128, 128)
Volume using 16.0 KB of memory Volume using 16.0 KB of memory
INFO:qim3d:Volume using 16.0 KB of memory INFO:qim3d:Volume using 16.0 KB of memory
System memory: System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
INFO:qim3d:System memory: INFO:qim3d:System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
Loaded shape: (128, 128) Loaded shape: (128, 128)
INFO:qim3d:Loaded shape: (128, 128) INFO:qim3d:Loaded shape: (128, 128)
Volume using 16.0 KB of memory Volume using 16.0 KB of memory
INFO:qim3d:Volume using 16.0 KB of memory INFO:qim3d:Volume using 16.0 KB of memory
System memory: System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
INFO:qim3d:System memory: INFO:qim3d:System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
Loaded shape: (128, 128) Loaded shape: (128, 128)
Saving images and masks to disk Saving images and masks to disk
- slice 108 (train) - slice 108 (train)
- slice 81 (train) - slice 81 (train)
- slice 65 (train) - slice 65 (train)
- slice 34 (train) - slice 34 (train)
INFO:qim3d:Loaded shape: (128, 128) INFO:qim3d:Loaded shape: (128, 128)
Volume using 16.0 KB of memory Volume using 16.0 KB of memory
INFO:qim3d:Volume using 16.0 KB of memory INFO:qim3d:Volume using 16.0 KB of memory
System memory: System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
INFO:qim3d:System memory: INFO:qim3d:System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
Loaded shape: (128, 128) Loaded shape: (128, 128)
INFO:qim3d:Loaded shape: (128, 128) INFO:qim3d:Loaded shape: (128, 128)
Volume using 16.0 KB of memory Volume using 16.0 KB of memory
INFO:qim3d:Volume using 16.0 KB of memory INFO:qim3d:Volume using 16.0 KB of memory
System memory: System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
INFO:qim3d:System memory: INFO:qim3d:System memory:
• Total.: 31.6 GB • Total.: 31.6 GB
• Used..: 15.8 GB (50.0%) • Used..: 15.8 GB (50.0%)
• Free..: 15.8 GB (50.0%) • Free..: 15.8 GB (50.0%)
- slice 109 (test) - slice 109 (test)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Build Unet # Build Unet
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Building and training the Unet model is straightforward using `qim3d`. Building and training the Unet model is straightforward using `qim3d`.
We first need to instantiate the model by defining its size, which can be either *small*, *medium* or *large*. We first need to instantiate the model by defining its size, which can be either *small*, *medium* or *large*.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# defining model # defining model
model = qim3d.ml.models.UNet(size = 'small', dropout = 0.25) model = qim3d.ml.models.UNet2D(size = 'small', dropout = 0.25)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Then we need to decide which type of augumentation to apply to the data. Then we need to decide which type of augumentation to apply to the data.
The `qim3d.ml.Augmentation` allows to specify how the images should be reshaped to the appropriate size and the level of transformation to apply respectively to train, test and validation sets. The `qim3d.ml.Augmentation` allows to specify how the images should be reshaped to the appropriate size and the level of transformation to apply respectively to train, test and validation sets.
The resize must be choosen between [*crop*, *reshape*, *padding*] and the level of transformation must be chosse between [*None*, *light*, *moderate*, *heavy*]. The user can also specify the mean and standard deviation values for normalizing pixel intensities. The resize must be choosen between [*crop*, *reshape*, *padding*] and the level of transformation must be chosse between [*None*, *light*, *moderate*, *heavy*]. The user can also specify the mean and standard deviation values for normalizing pixel intensities.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# defining augmentation # defining augmentation
aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light') aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Then the datasets and dataloaders are instantiated Then the datasets and dataloaders are instantiated
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# datasets and dataloaders # datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path, train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path,
val_fraction = 0.5, val_fraction = 0.5,
model = model, model = model,
augmentation = aug) augmentation = aug)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set,
val_set, val_set,
test_set, test_set,
batch_size = 1) batch_size = 1)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
The hyperparameters are defined using the function `qim3d.ml.Hyperparameters` and the model can be easily trained by running the function `qim3d.ml.train_model` which returns also a plot of the losses at the end of the training if the option is selected by the user The hyperparameters are defined using the function `qim3d.ml.Hyperparameters` and the model can be easily trained by running the function `qim3d.ml.train_model` which returns also a plot of the losses at the end of the training if the option is selected by the user
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Train model # Train model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# model hyperparameters # model hyperparameters
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10, hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10,
learning_rate = 5e-3, loss_function='DiceCE', learning_rate = 5e-3, loss_function='DiceCE',
weight_decay=1e-3) weight_decay=1e-3)
# training model # training model
qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True) qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True)
``` ```
%% Output %% Output
Epoch 0, train loss: 1.9671, val loss: 1.5506 Epoch 0, train loss: 1.9671, val loss: 1.5506
Epoch 5, train loss: 0.9591, val loss: 0.8574 Epoch 5, train loss: 0.9591, val loss: 0.8574
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Check results # Check results
To compute the inference step it is just needed to run `qim3d.ml.inference`. To compute the inference step it is just needed to run `qim3d.ml.inference`.
The results can be visualize with the function `qim3d.viz.grid_pred` that shows the predicted segmentation along with a comparison between the ground truth. The results can be visualize with the function `qim3d.viz.grid_pred` that shows the predicted segmentation along with a comparison between the ground truth.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
in_targ_preds_test = qim3d.ml.inference(test_set, model) in_targ_preds_test = qim3d.ml.inference(test_set, model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1) qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
``` ```
%% Output %% Output
Not enough images in the dataset. Changing num_images=7 to num_images=1 Not enough images in the dataset. Changing num_images=7 to num_images=1
<Figure size 200x1000 with 4 Axes> <Figure size 200x1000 with 4 Axes>
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Compute inference on entire volume # Compute inference on entire volume
Given that the input is a volume, the goal is to perform inference on the entire volume rather than individual slices. Given that the input is a volume, the goal is to perform inference on the entire volume rather than individual slices.
By using the function `qim3d.ml.volume_inference` it is possible to obtain the segmentation volume output By using the function `qim3d.ml.volume_inference` it is possible to obtain the segmentation volume output
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
inference_vol = qim3d.ml.volume_inference(vol, model) inference_vol = qim3d.ml.volume_inference(vol, model)
qim3d.viz.slicer(inference_vol) qim3d.viz.slicer(inference_vol)
``` ```
%% Output %% Output
interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=… interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=…
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We can also visualize the created mask together with the original volume We can also visualize the created mask together with the original volume
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vol_masked = qim3d.viz.vol_masked(vol, inference_vol, viz_delta=128) vol_masked = qim3d.viz.vol_masked(vol, inference_vol, viz_delta=128)
qim3d.viz.slicer(vol_masked, color_map="PiYG") qim3d.viz.slicer(vol_masked, color_map="PiYG")
``` ```
%% Output %% Output
interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=… interactive(children=(IntSlider(value=64, description='Slice', max=127), Output()), layout=Layout(align_items=…
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment