From b63bf3a21ce5102138a30678afa035aaf1165c16 Mon Sep 17 00:00:00 2001
From: s184058 <s184058@student.dtu.dk>
Date: Wed, 6 Mar 2024 10:01:56 +0100
Subject: [PATCH] Simple 3d slicer

---
 qim3d/tests/viz/test_img.py | 115 +++++++++++++++++++++++++++-----
 qim3d/viz/__init__.py       |   2 +-
 qim3d/viz/img.py            | 127 ++++++++++++++++++++++++++++++++++--
 requirements.txt            |   3 +-
 setup.py                    |   3 +-
 5 files changed, 226 insertions(+), 24 deletions(-)

diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index 17cc3072..82dd9882 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -1,7 +1,9 @@
+import pytest
 import torch
 import numpy as np
+import ipywidgets as widgets
+import matplotlib.pyplot as plt
 import qim3d
-import pytest
 from qim3d.utils.internal_tools import temp_data
 
 # unit tests for grid overview
@@ -40,19 +42,17 @@ def test_grid_pred():
     temp_data(folder,remove = True)
 
 
-# unit tests for slice visualization
+# unit tests for slices function
 def test_slices_numpy_array_input():
     example_volume = np.ones((10, 10, 10))
-    img_width = 3
-    fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
-    assert fig.get_figwidth() == img_width
+    fig = qim3d.viz.slices(example_volume, n_slices=1)
+    assert isinstance(fig, plt.Figure)
 
 def test_slices_torch_tensor_input():
     example_volume = torch.ones((10,10,10))
     img_width = 3
-    fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
-
-    assert fig.get_figwidth() == img_width
+    fig = qim3d.viz.slices(example_volume,n_slices = 1)
+    assert isinstance(fig, plt.Figure)
 
 def test_slices_wrong_input_format():
     input = 'not_a_volume'
@@ -84,13 +84,6 @@ def test_slices_invalid_axis_value():
     with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
         qim3d.viz.slices(example_volume, axis = 3)
 
-def test_slices_show_title_option():
-    example_volume = np.ones((10, 10, 10))
-    img_width = 3
-    fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
-    # Assert that titles are not shown
-    assert all(ax.get_title() == '' for ax in fig.get_axes())
-
 def test_slices_interpolation_option():
     example_volume = torch.ones((10, 10, 10))
     img_width = 3
@@ -125,4 +118,94 @@ def test_slices_axis_argument():
     # Ensure that different axes result in different plots
     assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
     assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
-    assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
\ No newline at end of file
+    assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
+
+# unit tests for slicer function
+def test_slicer_with_numpy_array():
+    # Create a sample NumPy array
+    vol = np.random.rand(10, 10, 10)
+    # Call the slicer function with the NumPy array
+    slicer_obj = qim3d.viz.slicer(vol)
+    # Assert that the slicer object is created successfully
+    assert isinstance(slicer_obj, widgets.interactive)
+
+def test_slicer_with_torch_tensor():
+    # Create a sample PyTorch tensor
+    vol = torch.rand(10, 10, 10)
+    # Call the slicer function with the PyTorch tensor
+    slicer_obj = qim3d.viz.slicer(vol)
+    # Assert that the slicer object is created successfully
+    assert isinstance(slicer_obj, widgets.interactive)
+
+def test_slicer_with_different_parameters():
+    # Test with different axis values
+    for axis in range(3):
+        slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), axis=axis)
+        assert isinstance(slicer_obj, widgets.interactive)
+
+    # Test with different colormaps
+    for cmap in ["viridis", "gray", "plasma"]:
+        slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap)
+        assert isinstance(slicer_obj, widgets.interactive)
+
+    # Test with different image sizes
+    for img_height, img_width in [(2, 2), (4, 4)]:
+        slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
+        assert isinstance(slicer_obj, widgets.interactive)
+
+    # Test with show_position set to True and False
+    for show_position in [True, False]:
+        slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position)
+        assert isinstance(slicer_obj, widgets.interactive)
+
+# unit tests for orthogonal function
+def test_orthogonal_with_numpy_array():
+    # Create a sample NumPy array
+    vol = np.random.rand(10, 10, 10)
+    # Call the orthogonal function with the NumPy array
+    orthogonal_obj = qim3d.viz.orthogonal(vol)
+    # Assert that the orthogonal object is created successfully
+    assert isinstance(orthogonal_obj, widgets.HBox)
+
+def test_orthogonal_with_torch_tensor():
+    # Create a sample PyTorch tensor
+    vol = torch.rand(10, 10, 10)
+    # Call the orthogonal function with the PyTorch tensor
+    orthogonal_obj = qim3d.viz.orthogonal(vol)
+    # Assert that the orthogonal object is created successfully
+    assert isinstance(orthogonal_obj, widgets.HBox)
+
+def test_orthogonal_with_different_parameters():
+    # Test with different colormaps
+    for cmap in ["viridis", "gray", "plasma"]:
+        orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap)
+        assert isinstance(orthogonal_obj, widgets.HBox)
+
+    # Test with different image sizes
+    for img_height, img_width in [(2, 2), (4, 4)]:
+        orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
+        assert isinstance(orthogonal_obj, widgets.HBox)
+
+    # Test with show_position set to True and False
+    for show_position in [True, False]:
+        orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position)
+        assert isinstance(orthogonal_obj, widgets.HBox)
+
+def test_orthogonal_initial_slider_value():
+    # Create a sample NumPy array
+    vol = np.random.rand(10, 7, 19)
+    # Call the orthogonal function with the NumPy array
+    orthogonal_obj = qim3d.viz.orthogonal(vol)
+    for idx,slicer in enumerate(orthogonal_obj.children):
+        assert slicer.children[0].value == vol.shape[idx]//2
+
+def test_orthogonal_slider_description():
+    # Create a sample NumPy array
+    vol = np.random.rand(10, 10, 10)
+    # Call the orthogonal function with the NumPy array
+    orthogonal_obj = qim3d.viz.orthogonal(vol)
+    for idx,slicer in enumerate(orthogonal_obj.children):
+        assert slicer.children[0].description == ['Z', 'Y', 'X'][idx]
+
+
+
diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index ef978e88..5f1a2614 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -1,3 +1,3 @@
 from .visualizations import plot_metrics
