From b63bf3a21ce5102138a30678afa035aaf1165c16 Mon Sep 17 00:00:00 2001 From: s184058 <s184058@student.dtu.dk> Date: Wed, 6 Mar 2024 10:01:56 +0100 Subject: [PATCH] Simple 3d slicer --- qim3d/tests/viz/test_img.py | 115 +++++++++++++++++++++++++++----- qim3d/viz/__init__.py | 2 +- qim3d/viz/img.py | 127 ++++++++++++++++++++++++++++++++++-- requirements.txt | 3 +- setup.py | 3 +- 5 files changed, 226 insertions(+), 24 deletions(-) diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 17cc3072..82dd9882 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -1,7 +1,9 @@ +import pytest import torch import numpy as np +import ipywidgets as widgets +import matplotlib.pyplot as plt import qim3d -import pytest from qim3d.utils.internal_tools import temp_data # unit tests for grid overview @@ -40,19 +42,17 @@ def test_grid_pred(): temp_data(folder,remove = True) -# unit tests for slice visualization +# unit tests for slices function def test_slices_numpy_array_input(): example_volume = np.ones((10, 10, 10)) - img_width = 3 - fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width) - assert fig.get_figwidth() == img_width + fig = qim3d.viz.slices(example_volume, n_slices=1) + assert isinstance(fig, plt.Figure) def test_slices_torch_tensor_input(): example_volume = torch.ones((10,10,10)) img_width = 3 - fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width) - - assert fig.get_figwidth() == img_width + fig = qim3d.viz.slices(example_volume,n_slices = 1) + assert isinstance(fig, plt.Figure) def test_slices_wrong_input_format(): input = 'not_a_volume' @@ -84,13 +84,6 @@ def test_slices_invalid_axis_value(): with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"): qim3d.viz.slices(example_volume, axis = 3) -def test_slices_show_title_option(): - example_volume = np.ones((10, 10, 10)) - img_width = 3 - fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False) - # Assert that titles are not shown - assert all(ax.get_title() == '' for ax in fig.get_axes()) - def test_slices_interpolation_option(): example_volume = torch.ones((10, 10, 10)) img_width = 3 @@ -125,4 +118,94 @@ def test_slices_axis_argument(): # Ensure that different axes result in different plots assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array()) assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array()) - assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array()) \ No newline at end of file + assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array()) + +# unit tests for slicer function +def test_slicer_with_numpy_array(): + # Create a sample NumPy array + vol = np.random.rand(10, 10, 10) + # Call the slicer function with the NumPy array + slicer_obj = qim3d.viz.slicer(vol) + # Assert that the slicer object is created successfully + assert isinstance(slicer_obj, widgets.interactive) + +def test_slicer_with_torch_tensor(): + # Create a sample PyTorch tensor + vol = torch.rand(10, 10, 10) + # Call the slicer function with the PyTorch tensor + slicer_obj = qim3d.viz.slicer(vol) + # Assert that the slicer object is created successfully + assert isinstance(slicer_obj, widgets.interactive) + +def test_slicer_with_different_parameters(): + # Test with different axis values + for axis in range(3): + slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), axis=axis) + assert isinstance(slicer_obj, widgets.interactive) + + # Test with different colormaps + for cmap in ["viridis", "gray", "plasma"]: + slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap) + assert isinstance(slicer_obj, widgets.interactive) + + # Test with different image sizes + for img_height, img_width in [(2, 2), (4, 4)]: + slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width) + assert isinstance(slicer_obj, widgets.interactive) + + # Test with show_position set to True and False + for show_position in [True, False]: + slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position) + assert isinstance(slicer_obj, widgets.interactive) + +# unit tests for orthogonal function +def test_orthogonal_with_numpy_array(): + # Create a sample NumPy array + vol = np.random.rand(10, 10, 10) + # Call the orthogonal function with the NumPy array + orthogonal_obj = qim3d.viz.orthogonal(vol) + # Assert that the orthogonal object is created successfully + assert isinstance(orthogonal_obj, widgets.HBox) + +def test_orthogonal_with_torch_tensor(): + # Create a sample PyTorch tensor + vol = torch.rand(10, 10, 10) + # Call the orthogonal function with the PyTorch tensor + orthogonal_obj = qim3d.viz.orthogonal(vol) + # Assert that the orthogonal object is created successfully + assert isinstance(orthogonal_obj, widgets.HBox) + +def test_orthogonal_with_different_parameters(): + # Test with different colormaps + for cmap in ["viridis", "gray", "plasma"]: + orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap) + assert isinstance(orthogonal_obj, widgets.HBox) + + # Test with different image sizes + for img_height, img_width in [(2, 2), (4, 4)]: + orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width) + assert isinstance(orthogonal_obj, widgets.HBox) + + # Test with show_position set to True and False + for show_position in [True, False]: + orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position) + assert isinstance(orthogonal_obj, widgets.HBox) + +def test_orthogonal_initial_slider_value(): + # Create a sample NumPy array + vol = np.random.rand(10, 7, 19) + # Call the orthogonal function with the NumPy array + orthogonal_obj = qim3d.viz.orthogonal(vol) + for idx,slicer in enumerate(orthogonal_obj.children): + assert slicer.children[0].value == vol.shape[idx]//2 + +def test_orthogonal_slider_description(): + # Create a sample NumPy array + vol = np.random.rand(10, 10, 10) + # Call the orthogonal function with the NumPy array + orthogonal_obj = qim3d.viz.orthogonal(vol) + for idx,slicer in enumerate(orthogonal_obj.children): + assert slicer.children[0].description == ['Z', 'Y', 'X'][idx] + + + diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index ef978e88..5f1a2614 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,3 +1,3 @@ from .visualizations import plot_metrics -from .img import grid_pred, grid_overview, slices +from .img import grid_pred, grid_overview, slices, slicer, orthogonal from .k3d import vol diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 24c1d666..3cf4e0ca 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -1,17 +1,15 @@ """ Provides a collection of visualization functions. """ - +import math from typing import List, Optional, Union import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib import colormaps import torch import numpy as np +import ipywidgets as widgets from qim3d.io.logger import log -import math -import qim3d.io -import os def grid_overview( @@ -308,7 +306,6 @@ def slices( 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' ) - # Make grid nrows = math.ceil(n_slices / max_cols) ncols = min(n_slices, max_cols) @@ -389,3 +386,123 @@ def _get_slice_range(position: int, n_slices: int, n_total): slice_idxs = np.arange(n_total - n_slices, n_total) return slice_idxs + + +def slicer( + vol: Union[np.ndarray, torch.Tensor], + axis: int = 0, + cmap: str = "viridis", + img_height: int = 3, + img_width: int = 3, + show_position: bool = False, + interpolation: Optional[str] = None, +) -> widgets.interactive: + """Interactive widget for visualizing slices of a 3D volume. + + Args: + vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". + img_height(int, optional): Height of the figure. Defaults to 3. + img_width(int, optional): Width of the figure. Defaults to 3. + show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. + + Returns: + slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. + + Example: + vol_path = '/my_vol_path/my_vol.tif' + vol = qim3d.io.load(vol_path) + slicer(vol, axis = 1) + """ + + # Create the interactive widget + def _slicer(position): + fig = slices( + vol, + axis=axis, + cmap=cmap, + img_height=img_height, + img_width=img_width, + show_position=show_position, + interpolation=interpolation, + position=position, + n_slices=1, + show=True, + ) + return fig + + position_slider = widgets.IntSlider( + value=vol.shape[axis] // 2, + min=0, + max=vol.shape[axis] - 1, + description="Slice", + continuous_update=True, + ) + slicer_obj = widgets.interactive(_slicer, position=position_slider) + slicer_obj.layout = widgets.Layout(align_items="flex-start") + + return slicer_obj + + +def orthogonal( + vol: Union[np.ndarray, torch.Tensor], + cmap: str = "viridis", + img_height: int = 3, + img_width: int = 3, + show_position: bool = False, + interpolation: Optional[str] = None, +): + """Interactive widget for visualizing orthogonal slices of a 3D volume. + + Args: + vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". + img_height(int, optional): Height of the figure. + img_width(int, optional): Width of the figure. + show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. + + Returns: + orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. + + Example: + vol_path = '/my_vol_path/my_vol.tif' + vol = qim3d.io.load(vol_path) + orthogonal(vol) + """ + + z_slicer = slicer( + vol, + axis=0, + cmap=cmap, + img_height=img_height, + img_width=img_width, + show_position=show_position, + interpolation=interpolation, + ) + y_slicer = slicer( + vol, + axis=1, + cmap=cmap, + img_height=img_height, + img_width=img_width, + show_position=show_position, + interpolation=interpolation, + ) + x_slicer = slicer( + vol, + axis=2, + cmap=cmap, + img_height=img_height, + img_width=img_width, + show_position=show_position, + interpolation=interpolation, + ) + + z_slicer.children[0].description = "Z" + y_slicer.children[0].description = "Y" + x_slicer.children[0].description = "X" + + return widgets.HBox([z_slicer, y_slicer, x_slicer]) diff --git a/requirements.txt b/requirements.txt index 3e7ac3be..6a865988 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,8 @@ torchvision>=0.15.2, torchinfo>=1.8.0, tqdm>=4.65.0, nibabel>=5.2.0, +ipywidgets>=8.1.2, dask>=2023.6.0, k3d>=2.16.1 olefile>=0.46 -psutil>=5.9.0 \ No newline at end of file +psutil>=5.9.0 diff --git a/setup.py b/setup.py index 64c0bd12..406df355 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f: setup( name="qim3d", - version="0.3.2", + version="0.3.3", author="Felipe Delestro", author_email="fima@dtu.dk", description="QIM tools and user interfaces", @@ -56,6 +56,7 @@ setup( "torchinfo>=1.8.0", "tqdm>=4.65.0", "nibabel>=5.2.0", + "ipywidgets>=8.1.2", "dask>=2023.6.0", "k3d>=2.16.1", "olefile>=0.46", -- GitLab