From d6f89ddbd3272c6e23f4d9327a1b0016cb442d34 Mon Sep 17 00:00:00 2001
From: s184058 <s184058@student.dtu.dk>
Date: Mon, 19 Feb 2024 15:28:28 +0100
Subject: [PATCH] Slice viz overhaul

---
 README.md                   |   8 +-
 docs/index.md               |   8 +-
 docs/releases.md            |   2 +-
 docs/viz.md                 |   2 +-
 qim3d/tests/viz/test_img.py |  99 +++++++++---
 qim3d/viz/__init__.py       |   2 +-
 qim3d/viz/img.py            | 289 +++++++++++++++++++++++-------------
 7 files changed, 278 insertions(+), 132 deletions(-)

diff --git a/README.md b/README.md
index b2d8dd6b..899596a1 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
 ```
 
 ## Visualize data
-You can easily check slices from your volume using `slice_viz`
+You can easily check slices from your volume using `slices`
 
 ```python
 import qim3d
@@ -40,13 +40,13 @@ import qim3d
 img = qim3d.examples.fly_150x256x256
 
 # By default shows the middle slice
-qim3d.viz.slice_viz(img)
+qim3d.viz.slices(img)
 
 # Or we can specifly positions
-qim3d.viz.slice_viz(img, position=[0,32,128])
+qim3d.viz.slices(img, position=[0,32,128])
 
 # Parameters for size and colormap are also possible
-qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno")
+qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
 
 ```
 
diff --git a/docs/index.md b/docs/index.md
index b2d702e3..ced94f6c 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -31,7 +31,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
 ```
 
 ### Visualize data
-You can easily check slices from your volume using `slice_viz`
+You can easily check slices from your volume using `slices`
 
 ```python
 import qim3d
@@ -39,13 +39,13 @@ import qim3d
 img = qim3d.examples.fly_150x256x256
 
 # By default shows the middle slice
-qim3d.viz.slice_viz(img)
+qim3d.viz.slices(img)
 
 # Or we can specifly positions
-qim3d.viz.slice_viz(img, position=[0,32,128])
+qim3d.viz.slices(img, position=[0,32,128])
 
 # Parameters for size and colormap are also possible
-qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno")
+qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
 
 ```
 
diff --git a/docs/releases.md b/docs/releases.md
index 5a72d14c..f50c0240 100644
--- a/docs/releases.md
+++ b/docs/releases.md
@@ -31,7 +31,7 @@ Includes new develoments toward the usability of the library, as well as its int
     - For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality.
 - Improved data loader
     - Now the extensions `tif`, `h5` and `txm` are supported.
-- Added `qim3d.viz.slice_viz` for easy slice visualization.
+- Added `qim3d.viz.slices` for easy slice visualization.
 - U-net model creation
     - Model availabe from `qim3d.models.UNet`
     - Data augmentation class at `qim3d.utils.Augmentation`
diff --git a/docs/viz.md b/docs/viz.md
index 5a92ef56..095e9e52 100644
--- a/docs/viz.md
+++ b/docs/viz.md
@@ -12,7 +12,7 @@ import qim3d
         members:
             - grid_overview
             - grid_pred
-            - slice_viz
+            - slices
 
 
 
diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index 06fec0d0..17cc3072 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -1,13 +1,12 @@
+import torch
+import numpy as np
 import qim3d
-import matplotlib.pyplot as plt
 import pytest
-
-from torch import ones
 from qim3d.utils.internal_tools import temp_data
 
 # unit tests for grid overview
 def test_grid_overview():
-    random_tuple = (ones(1,256,256),ones(256,256))
+    random_tuple = (torch.ones(1,256,256),torch.ones(256,256))
     n_images = 10 
     train_set = [random_tuple for t in range(n_images)]
 
@@ -16,7 +15,7 @@ def test_grid_overview():
 
 
 def test_grid_overview_tuple():
-    random_tuple = (ones(256,256),ones(256,256))
+    random_tuple = (torch.ones(256,256),torch.ones(256,256))
 
     with pytest.raises(ValueError,match="Data elements must be tuples"):
         qim3d.viz.grid_overview(random_tuple,num_images=1)
@@ -42,22 +41,88 @@ def test_grid_pred():
 
 
 # unit tests for slice visualization
-def test_slice_viz():
-    example_volume = ones(10,10,10)
+def test_slices_numpy_array_input():
+    example_volume = np.ones((10, 10, 10))
+    img_width = 3
+    fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
+    assert fig.get_figwidth() == img_width
+
+def test_slices_torch_tensor_input():
+    example_volume = torch.ones((10,10,10))
     img_width = 3
-    fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width)
+    fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
 
     assert fig.get_figwidth() == img_width
 
