Skip to content
Snippets Groups Projects
Commit 20edba7f authored by fima's avatar fima :beers:
Browse files

Merge branch 'simple_3D_slicer' into 'main'

Simple 3d slicer

See merge request !56
parents e5a790ed b63bf3a2
Branches
No related tags found
1 merge request!56Simple 3d slicer
import pytest
import torch
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import qim3d
import pytest
from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview
......@@ -40,19 +42,17 @@ def test_grid_pred():
temp_data(folder,remove = True)
# unit tests for slice visualization
# unit tests for slices function
def test_slices_numpy_array_input():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
assert fig.get_figwidth() == img_width
fig = qim3d.viz.slices(example_volume, n_slices=1)
assert isinstance(fig, plt.Figure)
def test_slices_torch_tensor_input():
example_volume = torch.ones((10,10,10))
img_width = 3
fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
assert fig.get_figwidth() == img_width
fig = qim3d.viz.slices(example_volume,n_slices = 1)
assert isinstance(fig, plt.Figure)
def test_slices_wrong_input_format():
input = 'not_a_volume'
......@@ -84,13 +84,6 @@ def test_slices_invalid_axis_value():
with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
qim3d.viz.slices(example_volume, axis = 3)
def test_slices_show_title_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
# Assert that titles are not shown
assert all(ax.get_title() == '' for ax in fig.get_axes())
def test_slices_interpolation_option():
example_volume = torch.ones((10, 10, 10))
img_width = 3
......@@ -126,3 +119,93 @@ def test_slices_axis_argument():
assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
# unit tests for slicer function
def test_slicer_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the slicer function with the NumPy array
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the slicer function with the PyTorch tensor
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_different_parameters():
# Test with different axis values
for axis in range(3):
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), axis=axis)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(slicer_obj, widgets.interactive)
# Test with show_position set to True and False
for show_position in [True, False]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(slicer_obj, widgets.interactive)
# unit tests for orthogonal function
def test_orthogonal_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the orthogonal function with the PyTorch tensor
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_different_parameters():
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with show_position set to True and False
for show_position in [True, False]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_initial_slider_value():
# Create a sample NumPy array
vol = np.random.rand(10, 7, 19)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].value == vol.shape[idx]//2
def test_orthogonal_slider_description():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].description == ['Z', 'Y', 'X'][idx]
from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices
from .img import grid_pred, grid_overview, slices, slicer, orthogonal
from .k3d import vol
"""
Provides a collection of visualization functions.
"""
import math
from typing import List, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
import torch
import numpy as np
import ipywidgets as widgets
from qim3d.io.logger import log
import math
import qim3d.io
import os
def grid_overview(
......@@ -308,7 +306,6 @@ def slices(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
)
# Make grid
nrows = math.ceil(n_slices / max_cols)
ncols = min(n_slices, max_cols)
......@@ -389,3 +386,123 @@ def _get_slice_range(position: int, n_slices: int, n_total):
slice_idxs = np.arange(n_total - n_slices, n_total)
return slice_idxs
def slicer(
vol: Union[np.ndarray, torch.Tensor],
axis: int = 0,
cmap: str = "viridis",
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
interpolation: Optional[str] = None,
) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume.
Args:
vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure. Defaults to 3.
img_width(int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
Example:
vol_path = '/my_vol_path/my_vol.tif'
vol = qim3d.io.load(vol_path)
slicer(vol, axis = 1)
"""
# Create the interactive widget
def _slicer(position):
fig = slices(
vol,
axis=axis,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
position=position,
n_slices=1,
show=True,
)
return fig
position_slider = widgets.IntSlider(
value=vol.shape[axis] // 2,
min=0,
max=vol.shape[axis] - 1,
description="Slice",
continuous_update=True,
)
slicer_obj = widgets.interactive(_slicer, position=position_slider)
slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj
def orthogonal(
vol: Union[np.ndarray, torch.Tensor],
cmap: str = "viridis",
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
interpolation: Optional[str] = None,
):
"""Interactive widget for visualizing orthogonal slices of a 3D volume.
Args:
vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure.
img_width(int, optional): Width of the figure.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
Example:
vol_path = '/my_vol_path/my_vol.tif'
vol = qim3d.io.load(vol_path)
orthogonal(vol)
"""
z_slicer = slicer(
vol,
axis=0,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
y_slicer = slicer(
vol,
axis=1,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
x_slicer = slicer(
vol,
axis=2,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y"
x_slicer.children[0].description = "X"
return widgets.HBox([z_slicer, y_slicer, x_slicer])
......@@ -17,6 +17,7 @@ torchvision>=0.15.2,
torchinfo>=1.8.0,
tqdm>=4.65.0,
nibabel>=5.2.0,
ipywidgets>=8.1.2,
dask>=2023.6.0,
k3d>=2.16.1
olefile>=0.46
......
......@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup(
name="qim3d",
version="0.3.2",
version="0.3.3",
author="Felipe Delestro",
author_email="fima@dtu.dk",
description="QIM tools and user interfaces",
......@@ -56,6 +56,7 @@ setup(
"torchinfo>=1.8.0",
"tqdm>=4.65.0",
"nibabel>=5.2.0",
"ipywidgets>=8.1.2",
"dask>=2023.6.0",
"k3d>=2.16.1",
"olefile>=0.46",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment