Skip to content
Snippets Groups Projects
UNet.ipynb 8.96 KiB
Newer Older
  • Learn to ignore specific revisions
  • s193396's avatar
    s193396 committed
    {
     "cells": [
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "# **Deep Learning Volume Segmentation (3D UNet)**\n",
        "\n",
        "Authors: Anna Ekner (s193396@dtu.dk)\n",
        "\n",
        "This notebook aims to demonstrate the feasibility of implementing a comprehensive deep learning segmentation pipeline solely leveraging the capabilities offered by the `qim3d` library. Specifically, it highlights the use of the synthetic data generation functionalities to create a volumetric dataset with associated labels, and walks through the process of creating and training a 3D UNet model using this synthetic dataset."
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 13,
       "metadata": {},
       "outputs": [],
       "source": [
        "import qim3d\n",
        "import glob\n",
        "import os\n",
        "import numpy as np"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "### **1. Generate synthetic dataset**"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 1.1 Example of data sample (probably should be after creating the dataset??)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "Synthetic dataset and associated labels."
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 6,
       "metadata": {},
       "outputs": [
        {
         "data": {
          "application/vnd.jupyter.widget-view+json": {
           "model_id": "d9b9bb97fe7d4828b5b7f359cb8ac50a",
           "version_major": 2,
           "version_minor": 0
          },
          "text/plain": [
           "Objects placed:   0%|          | 0/5 [00:00<?, ?it/s]"
          ]
         },
         "metadata": {},
         "output_type": "display_data"
        }
       ],
       "source": [
        "num_objects = 5\n",
        "vol, labels = qim3d.generate.noise_object_collection(\n",
        "    num_objects = num_objects,\n",
        "    collection_shape = (128, 128, 128),\n",
        "    min_object_noise = 0.03, \n",
        "    max_object_noise = 0.08,\n",
        "    )"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 10,
       "metadata": {},
       "outputs": [],
       "source": [
        "# Visualize synthetic collection\n",
        "# qim3d.viz.volumetric(vol)"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 30,
       "metadata": {},
       "outputs": [],
       "source": [
        "# Visualize slices\n",
        "# qim3d.viz.slicer(vol)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "There will be $N + 1$ unique labels, because one extra for background.\n",
        "But we want only 2 labels: foreground and background."
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 3,
       "metadata": {},
       "outputs": [],
       "source": [
        "# Convert N + 1 labels into 2 labels (background and object)\n",
        "labels = (labels > 0).astype(int)"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 29,
       "metadata": {},
       "outputs": [],
       "source": [
        "# Visualize labels\n",
        "# qim3d.viz.slicer(labels)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 1.2 Create folder structure"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 12,
       "metadata": {},
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "Creating directories:\n",
          "C:\\Users\\s193396/dataset\\train\\images\n",
          "C:\\Users\\s193396/dataset\\train\\labels\n",
          "C:\\Users\\s193396/dataset\\test\\images\n",
          "C:\\Users\\s193396/dataset\\test\\labels\n"
         ]
        }
       ],
       "source": [
        "# Base path for the training data\n",
        "base_path = os.path.expanduser(\"~/dataset\")\n",
        "\n",
        "# Create directories\n",
        "print(\"Creating directories:\")\n",
        "for folder_split in [\"train\", \"test\"]:\n",
        "    for folder_type in [\"images\", \"labels\"]:\n",
        "        path = os.path.join(base_path, folder_split, folder_type)\n",
        "        os.makedirs(path, exist_ok=True)\n",
        "        print(path)\n",
        "\n",
        "# Here we have the option to remove any previous files\n",
        "clean_files = True\n",
        "if clean_files:\n",
        "    for root, dirs, files in os.walk(base_path):\n",
        "        for file in files:\n",
        "            file_path = os.path.join(root, file)\n",
        "            os.remove(file_path)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 1.3 Create dataset"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "We need to create a dataset of multiple volumes"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
        "num_samples = 5\n",
        "\n",
        "for idx in range(num_samples):\n",
        "    # TODO: Figure out whether or not the seed makes it such that all volumes are identical?\n",
        "\n",
        "    vol, label = qim3d.generate.noise_object_collection(\n",
        "        num_objects = num_objects,\n",
        "        collection_shape = (128, 128, 128),\n",
        "        min_object_noise = 0.03, \n",
        "        max_object_noise = 0.08,\n",
        "        )\n",
        "\n",
        "    # Convert N + 1 labels into 2 labels (background and object)\n",
        "    label = (labels > 0).astype(int)\n",
        "\n",
        "    # Save volume\n",
        "    qim3d.io.save(os.path.join(base_path, folder_split, \"images\", f\"{idx}.nii.gz\"), vol, replace = True)\n",
        "\n",
        "    # Save label\n",
        "    qim3d.io.save(os.path.join(base_path, folder_split, \"labels\", f\"{idx}.nii.gz\"), label, replace = True)"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
        "# volumes = sorted(glob.glob(os.path.join(base_path, \"im*.nii.gz\")))\n",
        "# labels = sorted(glob.glob(os.path.join(base_path, \"seg*.nii.gz\")))"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "### **2. Build 3D UNet model**"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 2.1 Instantiate UNet model"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 4,
       "metadata": {},
       "outputs": [],
       "source": [
        "model = qim3d.ml.models.UNet(size = 'small', dropout = 0.25)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 2.2 Define augmentations"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 5,
       "metadata": {},
       "outputs": [],
       "source": [
        "augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 2.3 Divide dataset into train and test splits"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
        "# datasets and dataloaders\n",
        "train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = base_path,\n",
        "                                                            val_fraction = 0.5,\n",
        "                                                            model = model,\n",
        "                                                            augmentation = augmentation)\n",
        "\n",
        "\n",
        "train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, \n",
        "                                                                        val_set,\n",
        "                                                                        test_set,\n",
        "                                                                        batch_size = 1)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "### **3. Train model**"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 3.1 Define training hyperparameters"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
        "# hyperparameters\n",
        "hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=10, \n",
        "                                               learning_rate = 5e-3, loss_function='DiceCE',\n",
        "                                               weight_decay=1e-3)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "#### 3.2 Train model"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
        "# training model\n",
        "qim3d.ml.train_model(model, hyperparameters, train_loader, val_loader, plot=True)"
       ]
      },
      {
       "cell_type": "markdown",
       "metadata": {},
       "source": [
        "### **4. Test model**"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": []
      }
     ],
     "metadata": {
      "kernelspec": {
       "display_name": "qim3d",
       "language": "python",
       "name": "python3"
      },
      "language_info": {
       "codemirror_mode": {
        "name": "ipython",
        "version": 3
       },
       "file_extension": ".py",
       "mimetype": "text/x-python",
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
       "version": "3.10.14"
      }
     },
     "nbformat": 4,
     "nbformat_minor": 2
    }