Skip to content
Snippets Groups Projects
Unet.ipynb 5.56 KiB
Newer Older
  • Learn to ignore specific revisions
  • s184058's avatar
    s184058 committed
    {
     "cells": [
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "be66055b-8ee9-46be-ad9d-f15edf2654a4",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
    
       "outputs": [],
    
    s184058's avatar
    s184058 committed
       "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "0c61dd11-5a2b-44ff-b0e5-989360bbb677",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "outputs": [],
       "source": [
        "from os.path import join\n",
        "import qim3d\n",
        "\n",
        "%matplotlib inline"
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "cd6bb832-1297-462f-8d35-1738a9c37ffd",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "outputs": [],
       "source": [
        "# Define function for getting dataset path from string\n",
    
        "def get_dataset_path(name: str, datasets):\n",
    
    s184058's avatar
    s184058 committed
        "    assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)\n",
        "    dataset_idx = datasets.index(name)\n",
        "    datasets_path = [\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n",
        "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n",
        "    ]\n",
        "    return datasets_path[dataset_idx]"
       ]
      },
      {
       "cell_type": "markdown",
    
       "id": "7d07077a-cce3-4448-89f5-02413345becc",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "source": [
    
        "### Datasets"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "9a3b9c3c-4bbb-4a19-9685-f68c437e8bee",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "outputs": [],
       "source": [
    
        "datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']\n",
        "dataset = datasets[3] \n",
        "root = get_dataset_path(dataset,datasets)\n",
        "\n",
        "# should not use gaudez2022: 3d image\n",
        "# reichardt2021: multiclass segmentation"
    
       "cell_type": "markdown",
       "id": "254dc8cb-6f24-4b57-91c0-98fb6f62602c",
    
    ofhkr's avatar
    ofhkr committed
       "metadata": {},
       "source": [
    
        "### Model and Augmentation"
    
    ofhkr's avatar
    ofhkr committed
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "30098003-ec06-48e0-809f-82f44166fb2b",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "outputs": [],
       "source": [
    
        "# defining model\n",
    
        "my_model = qim3d.models.UNet(size = 'medium', dropout = 0.25)\n",
    
        "# defining augmentation\n",
    
        "my_aug = qim3d.utils.Augmentation(resize = 'crop', transform_train = 'light')"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "markdown",
    
       "id": "7b56c654-720d-4c5f-8545-749daa5dbaf2",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "source": [
    
        "### Loading the data"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "84141298-054d-4322-8bda-5ec514528985",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
    
    s184058's avatar
    s184058 committed
       "source": [
    
        "# level of logging\n",
        "qim3d.io.logger.level('info')\n",
    
    s184058's avatar
    s184058 committed
        "\n",
    
        "# datasets and dataloaders\n",
    
        "train_set, val_set, test_set = qim3d.utils.prepare_datasets(path = root, val_fraction = 0.3,\n",
        "                                                            model = my_model , augmentation = my_aug)\n",
        "\n",
        "train_loader, val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set, val_set,\n",
        "                                                                        test_set, batch_size = 6)"
       ]
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "id": "f320a4ae-f063-430c-b5a0-0d9fb64c2725",
       "metadata": {},
       "outputs": [],
       "source": [
        "qim3d.viz.grid_overview(train_set,alpha = 1)"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "7fa3aa57-ba61-4c9a-934c-dce26bbc9e97",
    
       "outputs": [],
    
    s184058's avatar
    s184058 committed
       "source": [
    
        "# Summary of model\n",
        "model_s = qim3d.utils.model_summary(train_loader,my_model)\n",
    
        "print(model_s)"
    
       "cell_type": "markdown",
       "id": "a665ae28-d9a6-419f-9131-54283b47582c",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "source": [
    
        "### Hyperparameters and training"
    
    ofhkr's avatar
    ofhkr committed
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "ce64ae65-01fb-45a9-bdcb-a3806de8469e",
    
    ofhkr's avatar
    ofhkr committed
       "metadata": {},
    
    s184058's avatar
    s184058 committed
       "source": [
    
        "# model hyperparameters\n",
    
        "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=25,\n",
        "                                                  learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)\n",
    
        "\n",
        "# training model\n",
        "qim3d.utils.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "markdown",
    
       "id": "7e14fac8-4fd3-4725-bd0d-9e2a95552278",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
       "source": [
    
        "### Plotting"
    
    s184058's avatar
    s184058 committed
       ]
      },
      {
       "cell_type": "code",
    
       "execution_count": null,
    
       "id": "f8684cb0-5673-4409-8d22-f00b7d099ca4",
    
    s184058's avatar
    s184058 committed
       "metadata": {},
    
    s184058's avatar
    s184058 committed
       "source": [
    
        "in_targ_preds_test = qim3d.utils.inference(test_set,my_model)\n",
        "qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)"
    
    s184058's avatar
    s184058 committed
       ]
      }
     ],
     "metadata": {
      "kernelspec": {
       "display_name": "Python 3 (ipykernel)",
       "language": "python",
       "name": "python3"
      },
      "language_info": {
       "codemirror_mode": {
        "name": "ipython",
        "version": 3
       },
       "file_extension": ".py",
       "mimetype": "text/x-python",
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
    
    fima's avatar
    fima committed
       "version": "3.10.11"
    
    s184058's avatar
    s184058 committed
      }
     },
     "nbformat": 4,
     "nbformat_minor": 5
    }