diff --git a/docs/notebooks/Unet.ipynb b/docs/notebooks/Unet.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d993b7d0dfd72611782b1a7ad47403a90c91a9ec --- /dev/null +++ b/docs/notebooks/Unet.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "dd6781ce", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa88080a", + "metadata": {}, + "outputs": [], + "source": [ + "from glob import glob\n", + "from os.path import join\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from skimage.io import imread\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm, trange\n", + "from monai.networks.nets import UNet\n", + "from torchvision import transforms\n", + "from monai.losses import FocalLoss, DiceLoss\n", + "import qim3d\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0a5eade", + "metadata": {}, + "outputs": [], + "source": [ + "# Define function for getting dataset path from string\n", + "def get_dataset_path(name: str):\n", + " datasets = [\n", + " 'belialev2020_side',\n", + " 'gaudez2022_3d',\n", + " 'guo2023_2d',\n", + " 'stan2020_2d',\n", + " 'reichardt2021_2d',\n", + " 'testcircles_2dbinary',\n", + " ]\n", + " assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)\n", + "\n", + " dataset_idx = datasets.index(name)\n", + "\n", + " datasets_path = [\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n", + " ]\n", + "\n", + " return datasets_path[dataset_idx]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ee235f48", + "metadata": {}, + "source": [ + "# Data loading" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c088ceb8", + "metadata": {}, + "source": [ + "### Check out https://albumentations.ai/docs/getting_started/transforms_and_targets/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddfef29e", + "metadata": {}, + "outputs": [], + "source": [ + "# Training set transformation\n", + "aug_train = A.Compose([\n", + " A.Resize(832,832),\n", + " A.RandomRotate90(),\n", + " A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]\n", + " ToTensorV2()\n", + "])\n", + "\n", + "# Validation/test set transformation\n", + "aug_val_test = A.Compose([\n", + " A.Resize(832,832),\n", + " A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]\n", + " ToTensorV2()\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "766f0f8e", + "metadata": {}, + "outputs": [], + "source": [ + "### Possible datasets ####\n", + "\n", + "# 'belialev2020_side'\n", + "# 'gaudez2022_3d'\n", + "# 'guo2023_2d'\n", + "# 'stan2020_2d'\n", + "# 'reichardt2021_2d'\n", + "# 'testcircles_2dbinary'\n", + "\n", + "# Choose dataset\n", + "dataset = 'stan2020_2d'\n", + "\n", + "# Define class instances. First, both train and validation set is defined from train \n", + "# folder with different transformations and below divided into non-overlapping subsets\n", + "train_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_train)\n", + "val_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_val_test)\n", + "test_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),split='test',transform=aug_val_test)\n", + "\n", + "# Define fraction of training set used for validation\n", + "VAL_FRACTION = 0.3\n", + "split_idx = int(np.floor(VAL_FRACTION * len(train_set)))\n", + "\n", + "# Define seed\n", + "# torch.manual_seed(42)\n", + "\n", + "# Get randomly permuted indices \n", + "indices = torch.randperm(len(train_set))\n", + "\n", + "# Define train and validation sets as subsets\n", + "train_set = torch.utils.data.Subset(train_set, indices[split_idx:])\n", + "val_set = torch.utils.data.Subset(val_set, indices[:split_idx])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "321495cc", + "metadata": {}, + "source": [ + "### Data overview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a794b739", + "metadata": {}, + "outputs": [], + "source": [ + "# Check if data has mask\n", + "has_mask= False #True if train_set[0][-1] is not None else False\n", + "\n", + "print(f'No. of train images={len(train_set)}')\n", + "print(f'No. of validation images={len(val_set)}')\n", + "print(f'No. of test images={len(test_set)}')\n", + "print(f'{train_set[0][0].dtype=}')\n", + "print(f'{train_set[0][1].dtype=}')\n", + "print(f'image shape={train_set[0][0].shape}')\n", + "print(f'label shape={train_set[0][1].shape}')\n", + "print(f'Labels={np.unique(train_set[0][1])}')\n", + "print(f'Masked data? {has_mask}')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "5efa7d33", + "metadata": {}, + "source": [ + "### Data visualization\n", + "\n", + "Display first seven image, labels, and masks if they exist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "170577d3", + "metadata": {}, + "outputs": [], + "source": [ + "qim3d.qim3d.viz.grid_overview(train_set,num_images=6,alpha=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33368063", + "metadata": {}, + "outputs": [], + "source": [ + "# Define batch sizes\n", + "TRAIN_BATCH_SIZE = 4\n", + "VAL_BATCH_SIZE = 4\n", + "TEST_BATCH_SIZE = 4\n", + "\n", + "# Define dataloaders\n", + "train_loader = DataLoader(dataset=train_set, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)\n", + "val_loader = DataLoader(dataset=val_set, batch_size=VAL_BATCH_SIZE, num_workers=8, pin_memory=True)\n", + "test_loader = DataLoader(dataset=test_set, batch_size=TEST_BATCH_SIZE, num_workers=8, pin_memory=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "35e83e38", + "metadata": {}, + "source": [ + "# Train U-Net" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36685b25", + "metadata": {}, + "outputs": [], + "source": [ + "# Define model\n", + "model = UNet(\n", + " spatial_dims=2,\n", + " in_channels=1, \n", + " out_channels=1, \n", + " channels=(64, 128, 256, 512, 1024), \n", + " strides=(2, 2, 2, 2), \n", + ")\n", + "\n", + "orig_state = model.state_dict() # Save, so we can reset model to original state later\n", + "\n", + "# Define loss function\n", + "#loss_fn = nn.CrossEntropyLoss()\n", + "loss_fn = FocalLoss()\n", + "\n", + "# Define device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "137be29b", + "metadata": {}, + "source": [ + "### Run training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13d8a9f3", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Define hyperparameters\n", + "NUM_EPOCHS = 5\n", + "EVAL_EVERY = 1\n", + "PRINT_EVERY = 1\n", + "LR = 3e-3\n", + "\n", + "\n", + "model.load_state_dict(orig_state) # Restart training every time\n", + "model.to(device)\n", + "\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", + "\n", + "all_losses = []\n", + "all_val_loss = []\n", + "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", + " epoch_loss = 0\n", + " step = 0\n", + " for data in train_loader:\n", + " if has_mask:\n", + " inputs, targets, masks = data\n", + " masks = masks.to(device).float()\n", + " else:\n", + " inputs, targets = data\n", + "\n", + " inputs = inputs.to(device)\n", + " targets = targets.to(device).float().unsqueeze(1)\n", + " \n", + " # Forward -> Backward -> Step\n", + " optimizer.zero_grad()\n", + "\n", + " outputs = model(inputs)\n", + "\n", + " #print(f'input {inputs.shape}, target: {targets.shape}, output: {outputs.shape}')\n", + " \n", + " loss = loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epoch_loss += loss.detach()\n", + " step += 1\n", + " \n", + " # Log and store average epoch loss\n", + " epoch_loss = epoch_loss.item() / step\n", + " all_losses.append(epoch_loss)\n", + "\n", + " if epoch % EVAL_EVERY == 0:\n", + " model.eval()\n", + " with torch.no_grad(): # Do not need gradients for this part\n", + " loss_sum = 0\n", + " step = 0\n", + " for data in val_loader:\n", + " if has_mask:\n", + " inputs, targets, masks = data\n", + " masks = masks.to(device).float()\n", + " else:\n", + " inputs, targets = data\n", + " \n", + " inputs = inputs.to(device)\n", + " targets = targets.to(device).float().unsqueeze(1)\n", + " \n", + " outputs = model(inputs)\n", + " \n", + " loss_sum += loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)\n", + " step += 1\n", + " \n", + " val_loss = loss_sum.item() / step\n", + " all_val_loss.append(val_loss)\n", + "\n", + " # Log and store average accuracy\n", + " if epoch % PRINT_EVERY == 0:\n", + " print(f'Epoch {epoch: 3}, train loss: {epoch_loss:.4f}, val loss: {val_loss:.4f}')\n", + "\n", + "print('Min val loss:', min(all_val_loss))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a7a8e9d7", + "metadata": {}, + "source": [ + "### Plot train and validation loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "851463c8", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(16, 3))\n", + "plt.plot(all_losses, '-', label='Train')\n", + "plt.plot(all_val_loss, '-', label='Val.')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1a700f8a", + "metadata": {}, + "source": [ + "### Inspecting the Predicted Segmentations on training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ac83638", + "metadata": {}, + "outputs": [], + "source": [ + "qim3d.qim3d.viz.grid_pred(train_set,model,num_images=5,alpha=1)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a176ff96", + "metadata": {}, + "source": [ + "### Inspecting the Predicted Segmentations on test data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffb261c2", + "metadata": {}, + "outputs": [], + "source": [ + "qim3d.qim3d.viz.grid_pred(test_set,model,alpha=1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qim3d/__init__.py b/qim3d/__init__.py index f93f3086827d709fcb070c579467f7a506846b53..e80b817639db15f8abaf702b683244721ba0d7c3 100644 --- a/qim3d/__init__.py +++ b/qim3d/__init__.py @@ -1,5 +1,5 @@ import qim3d.io import qim3d.gui -import qim3d.tools +import qim3d.utils import qim3d.viz import logging \ No newline at end of file diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py index 597556ce146a90b8f1206676c6ffbfd9eb521d5b..57dab039201a5a9b5e39b8f2c067990b33276a59 100644 --- a/qim3d/gui/data_explorer.py +++ b/qim3d/gui/data_explorer.py @@ -1,7 +1,7 @@ import gradio as gr import numpy as np import os -from qim3d.tools import internal_tools +from qim3d.utils import internal_tools from qim3d.io import DataLoader from qim3d.io.logger import log import tifffile diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py index 5e1571913d9897243d685e1f38ca8753b3a509c8..728902f25479e55a59546aa4f62ae97cf8c9d40a 100644 --- a/qim3d/gui/iso3d.py +++ b/qim3d/gui/iso3d.py @@ -1,7 +1,7 @@ import gradio as gr import numpy as np import os -from qim3d.tools import internal_tools +from qim3d.utils import internal_tools from qim3d.io import DataLoader from qim3d.io.logger import log import plotly.graph_objects as go @@ -44,7 +44,6 @@ class Interface: return None def load_data(self, filepath): - # TODO: Add support for multiple files self.vol = DataLoader().load_tiff(filepath) diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py index b1368a30d92e087f2e2b0b4d7f387ccc6e97f1af..39fcfbb033ec710de044d26f7be12e74d134aff0 100644 --- a/qim3d/gui/local_thickness.py +++ b/qim3d/gui/local_thickness.py @@ -1,7 +1,7 @@ import gradio as gr import numpy as np import os -from qim3d.tools import internal_tools +from qim3d.utils import internal_tools from qim3d.io import DataLoader from qim3d.io.logger import log import tifffile diff --git a/qim3d/io/load.py b/qim3d/io/load.py index 4e92fff8aa4ccef3b12a94bcafd69182ca2dbdda..24fa1d331264dfc86f5f2ebe09679e4dd21327c6 100644 --- a/qim3d/io/load.py +++ b/qim3d/io/load.py @@ -6,7 +6,7 @@ import difflib import tifffile import h5py from qim3d.io.logger import log -from qim3d.tools.internal_tools import sizeof +from qim3d.utils.internal_tools import sizeof class DataLoader: diff --git a/qim3d/tools/__init__.py b/qim3d/utils/__init__.py similarity index 53% rename from qim3d/tools/__init__.py rename to qim3d/utils/__init__.py index 834fbdb8613bdae4eee01dc776bf878ee1097fcb..460718e9d2c5a27fbba7ba76320d833c5dfe92c3 100644 --- a/qim3d/tools/__init__.py +++ b/qim3d/utils/__init__.py @@ -1,2 +1,2 @@ -from . import * from . import internal_tools +from .data import Dataset \ No newline at end of file diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5abcaf19f1a10cdc4a6177b028cd9901c8d52f79 --- /dev/null +++ b/qim3d/utils/data.py @@ -0,0 +1,71 @@ +"""Provides a custom Dataset class for building a PyTorch dataset""" +from pathlib import Path +from PIL import Image +import torch +import numpy as np + + +class Dataset(torch.utils.data.Dataset): + """ + Custom Dataset class for building a PyTorch dataset + + Args: + root_path (str): The root directory path of the dataset. + split (str, optional): The split of the dataset, either "train" or "test". + Default is "train". + transform (callable, optional): A callable function or transformation to + be applied to the data. Default is None. + + Raises: + ValueError: If the provided split is not valid (neither "train" nor "test"). + + Attributes: + split (str): The split of the dataset ("train" or "test"). + transform (callable): The transformation to be applied to the data. + sample_images (list): A list containing the paths to the sample images in the dataset. + sample_targets (list): A list containing the paths to the corresponding target images + in the dataset. + + Methods: + __len__(): Returns the total number of samples in the dataset. + __getitem__(idx): Returns the image and its target segmentation at the given index. + + Usage: + dataset = Dataset(root_path="path/to/dataset", split="train", + transform=albumentations.Compose([ToTensorV2()])) + image, target = dataset[idx] + """ + def __init__(self, root_path: str, split="train", transform=None): + super().__init__() + + # Check if split is valid + if split not in ["train", "test"]: + raise ValueError("Split must be either train or test") + + self.split = split + self.transform = transform + + path = Path(root_path) / split + + self.sample_images = [file for file in sorted((path / "images").iterdir())] + self.sample_targets = [file for file in sorted((path / "labels").iterdir())] + assert len(self.sample_images) == len(self.sample_targets) + + def __len__(self): + return len(self.sample_images) + + def __getitem__(self, idx): + image_path = self.sample_images[idx] + target_path = self.sample_targets[idx] + + image = Image.open(str(image_path)) + image = np.array(image) + target = Image.open(str(target_path)) + target = np.array(target) + + if self.transform: + transformed = self.transform(image=image, mask=target) + image = transformed["image"] + target = transformed["mask"] + + return image, target diff --git a/qim3d/tools/internal_tools.py b/qim3d/utils/internal_tools.py similarity index 100% rename from qim3d/tools/internal_tools.py rename to qim3d/utils/internal_tools.py