Skip to content
Snippets Groups Projects
Commit b63bf3a2 authored by s184058's avatar s184058 Committed by fima
Browse files

Simple 3d slicer

parent e5a790ed
Branches
Tags
1 merge request!56Simple 3d slicer
import pytest
import torch import torch
import numpy as np import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import qim3d import qim3d
import pytest
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview # unit tests for grid overview
...@@ -40,19 +42,17 @@ def test_grid_pred(): ...@@ -40,19 +42,17 @@ def test_grid_pred():
temp_data(folder,remove = True) temp_data(folder,remove = True)
# unit tests for slice visualization # unit tests for slices function
def test_slices_numpy_array_input(): def test_slices_numpy_array_input():
example_volume = np.ones((10, 10, 10)) example_volume = np.ones((10, 10, 10))
img_width = 3 fig = qim3d.viz.slices(example_volume, n_slices=1)
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width) assert isinstance(fig, plt.Figure)
assert fig.get_figwidth() == img_width
def test_slices_torch_tensor_input(): def test_slices_torch_tensor_input():
example_volume = torch.ones((10,10,10)) example_volume = torch.ones((10,10,10))
img_width = 3 img_width = 3
fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width) fig = qim3d.viz.slices(example_volume,n_slices = 1)
assert isinstance(fig, plt.Figure)
assert fig.get_figwidth() == img_width
def test_slices_wrong_input_format(): def test_slices_wrong_input_format():
input = 'not_a_volume' input = 'not_a_volume'
...@@ -84,13 +84,6 @@ def test_slices_invalid_axis_value(): ...@@ -84,13 +84,6 @@ def test_slices_invalid_axis_value():
with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"): with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
qim3d.viz.slices(example_volume, axis = 3) qim3d.viz.slices(example_volume, axis = 3)
def test_slices_show_title_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
# Assert that titles are not shown
assert all(ax.get_title() == '' for ax in fig.get_axes())
def test_slices_interpolation_option(): def test_slices_interpolation_option():
example_volume = torch.ones((10, 10, 10)) example_volume = torch.ones((10, 10, 10))
img_width = 3 img_width = 3
...@@ -126,3 +119,93 @@ def test_slices_axis_argument(): ...@@ -126,3 +119,93 @@ def test_slices_axis_argument():
assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array()) assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array()) assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array()) assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
# unit tests for slicer function
def test_slicer_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the slicer function with the NumPy array
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the slicer function with the PyTorch tensor
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_different_parameters():
# Test with different axis values
for axis in range(3):
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), axis=axis)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(slicer_obj, widgets.interactive)
# Test with show_position set to True and False
for show_position in [True, False]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(slicer_obj, widgets.interactive)
# unit tests for orthogonal function
def test_orthogonal_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the orthogonal function with the PyTorch tensor
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_different_parameters():
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with show_position set to True and False
for show_position in [True, False]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_initial_slider_value():
# Create a sample NumPy array
vol = np.random.rand(10, 7, 19)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].value == vol.shape[idx]//2
def test_orthogonal_slider_description():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].description == ['Z', 'Y', 'X'][idx]
from .visualizations import plot_metrics from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices from .img import grid_pred, grid_overview, slices, slicer, orthogonal
from .k3d import vol from .k3d import vol
""" """
Provides a collection of visualization functions. Provides a collection of visualization functions.
""" """
import math
from typing import List, Optional, Union from typing import List, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps from matplotlib import colormaps
import torch import torch
import numpy as np import numpy as np
import ipywidgets as widgets
from qim3d.io.logger import log from qim3d.io.logger import log
import math
import qim3d.io
import os
def grid_overview( def grid_overview(
...@@ -308,7 +306,6 @@ def slices( ...@@ -308,7 +306,6 @@ def slices(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
) )
# Make grid # Make grid
nrows = math.ceil(n_slices / max_cols) nrows = math.ceil(n_slices / max_cols)
ncols = min(n_slices, max_cols) ncols = min(n_slices, max_cols)
...@@ -389,3 +386,123 @@ def _get_slice_range(position: int, n_slices: int, n_total): ...@@ -389,3 +386,123 @@ def _get_slice_range(position: int, n_slices: int, n_total):
slice_idxs = np.arange(n_total - n_slices, n_total) slice_idxs = np.arange(n_total - n_slices, n_total)
return slice_idxs return slice_idxs
def slicer(
vol: Union[np.ndarray, torch.Tensor],
axis: int = 0,
cmap: str = "viridis",
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
interpolation: Optional[str] = None,
) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume.
Args:
vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure. Defaults to 3.
img_width(int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
Example:
vol_path = '/my_vol_path/my_vol.tif'
vol = qim3d.io.load(vol_path)
slicer(vol, axis = 1)
"""
# Create the interactive widget
def _slicer(position):
fig = slices(
vol,
axis=axis,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
position=position,
n_slices=1,
show=True,
)
return fig
position_slider = widgets.IntSlider(
value=vol.shape[axis] // 2,
min=0,
max=vol.shape[axis] - 1,
description="Slice",
continuous_update=True,
)
slicer_obj = widgets.interactive(_slicer, position=position_slider)
slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj
def orthogonal(
vol: Union[np.ndarray, torch.Tensor],
cmap: str = "viridis",
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
interpolation: Optional[str] = None,
):
"""Interactive widget for visualizing orthogonal slices of a 3D volume.
Args:
vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure.
img_width(int, optional): Width of the figure.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
Example:
vol_path = '/my_vol_path/my_vol.tif'
vol = qim3d.io.load(vol_path)
orthogonal(vol)
"""
z_slicer = slicer(
vol,
axis=0,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
y_slicer = slicer(
vol,
axis=1,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
x_slicer = slicer(
vol,
axis=2,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y"
x_slicer.children[0].description = "X"
return widgets.HBox([z_slicer, y_slicer, x_slicer])
...@@ -17,6 +17,7 @@ torchvision>=0.15.2, ...@@ -17,6 +17,7 @@ torchvision>=0.15.2,
torchinfo>=1.8.0, torchinfo>=1.8.0,
tqdm>=4.65.0, tqdm>=4.65.0,
nibabel>=5.2.0, nibabel>=5.2.0,
ipywidgets>=8.1.2,
dask>=2023.6.0, dask>=2023.6.0,
k3d>=2.16.1 k3d>=2.16.1
olefile>=0.46 olefile>=0.46
......
...@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f: ...@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup( setup(
name="qim3d", name="qim3d",
version="0.3.2", version="0.3.3",
author="Felipe Delestro", author="Felipe Delestro",
author_email="fima@dtu.dk", author_email="fima@dtu.dk",
description="QIM tools and user interfaces", description="QIM tools and user interfaces",
...@@ -56,6 +56,7 @@ setup( ...@@ -56,6 +56,7 @@ setup(
"torchinfo>=1.8.0", "torchinfo>=1.8.0",
"tqdm>=4.65.0", "tqdm>=4.65.0",
"nibabel>=5.2.0", "nibabel>=5.2.0",
"ipywidgets>=8.1.2",
"dask>=2023.6.0", "dask>=2023.6.0",
"k3d>=2.16.1", "k3d>=2.16.1",
"olefile>=0.46", "olefile>=0.46",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment