Skip to content
Snippets Groups Projects
Commit e21f1d51 authored by s193396's avatar s193396
Browse files

draft for 3D UNet notebook

parent ae49d74c
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# **Deep Learning Volume Segmentation (3D UNet)**
Authors: Anna Ekner (s193396@dtu.dk)
This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it highlights the use of the synthetic data generation functionalities to create a volumetric dataset with associated labels, and walks through the process of creating and training a 3D UNet model using this synthetic dataset.
%% Cell type:code id: tags:
``` python
import qim3d
import glob
import os
import numpy as np
```
%% Cell type:markdown id: tags:
### **1. Generate synthetic dataset**
%% Cell type:markdown id: tags:
#### 1.1 Example of data sample (probably should be after creating the dataset??)
%% Cell type:markdown id: tags:
Synthetic dataset and associated labels.
%% Cell type:code id: tags:
``` python
num_objects = 5
vol, labels = qim3d.generate.noise_object_collection(
num_objects = num_objects,
collection_shape = (128, 128, 128),
min_object_noise = 0.03,
max_object_noise = 0.08,
)
```
%% Output
%% Cell type:code id: tags:
``` python
# Visualize synthetic collection
# qim3d.viz.volumetric(vol)
```
%% Cell type:code id: tags:
``` python
# Visualize slices
# qim3d.viz.slicer(vol)
```
%% Cell type:markdown id: tags:
There will be $N + 1$ unique labels, because one extra for background.
But we want only 2 labels: foreground and background.
%% Cell type:code id: tags:
``` python
# Convert N + 1 labels into 2 labels (background and object)
labels = (labels > 0).astype(int)
```
%% Cell type:code id: tags:
``` python
# Visualize labels
# qim3d.viz.slicer(labels)
```
%% Cell type:markdown id: tags:
#### 1.2 Create folder structure
%% Cell type:code id: tags:
``` python
# Base path for the training data
base_path = os.path.expanduser("~/dataset")
# Create directories
print("Creating directories:")
for folder_split in ["train", "test"]:
for folder_type in ["images", "labels"]:
path = os.path.join(base_path, folder_split, folder_type)
os.makedirs(path, exist_ok=True)
print(path)
# Here we have the option to remove any previous files
clean_files = True
if clean_files:
for root, dirs, files in os.walk(base_path):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
```
%% Output
Creating directories:
C:\Users\s193396/dataset\train\images
C:\Users\s193396/dataset\train\labels
C:\Users\s193396/dataset\test\images
C:\Users\s193396/dataset\test\labels
%% Cell type:markdown id: tags:
#### 1.3 Create dataset
%% Cell type:markdown id: tags:
We need to create a dataset of multiple volumes
%% Cell type:code id: tags:
``` python
num_samples = 5
for idx in range(num_samples):
# TODO: Figure out whether or not the seed makes it such that all volumes are identical?
vol, label = qim3d.generate.noise_object_collection(
num_objects = num_objects,
collection_shape = (128, 128, 128),
min_object_noise = 0.03,
max_object_noise = 0.08,
)
# Convert N + 1 labels into 2 labels (background and object)
label = (labels > 0).astype(int)
# Save volume
qim3d.io.save(os.path.join(base_path, folder_split, "images", f"{idx}.nii.gz"), vol, replace = True)
# Save label
qim3d.io.save(os.path.join(base_path, folder_split, "labels", f"{idx}.nii.gz"), label, replace = True)
```
%% Cell type:code id: tags:
``` python
# volumes = sorted(glob.glob(os.path.join(base_path, "im*.nii.gz")))
# labels = sorted(glob.glob(os.path.join(base_path, "seg*.nii.gz")))
```
%% Cell type:markdown id: tags:
### **2. Build 3D UNet model**
%% Cell type:markdown id: tags:
#### 2.1 Instantiate UNet model
%% Cell type:code id: tags:
``` python
model = qim3d.ml.models.UNet(size = 'small', dropout = 0.25)
```
%% Cell type:markdown id: tags:
#### 2.2 Define augmentations
%% Cell type:code id: tags:
``` python
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
```
%% Cell type:markdown id: tags:
#### 2.3 Divide dataset into train and test splits
%% Cell type:code id: tags:
``` python
# datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set,
val_set,
test_set,
batch_size = 1)
```
%% Cell type:markdown id: tags:
### **3. Train model**
%% Cell type:markdown id: tags:
#### 3.1 Define training hyperparameters
%% Cell type:code id: tags:
``` python
# hyperparameters
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10,
learning_rate = 5e-3, loss_function='DiceCE',
weight_decay=1e-3)
```
%% Cell type:markdown id: tags:
#### 3.2 Train model
%% Cell type:code id: tags:
``` python
# training model
qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True)
```
%% Cell type:markdown id: tags:
### **4. Test model**
%% Cell type:code id: tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment