Skip to content
Snippets Groups Projects

Simple 3d slicer

Merged s184058 requested to merge simple_3D_slicer into main

Files

+ 99
16
import pytest
import torch
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import qim3d
import pytest
from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview
@@ -40,19 +42,17 @@ def test_grid_pred():
temp_data(folder,remove = True)
# unit tests for slice visualization
# unit tests for slices function
def test_slices_numpy_array_input():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
assert fig.get_figwidth() == img_width
fig = qim3d.viz.slices(example_volume, n_slices=1)
assert isinstance(fig, plt.Figure)
def test_slices_torch_tensor_input():
example_volume = torch.ones((10,10,10))
img_width = 3
fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
assert fig.get_figwidth() == img_width
fig = qim3d.viz.slices(example_volume,n_slices = 1)
assert isinstance(fig, plt.Figure)
def test_slices_wrong_input_format():
input = 'not_a_volume'
@@ -84,13 +84,6 @@ def test_slices_invalid_axis_value():
with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
qim3d.viz.slices(example_volume, axis = 3)
def test_slices_show_title_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
# Assert that titles are not shown
assert all(ax.get_title() == '' for ax in fig.get_axes())
def test_slices_interpolation_option():
example_volume = torch.ones((10, 10, 10))
img_width = 3
@@ -125,4 +118,94 @@ def test_slices_axis_argument():
# Ensure that different axes result in different plots
assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
\ No newline at end of file
assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
# unit tests for slicer function
def test_slicer_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the slicer function with the NumPy array
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the slicer function with the PyTorch tensor
slicer_obj = qim3d.viz.slicer(vol)
# Assert that the slicer object is created successfully
assert isinstance(slicer_obj, widgets.interactive)
def test_slicer_with_different_parameters():
# Test with different axis values
for axis in range(3):
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), axis=axis)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(slicer_obj, widgets.interactive)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(slicer_obj, widgets.interactive)
# Test with show_position set to True and False
for show_position in [True, False]:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(slicer_obj, widgets.interactive)
# unit tests for orthogonal function
def test_orthogonal_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the orthogonal function with the PyTorch tensor
orthogonal_obj = qim3d.viz.orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_different_parameters():
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with show_position set to True and False
for show_position in [True, False]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position)
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_initial_slider_value():
# Create a sample NumPy array
vol = np.random.rand(10, 7, 19)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].value == vol.shape[idx]//2
def test_orthogonal_slider_description():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
for idx,slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].description == ['Z', 'Y', 'X'][idx]
Loading