+def test_slices_wrong_input_format():
+    input = 'not_a_volume'
+    with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'):
+        qim3d.viz.slices(input)
+
+def test_slices_not_volume():
+    example_volume = np.ones((10,10))
+    with pytest.raises(ValueError, match = 'The provided object is not a volume as it has less than 3 dimensions.'):
+        qim3d.viz.slices(example_volume)
+
+def test_slices_wrong_position_format1():
+    example_volume = np.ones((10,10,10))
+    with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
+        qim3d.viz.slices(example_volume, position = 'invalid_slice')
+
+def test_slices_wrong_position_format2():
+    example_volume = np.ones((10,10,10))
+    with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
+        qim3d.viz.slices(example_volume, position = 1.5)
+
+def test_slices_wrong_position_format3():
+    example_volume = np.ones((10,10,10))
+    with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
+        qim3d.viz.slices(example_volume, position = [1, 2, 3.5])
+
+def test_slices_invalid_axis_value():
+    example_volume = np.ones((10,10,10))
+    with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
+        qim3d.viz.slices(example_volume, axis = 3)
+
+def test_slices_show_title_option():
+    example_volume = np.ones((10, 10, 10))
+    img_width = 3
+    fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
+    # Assert that titles are not shown
+    assert all(ax.get_title() == '' for ax in fig.get_axes())
 
-def test_slice_viz_not_volume():
-    example_volume = ones(10,10)
-    dim = example_volume.ndim
-    with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"):
-        qim3d.viz.slice_viz(example_volume)
+def test_slices_interpolation_option():
+    example_volume = torch.ones((10, 10, 10))
+    img_width = 3
+    interpolation_method = 'bilinear'
+    fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, interpolation=interpolation_method)
+
+    for ax in fig.get_axes():
+        # Access the interpolation method used for each Axes object
+        actual_interpolation = ax.images[0].get_interpolation()
+
+        # Assert that the actual interpolation method matches the expected method
+        assert actual_interpolation == interpolation_method
+
+def test_slices_multiple_slices():
+    example_volume = np.ones((10, 10, 10))
+    img_width = 3
+    n_slices = 3
+    fig = qim3d.viz.slices(example_volume, n_slices=n_slices, img_width=img_width)
+    # Add assertions for the expected number of subplots in the figure
+    assert len(fig.get_axes()) == n_slices
+
+def test_slices_axis_argument():
+    # Non-symmetric input
+    example_volume = np.arange(1000).reshape((10, 10, 10))
+    img_width = 3
 
+    # Call the function with different values of the axis
+    fig_axis_0 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=0)
+    fig_axis_1 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=1)
+    fig_axis_2 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=2)
 
-def test_slice_viz_wrong_slice():
-    example_volume = ones(10,10,10)
-    with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'):
-        qim3d.viz.slice_viz(example_volume, position = 'invalid_slice')
+    # Ensure that different axes result in different plots
+    assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
+    assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
+    assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
\ No newline at end of file
diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index 1c36c157..98132cca 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -1,2 +1,2 @@
 from .visualizations import plot_metrics
-from .img import grid_pred, grid_overview, slice_viz
+from .img import grid_pred, grid_overview, slices
diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index 804cdd28..f228800c 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -1,15 +1,22 @@
 """ 
 Provides a collection of visualization functions.
 """
+
+from typing import List, Optional, Union
 import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 import torch
 import numpy as np
 from qim3d.io.logger import log
+import math
 import qim3d.io
+import os
+
 
-def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
+def grid_overview(
+    data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
+):
     """Displays an overview grid of images, labels, and masks (if they exist).
 
     Labels are the annotated target segmentations
@@ -22,14 +29,14 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
         cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'.
         cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'.
         alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5.
-        show (bool, optional): If True, displays the plot. Defaults to False.
+        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
 
     Raises:
         ValueError: If the data elements are not tuples.
 
 
     Returns:
-        fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels.   
+        fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels.
 
     Example:
         ```python
@@ -75,7 +82,9 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
     # Make new list such that possible augmentations remain identical for all three rows
     plot_data = [data[idx] for idx in range(num_images)]
 
-    fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True)
+    fig = plt.figure(
+        figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True
+    )
 
     # create 2 (3) x 1 subfigs
     subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1)
@@ -92,15 +101,22 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
             else:
                 ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
                 ax.axis("off")
-    
+
     if show:
         plt.show()
     plt.close()
-    
+
     return fig
 
 
-def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False):
+def grid_pred(
+    in_targ_preds,
+    num_images=7,
+    cmap_im="gray",
+    cmap_segm="viridis",
+    alpha=0.5,
+    show=False,
+):
     """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
 
     Displays a grid of subplots representing different aspects of the input images and segmentations.
@@ -119,7 +135,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
         cmap_im (str, optional): Color map for input images. Defaults to "gray".
         cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
         alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
-        show (bool, optional): If True, displays the plot. Defaults to False.
+        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
 
     Returns:
         fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models.
@@ -131,9 +147,9 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
         dataset = MySegmentationDataset()
         model = MySegmentationModel()
         in_targ_preds = qim3d.utils.models.inference(dataset,model)
-        grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)        
+        grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)
     """
-    
+
     # Check if dataset have at least specified number of images
     if len(in_targ_preds[0]) < num_images:
         log.warning(
@@ -144,21 +160,21 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
         num_images = len(in_targ_preds[0])
 
     # Take only the number of images from in_targ_preds
-    inputs,targets,preds = [items[:num_images] for items in in_targ_preds]
-    
+    inputs, targets, preds = [items[:num_images] for items in in_targ_preds]
+
     # Adapt segmentation cmap so that background is transparent
     colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
     colors_segm[:128, 3] = 0
     custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
-    
+
     N = num_images
     H = inputs[0].shape[-2]
     W = inputs[0].shape[-1]
 
-    comp_rgb = torch.zeros((N,4,H,W))
-    comp_rgb[:,1,:,:] = targets.logical_and(preds)
-    comp_rgb[:,0,:,:] = targets.logical_xor(preds)
-    comp_rgb[:,3,:,:] = targets.logical_or(preds)
+    comp_rgb = torch.zeros((N, 4, H, W))
+    comp_rgb[:, 1, :, :] = targets.logical_and(preds)
+    comp_rgb[:, 0, :, :] = targets.logical_xor(preds)
+    comp_rgb[:, 3, :, :] = targets.logical_or(preds)
 
     row_titles = [
         "Input images",
@@ -187,9 +203,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
                 ax.axis("off")
             elif row == 2:  # Ground truth segmentation
                 ax.imshow(inputs[col], cmap=cmap_im)
-                ax.imshow(
-                    targets[col], cmap=custom_cmap, alpha=alpha
-                )
+                ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha)
                 ax.axis("off")
             else:
                 ax.imshow(inputs[col], cmap=cmap_im)
@@ -202,105 +216,172 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
 
     return fig
 
-def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False):
-    """ Displays one or several slices from a 3d array.
 
-    By default if `position` is None, slice_viz plots an overview of the entire stack.
-    If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position.
-    If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted.
-    
+def slices(
+    vol: Union[np.ndarray, torch.Tensor],
+    axis: int = 0,
+    position: Optional[Union[str, int, List[int]]] = None,
+    n_slices: int = 5,
+    max_cols: int = 5,
+    cmap: str = "viridis",
+    img_height: int = 2,
+    img_width: int = 2,
+    show: bool = False,
+    show_position: bool = True,
+    interpolation: Optional[str] = None,
+) -> plt.Figure:
+    """Displays one or several slices from a 3d volume.
+
+    By default if `position` is None, slices plots `n_slices` linearly spaced slices.
+    If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
+    If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.
+
     Args:
-        input (str, numpy.ndarray): Path to the file or 3-dimensional array.
-        position (str, int, list, array, optional): One or several slicing levels.
-        n_slices (int, optional): Defines how many slices the user wants.
-        cmap (str, optional): Specifies the color map for the image.
-        axis (bool, optional): Specifies whether the axes should be included.
+        vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
+        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
+        position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
+        n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
+        max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
+        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
         img_height(int, optional): Height of the figure.
         img_width(int, optional): Width of the figure.
-        show (bool, optional): If True, displays the plot. Defaults to False.
+        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
+        show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
+        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
 
     Returns:
         fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
 
     Raises:
-        ValueError: If the file or array is not a 3D volume.
-        ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
-    
+        ValueError: If the input is not a numpy.ndarray or torch.Tensor.
+        ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
+        ValueError: If the file or array is not a volume with at least 3 dimensions.
+        ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
+
     Example:
-        image_path = '/my_image_path/my_image.tif'
-        slice_viz(image_path)
+        vol_path = '/my_vol_path/my_vol.tif'
+        vol = qim3d.io.load(vol_path)
+        slices(vol, axis = 1, position = 'mid', n_slices = 3, cmap = 'viridis', img_height = 4, img_width = 4, show = True, show_position = True, interpolation = None)
     """
-    
-    # Filepath input
-    if isinstance(input,str):
-        vol = qim3d.io.load(input) # Function has its own ValueErrors
-        dim = vol.ndim        
-        
-    # Numpy array input
-    elif isinstance(input,(np.ndarray,torch.Tensor)):
-        vol = input
-        dim = input.ndim
-        
-    if dim != 3:
-        raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
 