-from .img import grid_pred, grid_overview, slices
+from .img import grid_pred, grid_overview, slices, slicer, orthogonal
 from .k3d import vol
diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index 24c1d666..3cf4e0ca 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -1,17 +1,15 @@
 """ 
 Provides a collection of visualization functions.
 """
-
+import math
 from typing import List, Optional, Union
 import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 import torch
 import numpy as np
+import ipywidgets as widgets
 from qim3d.io.logger import log
-import math
-import qim3d.io
-import os
 
 
 def grid_overview(
@@ -308,7 +306,6 @@ def slices(
             'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
         )
 
-
     # Make grid
     nrows = math.ceil(n_slices / max_cols)
     ncols = min(n_slices, max_cols)
@@ -389,3 +386,123 @@ def _get_slice_range(position: int, n_slices: int, n_total):
         slice_idxs = np.arange(n_total - n_slices, n_total)
 
     return slice_idxs
+
+
+def slicer(
+    vol: Union[np.ndarray, torch.Tensor],
+    axis: int = 0,
+    cmap: str = "viridis",
+    img_height: int = 3,
+    img_width: int = 3,
+    show_position: bool = False,
+    interpolation: Optional[str] = None,
+) -> widgets.interactive:
+    """Interactive widget for visualizing slices of a 3D volume.
+
+    Args:
+        vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
+        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
+        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        img_height(int, optional): Height of the figure. Defaults to 3.
+        img_width(int, optional): Width of the figure. Defaults to 3.
+        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
+        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
+
+    Returns:
+        slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
+
+    Example:
+        vol_path = '/my_vol_path/my_vol.tif'
+        vol = qim3d.io.load(vol_path)
+        slicer(vol, axis = 1)
+    """
+
+    # Create the interactive widget
+    def _slicer(position):
+        fig = slices(
+            vol,
+            axis=axis,
+            cmap=cmap,
+            img_height=img_height,
+            img_width=img_width,
+            show_position=show_position,
+            interpolation=interpolation,
+            position=position,
+            n_slices=1,
+            show=True,
+        )
+        return fig
+
+    position_slider = widgets.IntSlider(
+        value=vol.shape[axis] // 2,
+        min=0,
+        max=vol.shape[axis] - 1,
+        description="Slice",
+        continuous_update=True,
+    )
+    slicer_obj = widgets.interactive(_slicer, position=position_slider)
+    slicer_obj.layout = widgets.Layout(align_items="flex-start")
+
+    return slicer_obj
+
+
+def orthogonal(
+    vol: Union[np.ndarray, torch.Tensor],
+    cmap: str = "viridis",
+    img_height: int = 3,
+    img_width: int = 3,
+    show_position: bool = False,
+    interpolation: Optional[str] = None,
+):
+    """Interactive widget for visualizing orthogonal slices of a 3D volume.
+
+    Args:
+        vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
+        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        img_height(int, optional): Height of the figure.
+        img_width(int, optional): Width of the figure.
+        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
+        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
+
+    Returns:
+        orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
+
+    Example:
+        vol_path = '/my_vol_path/my_vol.tif'
+        vol = qim3d.io.load(vol_path)
+        orthogonal(vol)
+    """
+
+    z_slicer = slicer(
+        vol,
+        axis=0,
+        cmap=cmap,
+        img_height=img_height,
+        img_width=img_width,
+        show_position=show_position,
+        interpolation=interpolation,
+    )
+    y_slicer = slicer(
+        vol,
+        axis=1,
+        cmap=cmap,
+        img_height=img_height,
+        img_width=img_width,
+        show_position=show_position,
+        interpolation=interpolation,
+    )
+    x_slicer = slicer(
+        vol,
+        axis=2,
+        cmap=cmap,
+        img_height=img_height,
+        img_width=img_width,
+        show_position=show_position,
+        interpolation=interpolation,
+    )
+
+    z_slicer.children[0].description = "Z"
+    y_slicer.children[0].description = "Y"
+    x_slicer.children[0].description = "X"
+
+    return widgets.HBox([z_slicer, y_slicer, x_slicer])
diff --git a/requirements.txt b/requirements.txt
index 3e7ac3be..6a865988 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,7 +17,8 @@ torchvision>=0.15.2,
 torchinfo>=1.8.0,
 tqdm>=4.65.0,
 nibabel>=5.2.0,
+ipywidgets>=8.1.2,
 dask>=2023.6.0,
 k3d>=2.16.1
 olefile>=0.46
-psutil>=5.9.0
\ No newline at end of file
+psutil>=5.9.0
diff --git a/setup.py b/setup.py
index 64c0bd12..406df355 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
 
 setup(
     name="qim3d",
-    version="0.3.2",
+    version="0.3.3",
     author="Felipe Delestro",
     author_email="fima@dtu.dk",
     description="QIM tools and user interfaces",
@@ -56,6 +56,7 @@ setup(
         "torchinfo>=1.8.0",
         "tqdm>=4.65.0",
         "nibabel>=5.2.0",
+        "ipywidgets>=8.1.2",
         "dask>=2023.6.0",
         "k3d>=2.16.1",
         "olefile>=0.46",
-- 
GitLab