diff --git a/docs/notebooks/UNet.ipynb b/docs/notebooks/UNet.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1c39366d87b1e4a8b59febc2464df99c60122780 --- /dev/null +++ b/docs/notebooks/UNet.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **Deep Learning Volume Segmentation (3D UNet)**\n", + "\n", + "Authors: Anna Ekner (s193396@dtu.dk)\n", + "\n", + "This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it highlights the use of the synthetic data generation functionalities to create a volumetric dataset with associated labels, and walks through the process of creating and training a 3D UNet model using this synthetic dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import qim3d\n", + "import glob\n", + "import os\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **1. Generate synthetic dataset**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.1 Example of data sample (probably should be after creating the dataset??)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Synthetic dataset and associated labels." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d9b9bb97fe7d4828b5b7f359cb8ac50a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Objects placed: 0%| | 0/5 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num_objects = 5\n", + "vol, labels = qim3d.generate.noise_object_collection(\n", + " num_objects = num_objects,\n", + " collection_shape = (128, 128, 128),\n", + " min_object_noise = 0.03, \n", + " max_object_noise = 0.08,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize synthetic collection\n", + "# qim3d.viz.volumetric(vol)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize slices\n", + "# qim3d.viz.slicer(vol)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There will be $N + 1$ unique labels, because one extra for background.\n", + "But we want only 2 labels: foreground and background." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert N + 1 labels into 2 labels (background and object)\n", + "labels = (labels > 0).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize labels\n", + "# qim3d.viz.slicer(labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.2 Create folder structure" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating directories:\n", + "C:\\Users\\s193396/dataset\\train\\images\n", + "C:\\Users\\s193396/dataset\\train\\labels\n", + "C:\\Users\\s193396/dataset\\test\\images\n", + "C:\\Users\\s193396/dataset\\test\\labels\n" + ] + } + ], + "source": [ + "# Base path for the training data\n", + "base_path = os.path.expanduser(\"~/dataset\")\n", + "\n", + "# Create directories\n", + "print(\"Creating directories:\")\n", + "for folder_split in [\"train\", \"test\"]:\n", + " for folder_type in [\"images\", \"labels\"]:\n", + " path = os.path.join(base_path, folder_split, folder_type)\n", + " os.makedirs(path, exist_ok=True)\n", + " print(path)\n", + "\n", + "# Here we have the option to remove any previous files\n", + "clean_files = True\n", + "if clean_files:\n", + " for root, dirs, files in os.walk(base_path):\n", + " for file in files:\n", + " file_path = os.path.join(root, file)\n", + " os.remove(file_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.3 Create dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to create a dataset of multiple volumes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_samples = 5\n", + "\n", + "for idx in range(num_samples):\n", + " # TODO: Figure out whether or not the seed makes it such that all volumes are identical?\n", + "\n", + " vol, label = qim3d.generate.noise_object_collection(\n", + " num_objects = num_objects,\n", + " collection_shape = (128, 128, 128),\n", + " min_object_noise = 0.03, \n", + " max_object_noise = 0.08,\n", + " )\n", + "\n", + " # Convert N + 1 labels into 2 labels (background and object)\n", + " label = (labels > 0).astype(int)\n", + "\n", + " # Save volume\n", + " qim3d.io.save(os.path.join(base_path, folder_split, \"images\", f\"{idx}.nii.gz\"), vol, replace = True)\n", + "\n", + " # Save label\n", + " qim3d.io.save(os.path.join(base_path, folder_split, \"labels\", f\"{idx}.nii.gz\"), label, replace = True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# volumes = sorted(glob.glob(os.path.join(base_path, \"im*.nii.gz\")))\n", + "# labels = sorted(glob.glob(os.path.join(base_path, \"seg*.nii.gz\")))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **2. Build 3D UNet model**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.1 Instantiate UNet model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = qim3d.ml.models.UNet(size = 'small', dropout = 0.25)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 Define augmentations" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "augmentation = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.3 Divide dataset into train and test splits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# datasets and dataloaders\n", + "train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path,\n", + " val_fraction = 0.5,\n", + " model = model,\n", + " augmentation = augmentation)\n", + "\n", + "\n", + "train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, \n", + " val_set,\n", + " test_set,\n", + " batch_size = 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **3. Train model**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.1 Define training hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# hyperparameters\n", + "hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10, \n", + " learning_rate = 5e-3, loss_function='DiceCE',\n", + " weight_decay=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 Train model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# training model\n", + "qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **4. Test model**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "qim3d", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}