+    # Numpy array or Torch tensor input
+    if not isinstance(vol, (np.ndarray, torch.Tensor)):
+        raise ValueError("Input must be a numpy.ndarray or torch.Tensor")
+
+    if vol.ndim < 3:
+        raise ValueError(
+            "The provided object is not a volume as it has less than 3 dimensions."
+        )
+
+    # Ensure axis is a valid choice
+    if not (0 <= axis < vol.ndim):
+        raise ValueError(
+            f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
+        )
+
+    # Get total number of slices in the specified dimension
+    n_total = vol.shape[axis]
+
+    # Position is not provided - will use linearly spaced slices
     if position is None:
-        height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int)
-    
+        slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
     # Position is a string
-    elif isinstance(position,str):
-        
-        if position.lower() in ['mid','middle']:
-            expansion_start = int(vol.shape[0]/2)
-            height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
-            
-        elif position.lower() in ['top','upper', 'start']:
-            expansion_start = 0
-            height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int)
-            
-        elif position.lower() in ['bot','bottom', 'end']:
-            expansion_start = vol.shape[0]-1
-            height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int)
-        
-        else:
-            raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".')
-
-    
-    # Position is an integer
-    elif isinstance(position,int):
-        expansion_start = position
-        n_stacks = vol.shape[0]-1
-
-        # if linspace would extend beyond n_stacks
-        if expansion_start + n_slices > n_stacks:
-            height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int)
-        
-        # if linspace would extend below 0 
-        elif expansion_start - n_slices < 0:
-            height = np.linspace(0,n_slices-1,n_slices).astype(int)
-
-        else:
-            height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
-
-    
-    # Position is a list or array of integers
-    elif isinstance(position,(list,np.ndarray)):
-        height = position
-
-    num_images = len(height)
-
-
-    fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True)
-    axs = fig.subplots(nrows = 1, ncols = num_images)
-    
-    for col, ax in enumerate(np.atleast_1d(axs)):
-        ax.imshow(vol[height[col],:,:],cmap = cmap)
-        ax.set_title(f'Slice {height[col]}', fontsize=6*img_height)
-        if not axis:
-            ax.axis('off')
-    
+    elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
+        if position.lower() == "start":
+            slice_idxs = _get_slice_range(0, n_slices, n_total)
+        elif position.lower() == "mid":
+            slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
+        elif position.lower() == "end":
+            slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
+    #  Position is an integer
+    elif isinstance(position, int):
+        slice_idxs = _get_slice_range(position, n_slices, n_total)
+    # Position is a list of integers
+    elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
+        slice_idxs = position
+    else:
+        raise ValueError(
+            'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
+        )
+
+
+    # Make grid
+    nrows = math.ceil(n_slices / max_cols)
+    ncols = min(n_slices, max_cols)
+
+    # Generate figure
+    fig, axs = plt.subplots(
+        nrows=nrows,
+        ncols=ncols,
+        figsize=(ncols * img_height, nrows * img_width),
+        constrained_layout=True,
+    )
+    if nrows == 1:
+        axs = [axs]  # Convert to a list for uniformity
+
+    # Convert Torch tensor to NumPy array in order to use the numpy.take method
+    if isinstance(vol, torch.Tensor):
+        vol = vol.numpy()
+
+    # Run through each ax of the grid
+    for i, ax_row in enumerate(axs):
+        for j, ax in enumerate(np.atleast_1d(ax_row)):
+            slice_idx = i * max_cols + j
+            try:
+                slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
+                ax.imshow(slice_img, cmap=cmap, interpolation=interpolation)
+
+                if show_position:
+                    ax.text(
+                        0.0,
+                        1.0,
+                        f"slice {slice_idxs[slice_idx]} ",
+                        transform=ax.transAxes,
+                        color="white",
+                        fontsize=8,
+                        va="top",
+                        ha="left",
+                        bbox=dict(facecolor="#303030", linewidth=0, pad=0),
+                    )
+
+                    ax.text(
+                        1.0,
+                        0.0,
+                        f"axis {axis} ",
+                        transform=ax.transAxes,
+                        color="white",
+                        fontsize=8,
+                        va="bottom",
+                        ha="right",
+                        bbox=dict(facecolor="#303030", linewidth=0, pad=0),
+                    )
+
+            except IndexError:
+                # Not a problem, because we simply do not have a slice to show
+                pass
+
+            # Hide the axis, so that we have a nice grid
+            ax.axis("off")
+
     if show:
         plt.show()
+
     plt.close()
 
-    return fig
\ No newline at end of file
+    return fig
+
+
+def _get_slice_range(position: int, n_slices: int, n_total):
+    """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
+    start_idx = position - n_slices // 2
+    end_idx = (
+        position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1
+    )
+    slice_idxs = np.arange(start_idx, end_idx)
+
+    if slice_idxs[0] < 0:
+        slice_idxs = np.arange(0, n_slices)
+    elif slice_idxs[-1] > n_total:
+        slice_idxs = np.arange(n_total - n_slices, n_total)
+
+    return slice_idxs
-- 
GitLab