Skip to content
Snippets Groups Projects
Commit 411cba30 authored by Christian Kento Rasmussen's avatar Christian Kento Rasmussen
Browse files

added tests and visualiser

parent a18ea21f
No related branches found
No related tags found
1 merge request!48Implemented 3D connected components as wrapper class for scipy.ndimage.label
import numpy as np
import pytest
from qim3d.utils.connected_components import get_3d_connected_components
@pytest.fixture(scope="module")
def setup_data():
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[1,1,0,0,1,0],
[0,0,0,1,0,0]])
num_components = 4
connected_components = get_3d_connected_components(components)
return connected_components, components, num_components
def test_connected_components_property(setup_data):
connected_components, _, _ = setup_data
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[2,2,0,0,3,0],
[0,0,0,4,0,0]])
assert np.array_equal(connected_components.connected_components, components)
def test_num_connected_components_property(setup_data):
connected_components, _, num_components = setup_data
assert connected_components.num_connected_components == num_components
def test_get_connected_component_with_index(setup_data):
connected_components, _, _ = setup_data
expected_component = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[0,0,0,0,0,0],
[0,0,0,0,0,0]], dtype=bool)
print(connected_components.get_connected_component(index=1))
print(expected_component)
assert np.array_equal(connected_components.get_connected_component(index=1), expected_component)
def test_get_connected_component_without_index(setup_data):
connected_components, _, _ = setup_data
component = connected_components.get_connected_component()
assert np.any(component)
def test_get_connected_component_with_invalid_index(setup_data):
connected_components, _, num_components = setup_data
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=0)
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=num_components + 1)
\ No newline at end of file
import qim3d
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pytest import pytest
from torch import ones from torch import ones
import qim3d
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview # unit tests for grid overview
def test_grid_overview(): def test_grid_overview():
random_tuple = (ones(1,256,256),ones(256,256)) random_tuple = (ones(1,256,256),ones(256,256))
......
import qim3d
import pytest import pytest
import qim3d
#unit test for plot_metrics() #unit test for plot_metrics()
def test_plot_metrics(): def test_plot_metrics():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
......
import numpy as np import numpy as np
from scipy.ndimage import label from scipy.ndimage import label
# TODO: implement find_objects and get_bounding_boxes methods
class ConnectedComponents: class ConnectedComponents:
def __init__(self, connected_components, num_connected_components): def __init__(self, connected_components, num_connected_components):
...@@ -44,15 +45,14 @@ class ConnectedComponents: ...@@ -44,15 +45,14 @@ class ConnectedComponents:
Returns: Returns:
np.ndarray: The connected component as a binary mask. np.ndarray: The connected component as a binary mask.
""" """
assert 1 <= index <= self._num_connected_components, "Index out of range." if index is None:
return self.connected_components == np.random.randint(1, self.num_connected_components + 1)
if index:
return self._connected_components == index
else: else:
return self._connected_components == np.random.randint(1, self._num_connected_components + 1) assert 1 <= index <= self.num_connected_components, "Index out of range."
return self.connected_components == index
def get_3d_connected_components(image, connectivity=1): def get_3d_connected_components(image):
"""Get the connected components of a 3D binary image. """Get the connected components of a 3D binary image.
Args: Args:
...@@ -62,5 +62,5 @@ def get_3d_connected_components(image, connectivity=1): ...@@ -62,5 +62,5 @@ def get_3d_connected_components(image, connectivity=1):
Returns: Returns:
class: Returns class object of the connected components. class: Returns class object of the connected components.
""" """
connected_components, num_connected_components = label(image, connectivity) connected_components, num_connected_components = label(image)
return ConnectedComponents(connected_components, num_connected_components) return ConnectedComponents(connected_components, num_connected_components)
""" Provides a collection of visualization functions.""" """ Provides a collection of visualization functions."""
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
import torch
import numpy as np import numpy as np
from qim3d.io.logger import log import torch
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
import qim3d.io import qim3d.io
from qim3d.io.logger import log
from qim3d.utils.connected_components import ConnectedComponents
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False): def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
"""Displays an overview grid of images, labels, and masks (if they exist). """Displays an overview grid of images, labels, and masks (if they exist).
...@@ -299,3 +302,41 @@ def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = Fal ...@@ -299,3 +302,41 @@ def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = Fal
plt.close() plt.close()
return fig return fig
def plot_connected_components(connected_components: ConnectedComponents, show=False):
""" Plots the connected components in 3D.
Args:
connected_components (ConnectedComponents): The connected components class from the qim3d.utils.connected_components module.
show (bool, optional): If matplotlib should show the plot. Defaults to False.
Returns:
matplotlib.pyplot: the 3D plot of the connected components.
"""
# Begin plotting
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
# Define default color theme
colors = plt.cm.tab10(np.linspace(0, 1, connected_components.num_connected_components + 1))
# Plot each component with a different color
for label_num in range(1, connected_components.num_connected_components + 1):
# Find the voxels that belong to the current component
component_voxels = connected_components.get_connected_component(label_num)
# Plot each voxel of the component
for voxel in zip(*component_voxels.nonzero()):
x, y, z = voxel
ax.bar3d(x, y, z, 1, 1, 1, color=colors[label_num], shade=True, alpha=0.5)
# Set labels and titles if necessary
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
ax.set_title('3D Visualization of Connected Components')
if show:
plt.show()
plt.close()
return fig
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment