diff --git a/qim3d/tests/utils/test_connected_components.py b/qim3d/tests/utils/test_connected_components.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4763aa74a7206b17eaedd6a7d01a5ca32eca5d --- /dev/null +++ b/qim3d/tests/utils/test_connected_components.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest + +from qim3d.utils.connected_components import get_3d_connected_components + + +@pytest.fixture(scope="module") +def setup_data(): + components = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [1,1,0,0,1,0], + [0,0,0,1,0,0]]) + num_components = 4 + connected_components = get_3d_connected_components(components) + return connected_components, components, num_components + +def test_connected_components_property(setup_data): + connected_components, _, _ = setup_data + components = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [2,2,0,0,3,0], + [0,0,0,4,0,0]]) + assert np.array_equal(connected_components.connected_components, components) + +def test_num_connected_components_property(setup_data): + connected_components, _, num_components = setup_data + assert connected_components.num_connected_components == num_components + +def test_get_connected_component_with_index(setup_data): + connected_components, _, _ = setup_data + expected_component = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [0,0,0,0,0,0], + [0,0,0,0,0,0]], dtype=bool) + print(connected_components.get_connected_component(index=1)) + print(expected_component) + assert np.array_equal(connected_components.get_connected_component(index=1), expected_component) + +def test_get_connected_component_without_index(setup_data): + connected_components, _, _ = setup_data + component = connected_components.get_connected_component() + assert np.any(component) + +def test_get_connected_component_with_invalid_index(setup_data): + connected_components, _, num_components = setup_data + with pytest.raises(AssertionError): + connected_components.get_connected_component(index=0) + with pytest.raises(AssertionError): + connected_components.get_connected_component(index=num_components + 1) \ No newline at end of file diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 06fec0d0c41d0c0269ffc94291a77a8b63fa2b36..d200c12d03eeab972f026ed70fea8ed074f4812c 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -1,10 +1,11 @@ -import qim3d import matplotlib.pyplot as plt import pytest - from torch import ones + +import qim3d from qim3d.utils.internal_tools import temp_data + # unit tests for grid overview def test_grid_overview(): random_tuple = (ones(1,256,256),ones(256,256)) diff --git a/qim3d/tests/viz/test_visualizations.py b/qim3d/tests/viz/test_visualizations.py index 75e84eb9fe06cffe7fa62a600a4821ebf8265fae..c9a8beceac2ae601b03adbc77d473fd339b577d1 100644 --- a/qim3d/tests/viz/test_visualizations.py +++ b/qim3d/tests/viz/test_visualizations.py @@ -1,6 +1,8 @@ -import qim3d import pytest +import qim3d + + #unit test for plot_metrics() def test_plot_metrics(): metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} diff --git a/qim3d/utils/3d_connected_components.py b/qim3d/utils/connected_components.py similarity index 83% rename from qim3d/utils/3d_connected_components.py rename to qim3d/utils/connected_components.py index 36f986d2a886553a897ef73fe254821870f8d349..ff9e5dcb08c8b380d9f7706ee47bcdb2e3399120 100644 --- a/qim3d/utils/3d_connected_components.py +++ b/qim3d/utils/connected_components.py @@ -1,6 +1,7 @@ import numpy as np from scipy.ndimage import label +# TODO: implement find_objects and get_bounding_boxes methods class ConnectedComponents: def __init__(self, connected_components, num_connected_components): @@ -44,15 +45,14 @@ class ConnectedComponents: Returns: np.ndarray: The connected component as a binary mask. """ - assert 1 <= index <= self._num_connected_components, "Index out of range." - - if index: - return self._connected_components == index + if index is None: + return self.connected_components == np.random.randint(1, self.num_connected_components + 1) else: - return self._connected_components == np.random.randint(1, self._num_connected_components + 1) + assert 1 <= index <= self.num_connected_components, "Index out of range." + return self.connected_components == index -def get_3d_connected_components(image, connectivity=1): +def get_3d_connected_components(image): """Get the connected components of a 3D binary image. Args: @@ -62,5 +62,5 @@ def get_3d_connected_components(image, connectivity=1): Returns: class: Returns class object of the connected components. """ - connected_components, num_connected_components = label(image, connectivity) + connected_components, num_connected_components = label(image) return ConnectedComponents(connected_components, num_connected_components) diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index e3648f5445c7a4449c2b42057c78b2db9e48e381..8bcffba2e91a3046a5bd82166cc18b6ce53d2978 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -1,11 +1,14 @@ """ Provides a collection of visualization functions.""" import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap -from matplotlib import colormaps -import torch import numpy as np -from qim3d.io.logger import log +import torch +from matplotlib import colormaps +from matplotlib.colors import LinearSegmentedColormap + import qim3d.io +from qim3d.io.logger import log +from qim3d.utils.connected_components import ConnectedComponents + def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False): """Displays an overview grid of images, labels, and masks (if they exist). @@ -298,4 +301,42 @@ def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = Fal plt.show() plt.close() + return fig + +def plot_connected_components(connected_components: ConnectedComponents, show=False): + """ Plots the connected components in 3D. + + Args: + connected_components (ConnectedComponents): The connected components class from the qim3d.utils.connected_components module. + show (bool, optional): If matplotlib should show the plot. Defaults to False. + + Returns: + matplotlib.pyplot: the 3D plot of the connected components. + """ + # Begin plotting + fig, ax = plt.subplots(subplot_kw=dict(projection='3d')) + + # Define default color theme + colors = plt.cm.tab10(np.linspace(0, 1, connected_components.num_connected_components + 1)) + + # Plot each component with a different color + for label_num in range(1, connected_components.num_connected_components + 1): + # Find the voxels that belong to the current component + component_voxels = connected_components.get_connected_component(label_num) + + # Plot each voxel of the component + for voxel in zip(*component_voxels.nonzero()): + x, y, z = voxel + ax.bar3d(x, y, z, 1, 1, 1, color=colors[label_num], shade=True, alpha=0.5) + + # Set labels and titles if necessary + ax.set_xlabel('X axis') + ax.set_ylabel('Y axis') + ax.set_zlabel('Z axis') + ax.set_title('3D Visualization of Connected Components') + + if show: + plt.show() + plt.close() + return fig \ No newline at end of file