Skip to content
Snippets Groups Projects
Commit 24923e92 authored by s193396's avatar s193396
Browse files

testing things in notebook

parent ab581084
Branches
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# **Deep Learning Volume Segmentation (3D UNet)**
Authors: Anna Ekner (s193396@dtu.dk)
This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it highlights the use of the synthetic data generation functionalities to create a volumetric dataset with associated labels, and walks through the process of creating and training a 3D UNet model using this synthetic dataset.
%% Cell type:code id: tags:
``` python
import qim3d
import glob
import os
import numpy as np
```
%% Cell type:markdown id: tags:
### **1. Generate synthetic dataset**
%% Cell type:markdown id: tags:
#### 1.1 Example of data sample (probably should be after creating the dataset??)
%% Cell type:markdown id: tags:
Synthetic dataset and associated labels.
%% Cell type:code id: tags:
``` python
num_objects = 5
vol, labels = qim3d.generate.noise_object_collection(
num_objects = num_objects,
collection_shape = (128, 128, 128),
min_object_noise = 0.03,
max_object_noise = 0.08,
)
# num_objects = 5
# vol, labels = qim3d.generate.noise_object_collection(
# num_objects = num_objects,
# collection_shape = (128, 128, 128),
# min_object_noise = 0.03,
# max_object_noise = 0.08,
# )
```
%% Output
%% Cell type:code id: tags:
``` python
# Visualize synthetic collection
# qim3d.viz.volumetric(vol)
```
%% Cell type:code id: tags:
``` python
# Visualize slices
# qim3d.viz.slicer(vol)
```
%% Cell type:markdown id: tags:
There will be $N + 1$ unique labels, because one extra for background.
But we want only 2 labels: foreground and background.
%% Cell type:code id: tags:
``` python
# Convert N + 1 labels into 2 labels (background and object)
labels = (labels > 0).astype(int)
# # Convert N + 1 labels into 2 labels (background and object)
# labels = (labels > 0).astype(int)
```
%% Cell type:code id: tags:
``` python
# Visualize labels
# qim3d.viz.slicer(labels)
```
%% Cell type:markdown id: tags:
#### 1.2 Create folder structure
%% Cell type:code id: tags:
``` python
# Base path for the training data
base_path = os.path.expanduser("~/dataset")
# Create directories
print("Creating directories:")
for folder_split in ["train", "test"]:
for folder_type in ["images", "labels"]:
path = os.path.join(base_path, folder_split, folder_type)
os.makedirs(path, exist_ok=True)
print(path)
# Here we have the option to remove any previous files
clean_files = True
if clean_files:
for root, dirs, files in os.walk(base_path):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
```
%% Output
Creating directories:
C:\Users\s193396/dataset\train\images
C:\Users\s193396/dataset\train\labels
C:\Users\s193396/dataset\test\images
C:\Users\s193396/dataset\test\labels
%% Cell type:markdown id: tags:
#### 1.3 Create dataset
%% Cell type:markdown id: tags:
We need to create a dataset of multiple volumes
%% Cell type:code id: tags:
``` python
num_samples = 5
# Specify the number of training and testing samples
num_train = 4
num_test = 1
```
%% Cell type:code id: tags:
``` python
# Specify the number of objects in each volume
num_objects = 5
# Generate training and testing samples
for idx in range(num_train + num_test):
# Determine the folder split (train or test)
if idx < num_train:
folder_split = "train"
else:
folder_split = "test"
for idx in range(num_samples):
# TODO: Figure out whether or not the seed makes it such that all volumes are identical?
vol, label = qim3d.generate.noise_object_collection(
num_objects = num_objects,
collection_shape = (128, 128, 128),
min_object_noise = 0.03,
max_object_noise = 0.08,
)
num_objects=num_objects,
collection_shape=(128, 128, 128),
min_object_noise=0.03,
max_object_noise=0.08,
)
# Convert N + 1 labels into 2 labels (background and object)
label = (labels > 0).astype(int)
label = (label > 0).astype(int)
# Save volume
qim3d.io.save(os.path.join(base_path, folder_split, "images", f"{idx}.nii.gz"), vol, replace = True)
qim3d.io.save(os.path.join(base_path, folder_split, "images", f"{idx + 1}.nii.gz"), vol, compression = True, replace=True)
# Save label
qim3d.io.save(os.path.join(base_path, folder_split, "labels", f"{idx}.nii.gz"), label, replace = True)
qim3d.io.save(os.path.join(base_path, folder_split, "labels", f"{idx + 1}.nii.gz"), label, compression = True, replace=True)
```
%% Output
%% Cell type:code id: tags:
``` python
# volumes = sorted(glob.glob(os.path.join(base_path, "im*.nii.gz")))
# labels = sorted(glob.glob(os.path.join(base_path, "seg*.nii.gz")))
# Check the image indices for training and testing data
for folder_split in ["train", "test"]:
path = os.path.join(base_path, folder_split, "images")
files = os.listdir(path)
files = [int(os.path.basename(f)[0]) for f in files]
if folder_split == "train":
print(f"Image indices for training data.....: {files}")
elif folder_split == "test":
print(f"Image indices for testing data......: {files}")
```
%% Output
Image indices for training data.....: [1, 2, 3, 4]
Image indices for testing data......: [5]
%% Cell type:markdown id: tags:
### **2. Build 3D UNet model**
%% Cell type:markdown id: tags:
#### 2.1 Instantiate UNet model
%% Cell type:code id: tags:
``` python
model = qim3d.ml.models.UNet(size = 'small', dropout = 0.25)
```
%% Cell type:markdown id: tags:
#### 2.2 Define augmentations
%% Cell type:code id: tags:
``` python
augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
```
%% Cell type:markdown id: tags:
#### 2.3 Divide dataset into train and test splits
%% Cell type:code id: tags:
``` python
# # Debug stuff
# import qim3d
# import numpy as np
# vol, labels = qim3d.generate.noise_object_collection(
# num_objects = 1,
# collection_shape = (128, 128, 128),
# min_object_noise = 0.03,
# max_object_noise = 0.04,
# )
# augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
# print(f"Before augmentation: {vol.shape} and {vol.dtype}")
# print(f"Augmentation: {augmentation}\n")
# compose_augmentation = augmentation.augment(*vol.shape, level = augmentation.transform_train)
# vol_expanded = np.expand_dims(vol, axis = 0)
# transformed_vol = compose_augmentation(vol_expanded)
# print(f"After channel expansion and augmentation: {transformed_vol.shape}")
# print(f"Dtype after augmentation: {transformed_vol.dtype}\n")
# # Remove the channel dimension
# vol_reduced = np.squeeze(transformed_vol)
# print(f"After reducing channel dimension: {vol_reduced.shape}")
# # qim3d.viz.volumetric(vol_reduced)
```
%% Output
Before augmentation: (128, 128, 128) and uint8
Augmentation: <qim3d.ml._augmentations.Augmentation object at 0x0000028B367D1870>
After channel expansion and augmentation: torch.Size([1, 128, 128, 128])
Dtype after augmentation: torch.uint8
After reducing channel dimension: (128, 128, 128)
%% Cell type:markdown id: tags:
Stuff
%% Cell type:code id: tags:
``` python
# datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path,
val_fraction = 0.5,
model = model,
augmentation = augmentation)
val_fraction = 0.5,
model = model,
augmentation = augmentation)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set,
val_set,
test_set,
batch_size = 1)
val_set,
test_set,
batch_size = 1)
```
%% Output
3D image shape: (128, 128, 128)
%% Cell type:markdown id: tags:
### **3. Train model**
%% Cell type:markdown id: tags:
#### 3.1 Define training hyperparameters
%% Cell type:code id: tags:
``` python
# hyperparameters
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10,
learning_rate = 5e-3, loss_function='DiceCE',
weight_decay=1e-3)
learning_rate = 5e-3, loss_function='DiceCE',
weight_decay=1e-3)
```
%% Cell type:markdown id: tags:
#### 3.2 Train model
%% Cell type:code id: tags:
``` python
# training model
qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True)
```
%% Output
Epoch 0, train loss: 1.6603, val loss: 1.3486
Epoch 5, train loss: 1.4026, val loss: 0.4789
%% Cell type:markdown id: tags:
### **4. Test model**
%% Cell type:code id: tags:
``` python
# Needs to be updated to handle 3D as well
in_targ_preds_test = qim3d.ml.inference(test_set, model)
```
%% Output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[24], line 1
----> 1 in_targ_preds_test = qim3d.ml.inference(test_set, model)
File c:\s193396\qim3d\qim3d\ml\_ml_utils.py:214, in inference(data, model)
212 pass
213 else:
--> 214 raise ValueError("Input image must be (C,H,W) format")
216 model.to(device)
217 model.eval()
ValueError: Input image must be (C,H,W) format
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment