diff --git a/docs/notebooks/Unet.ipynb b/docs/notebooks/Unet.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d993b7d0dfd72611782b1a7ad47403a90c91a9ec
--- /dev/null
+++ b/docs/notebooks/Unet.ipynb
@@ -0,0 +1,437 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dd6781ce",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fa88080a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from glob import glob\n",
+    "from os.path import join\n",
+    "import os\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from skimage.io import imread\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "from torch.utils.data import DataLoader\n",
+    "from tqdm import tqdm, trange\n",
+    "from monai.networks.nets import UNet\n",
+    "from torchvision import transforms\n",
+    "from monai.losses import FocalLoss, DiceLoss\n",
+    "import qim3d\n",
+    "import albumentations as A\n",
+    "from albumentations.pytorch import ToTensorV2\n",
+    "\n",
+    "%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d0a5eade",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define function for getting dataset path from string\n",
+    "def get_dataset_path(name: str):\n",
+    "    datasets = [\n",
+    "        'belialev2020_side',\n",
+    "        'gaudez2022_3d',\n",
+    "        'guo2023_2d',\n",
+    "        'stan2020_2d',\n",
+    "        'reichardt2021_2d',\n",
+    "        'testcircles_2dbinary',\n",
+    "    ]\n",
+    "    assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)\n",
+    "\n",
+    "    dataset_idx = datasets.index(name)\n",
+    "\n",
+    "    datasets_path = [\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n",
+    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n",
+    "    ]\n",
+    "\n",
+    "    return datasets_path[dataset_idx]"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "ee235f48",
+   "metadata": {},
+   "source": [
+    "# Data loading"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "c088ceb8",
+   "metadata": {},
+   "source": [
+    "### Check out https://albumentations.ai/docs/getting_started/transforms_and_targets/"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ddfef29e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Training set transformation\n",
+    "aug_train = A.Compose([\n",
+    "    A.Resize(832,832),\n",
+    "    A.RandomRotate90(),\n",
+    "    A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]\n",
+    "    ToTensorV2()\n",
+    "])\n",
+    "\n",
+    "# Validation/test set transformation\n",
+    "aug_val_test = A.Compose([\n",
+    "    A.Resize(832,832),\n",
+    "    A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]\n",
+    "    ToTensorV2()\n",
+    "])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "766f0f8e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Possible datasets ####\n",
+    "\n",
+    "# 'belialev2020_side'\n",
+    "# 'gaudez2022_3d'\n",
+    "# 'guo2023_2d'\n",
+    "# 'stan2020_2d'\n",
+    "# 'reichardt2021_2d'\n",
+    "# 'testcircles_2dbinary'\n",
+    "\n",
+    "# Choose dataset\n",
+    "dataset = 'stan2020_2d'\n",
+    "\n",
+    "# Define class instances. First, both train and validation set is defined from train \n",
+    "# folder with different transformations and below divided into non-overlapping subsets\n",
+    "train_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_train)\n",
+    "val_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_val_test)\n",
+    "test_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),split='test',transform=aug_val_test)\n",
+    "\n",
+    "# Define fraction of training set used for validation\n",
+    "VAL_FRACTION = 0.3\n",
+    "split_idx = int(np.floor(VAL_FRACTION * len(train_set)))\n",
+    "\n",
+    "# Define seed\n",
+    "# torch.manual_seed(42)\n",
+    "\n",
+    "# Get randomly permuted indices \n",
+    "indices = torch.randperm(len(train_set))\n",
+    "\n",
+    "# Define train and validation sets as subsets\n",
+    "train_set = torch.utils.data.Subset(train_set, indices[split_idx:])\n",
+    "val_set = torch.utils.data.Subset(val_set, indices[:split_idx])\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "321495cc",
+   "metadata": {},
+   "source": [
+    "### Data overview"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a794b739",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Check if data has mask\n",
+    "has_mask= False #True if train_set[0][-1] is not None else False\n",
+    "\n",
+    "print(f'No. of train images={len(train_set)}')\n",
+    "print(f'No. of validation images={len(val_set)}')\n",
+    "print(f'No. of test images={len(test_set)}')\n",
+    "print(f'{train_set[0][0].dtype=}')\n",
+    "print(f'{train_set[0][1].dtype=}')\n",
+    "print(f'image shape={train_set[0][0].shape}')\n",
+    "print(f'label shape={train_set[0][1].shape}')\n",
+    "print(f'Labels={np.unique(train_set[0][1])}')\n",
+    "print(f'Masked data? {has_mask}')"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "5efa7d33",
+   "metadata": {},
+   "source": [
+    "### Data visualization\n",
+    "\n",
+    "Display first seven image, labels, and masks if they exist"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "170577d3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "qim3d.qim3d.viz.grid_overview(train_set,num_images=6,alpha=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "33368063",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define batch sizes\n",
+    "TRAIN_BATCH_SIZE = 4\n",
+    "VAL_BATCH_SIZE = 4\n",
+    "TEST_BATCH_SIZE = 4\n",
+    "\n",
+    "# Define dataloaders\n",
+    "train_loader = DataLoader(dataset=train_set, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)\n",
+    "val_loader = DataLoader(dataset=val_set, batch_size=VAL_BATCH_SIZE, num_workers=8, pin_memory=True)\n",
+    "test_loader = DataLoader(dataset=test_set, batch_size=TEST_BATCH_SIZE, num_workers=8, pin_memory=True)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "35e83e38",
+   "metadata": {},
+   "source": [
+    "# Train U-Net"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "36685b25",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define model\n",
+    "model = UNet(\n",
+    "    spatial_dims=2,\n",
+    "    in_channels=1, \n",
+    "    out_channels=1, \n",
+    "    channels=(64, 128, 256, 512, 1024), \n",
+    "    strides=(2, 2, 2, 2), \n",
+    ")\n",
+    "\n",
+    "orig_state = model.state_dict()  # Save, so we can reset model to original state later\n",
+    "\n",
+    "# Define loss function\n",
+    "#loss_fn = nn.CrossEntropyLoss()\n",
+    "loss_fn = FocalLoss()\n",
+    "\n",
+    "# Define device\n",
+    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "137be29b",
+   "metadata": {},
+   "source": [
+    "### Run training"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "13d8a9f3",
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "# Define hyperparameters\n",
+    "NUM_EPOCHS = 5\n",
+    "EVAL_EVERY = 1\n",
+    "PRINT_EVERY = 1\n",
+    "LR = 3e-3\n",
+    "\n",
+    "\n",
+    "model.load_state_dict(orig_state)  # Restart training every time\n",
+    "model.to(device)\n",
+    "\n",
+    "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
+    "\n",
+    "all_losses = []\n",
+    "all_val_loss = []\n",
+    "for epoch in range(NUM_EPOCHS):\n",
+    "    model.train()\n",
+    "    epoch_loss = 0\n",
+    "    step = 0\n",
+    "    for data in train_loader:\n",
+    "        if has_mask:\n",
+    "            inputs, targets, masks = data\n",
+    "            masks = masks.to(device).float()\n",
+    "        else:\n",
+    "            inputs, targets = data\n",
+    "\n",
+    "        inputs = inputs.to(device)\n",
+    "        targets = targets.to(device).float().unsqueeze(1)\n",
+    "        \n",
+    "        # Forward -> Backward -> Step\n",
+    "        optimizer.zero_grad()\n",
+    "\n",
+    "        outputs = model(inputs)\n",
+    "\n",
+    "        #print(f'input {inputs.shape}, target: {targets.shape}, output: {outputs.shape}')\n",
+    "    \n",
+    "        loss = loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)\n",
+    "\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "\n",
+    "        epoch_loss += loss.detach()\n",
+    "        step += 1\n",
+    "        \n",
+    "    # Log and store average epoch loss\n",
+    "    epoch_loss = epoch_loss.item() / step\n",
+    "    all_losses.append(epoch_loss)\n",
+    "\n",
+    "    if epoch % EVAL_EVERY == 0:\n",
+    "        model.eval()\n",
+    "        with torch.no_grad():  # Do not need gradients for this part\n",
+    "            loss_sum = 0\n",
+    "            step = 0\n",
+    "            for data in val_loader:\n",
+    "                if has_mask:\n",
+    "                    inputs, targets, masks = data\n",
+    "                    masks = masks.to(device).float()\n",
+    "                else:\n",
+    "                    inputs, targets = data\n",
+    "                \n",
+    "                inputs = inputs.to(device)\n",
+    "                targets = targets.to(device).float().unsqueeze(1)\n",
+    "                \n",
+    "                outputs = model(inputs)\n",
+    "                \n",
+    "                loss_sum += loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)\n",
+    "                step += 1\n",
+    "                \n",
+    "            val_loss = loss_sum.item() / step\n",
+    "            all_val_loss.append(val_loss)\n",
+    "\n",
+    "        # Log and store average accuracy\n",
+    "        if epoch % PRINT_EVERY == 0:\n",
+    "            print(f'Epoch {epoch: 3}, train loss: {epoch_loss:.4f}, val loss: {val_loss:.4f}')\n",
+    "\n",
+    "print('Min val loss:', min(all_val_loss))"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "a7a8e9d7",
+   "metadata": {},
+   "source": [
+    "### Plot train and validation loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "851463c8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(figsize=(16, 3))\n",
+    "plt.plot(all_losses, '-', label='Train')\n",
+    "plt.plot(all_val_loss, '-', label='Val.')\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "1a700f8a",
+   "metadata": {},
+   "source": [
+    "### Inspecting the Predicted Segmentations on training data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2ac83638",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "qim3d.qim3d.viz.grid_pred(train_set,model,num_images=5,alpha=1)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "id": "a176ff96",
+   "metadata": {},
+   "source": [
+    "### Inspecting the Predicted Segmentations on test data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ffb261c2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "qim3d.qim3d.viz.grid_pred(test_set,model,alpha=1)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/qim3d/__init__.py b/qim3d/__init__.py
index f93f3086827d709fcb070c579467f7a506846b53..e80b817639db15f8abaf702b683244721ba0d7c3 100644
--- a/qim3d/__init__.py
+++ b/qim3d/__init__.py
@@ -1,5 +1,5 @@
 import qim3d.io
 import qim3d.gui
-import qim3d.tools
+import qim3d.utils
 import qim3d.viz
 import logging
\ No newline at end of file
diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py
index 597556ce146a90b8f1206676c6ffbfd9eb521d5b..57dab039201a5a9b5e39b8f2c067990b33276a59 100644
--- a/qim3d/gui/data_explorer.py
+++ b/qim3d/gui/data_explorer.py
@@ -1,7 +1,7 @@
 import gradio as gr
 import numpy as np
 import os
-from qim3d.tools import internal_tools
+from qim3d.utils import internal_tools
 from qim3d.io import DataLoader
 from qim3d.io.logger import log
 import tifffile
diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py
index 5e1571913d9897243d685e1f38ca8753b3a509c8..728902f25479e55a59546aa4f62ae97cf8c9d40a 100644
--- a/qim3d/gui/iso3d.py
+++ b/qim3d/gui/iso3d.py
@@ -1,7 +1,7 @@
 import gradio as gr
 import numpy as np
 import os
-from qim3d.tools import internal_tools
+from qim3d.utils import internal_tools
 from qim3d.io import DataLoader
 from qim3d.io.logger import log
 import plotly.graph_objects as go
@@ -44,7 +44,6 @@ class Interface:
         return None
 
     def load_data(self, filepath):
-
         # TODO: Add support for multiple files
         self.vol = DataLoader().load_tiff(filepath)
 
diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py
index b1368a30d92e087f2e2b0b4d7f387ccc6e97f1af..39fcfbb033ec710de044d26f7be12e74d134aff0 100644
--- a/qim3d/gui/local_thickness.py
+++ b/qim3d/gui/local_thickness.py
@@ -1,7 +1,7 @@
 import gradio as gr
 import numpy as np
 import os
-from qim3d.tools import internal_tools
+from qim3d.utils import internal_tools
 from qim3d.io import DataLoader
 from qim3d.io.logger import log
 import tifffile
diff --git a/qim3d/io/load.py b/qim3d/io/load.py
index 4e92fff8aa4ccef3b12a94bcafd69182ca2dbdda..24fa1d331264dfc86f5f2ebe09679e4dd21327c6 100644
--- a/qim3d/io/load.py
+++ b/qim3d/io/load.py
@@ -6,7 +6,7 @@ import difflib
 import tifffile
 import h5py
 from qim3d.io.logger import log
-from qim3d.tools.internal_tools import sizeof
+from qim3d.utils.internal_tools import sizeof
 
 
 class DataLoader:
diff --git a/qim3d/tools/__init__.py b/qim3d/utils/__init__.py
similarity index 53%
rename from qim3d/tools/__init__.py
rename to qim3d/utils/__init__.py
index 834fbdb8613bdae4eee01dc776bf878ee1097fcb..460718e9d2c5a27fbba7ba76320d833c5dfe92c3 100644
--- a/qim3d/tools/__init__.py
+++ b/qim3d/utils/__init__.py
@@ -1,2 +1,2 @@
-from . import *
 from . import internal_tools
+from .data import Dataset
\ No newline at end of file
diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..5abcaf19f1a10cdc4a6177b028cd9901c8d52f79
--- /dev/null
+++ b/qim3d/utils/data.py
@@ -0,0 +1,71 @@
+"""Provides a custom Dataset class for building a PyTorch dataset"""
+from pathlib import Path
+from PIL import Image
+import torch
+import numpy as np
+
+
+class Dataset(torch.utils.data.Dataset):
+    """
+    Custom Dataset class for building a PyTorch dataset
+
+    Args:
+        root_path (str): The root directory path of the dataset.
+        split (str, optional): The split of the dataset, either "train" or "test".
+            Default is "train".
+        transform (callable, optional): A callable function or transformation to 
+            be applied to the data. Default is None.
+
+    Raises:
+        ValueError: If the provided split is not valid (neither "train" nor "test").
+
+    Attributes:
+        split (str): The split of the dataset ("train" or "test").
+        transform (callable): The transformation to be applied to the data.
+        sample_images (list): A list containing the paths to the sample images in the dataset.
+        sample_targets (list): A list containing the paths to the corresponding target images
+            in the dataset.
+
+    Methods:
+        __len__(): Returns the total number of samples in the dataset.
+        __getitem__(idx): Returns the image and its target segmentation at the given index.
+
+    Usage:
+        dataset = Dataset(root_path="path/to/dataset", split="train", 
+            transform=albumentations.Compose([ToTensorV2()]))
+        image, target = dataset[idx]
+    """
+    def __init__(self, root_path: str, split="train", transform=None):
+        super().__init__()
+
+        # Check if split is valid
+        if split not in ["train", "test"]:
+            raise ValueError("Split must be either train or test")
+
+        self.split = split
+        self.transform = transform
+
+        path = Path(root_path) / split
+
+        self.sample_images = [file for file in sorted((path / "images").iterdir())]
+        self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
+        assert len(self.sample_images) == len(self.sample_targets)
+        
+    def __len__(self):
+        return len(self.sample_images)
+
+    def __getitem__(self, idx):
+        image_path = self.sample_images[idx]
+        target_path = self.sample_targets[idx]
+
+        image = Image.open(str(image_path))
+        image = np.array(image)
+        target = Image.open(str(target_path))
+        target = np.array(target)
+
+        if self.transform:
+            transformed = self.transform(image=image, mask=target)
+            image = transformed["image"]
+            target = transformed["mask"]
+
+        return image, target
diff --git a/qim3d/tools/internal_tools.py b/qim3d/utils/internal_tools.py
similarity index 100%
rename from qim3d/tools/internal_tools.py
rename to qim3d/utils/internal_tools.py