diff --git a/qim3d/utils/connected_components.py b/qim3d/utils/connected_components.py index ff9e5dcb08c8b380d9f7706ee47bcdb2e3399120..d0eba60b79d7d26c58e29e841f6fc178086e7044 100644 --- a/qim3d/utils/connected_components.py +++ b/qim3d/utils/connected_components.py @@ -1,7 +1,7 @@ import numpy as np -from scipy.ndimage import label +import torch +from scipy.ndimage import find_objects, label -# TODO: implement find_objects and get_bounding_boxes methods class ConnectedComponents: def __init__(self, connected_components, num_connected_components): @@ -36,7 +36,7 @@ class ConnectedComponents: return self._num_connected_components def get_connected_component(self, index=None): - """ + """ Get the connected component with the given index, if index is None selects a random component. Args: @@ -46,21 +46,44 @@ class ConnectedComponents: np.ndarray: The connected component as a binary mask. """ if index is None: - return self.connected_components == np.random.randint(1, self.num_connected_components + 1) + return self.connected_components == np.random.randint( + 1, self.num_connected_components + 1 + ) else: assert 1 <= index <= self.num_connected_components, "Index out of range." return self.connected_components == index + def get_bounding_box(self, index=None): + """Get the bounding boxes of the connected components. + + Args: + index (int, optional): The index of the connected component. If none selects all components. + + Returns: + list: A list of bounding boxes. + """ -def get_3d_connected_components(image): + if index: + assert 1 <= index <= self.num_connected_components, "Index out of range." + return find_objects(self.connected_components == index) + else: + return find_objects(self.connected_components) + + +def get_3d_connected_components(image: np.ndarray | torch.Tensor): """Get the connected components of a 3D binary image. Args: - image (np.ndarray): The 3D binary image. - connectivity (int, optional): The connectivity of the connected components. Defaults to 1. + image (np.ndarray | torch.Tensor): An array-like object to be labeled. Any non-zero values in `input` are + counted as features and zero values are considered the background. Returns: class: Returns class object of the connected components. """ + if image.ndim != 3: + raise ValueError( + f"Given array is not a volume! Current dimension: {image.ndim}" + ) + connected_components, num_connected_components = label(image) return ConnectedComponents(connected_components, num_connected_components) diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 8bcffba2e91a3046a5bd82166cc18b6ce53d2978..b2c2ad99b2eae59db3e9440e1940f39c70c23073 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -303,40 +303,16 @@ def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = Fal return fig -def plot_connected_components(connected_components: ConnectedComponents, show=False): - """ Plots the connected components in 3D. +def plot_connected_components(connected_components: ConnectedComponents, **kwargs): + """ Plots connected components from the ConnectedComponents class as 2d slices. Args: connected_components (ConnectedComponents): The connected components class from the qim3d.utils.connected_components module. show (bool, optional): If matplotlib should show the plot. Defaults to False. + **kwargs: Additional keyword arguments to pass to qim3d.viz.img.slice_viz. Returns: matplotlib.pyplot: the 3D plot of the connected components. """ - # Begin plotting - fig, ax = plt.subplots(subplot_kw=dict(projection='3d')) - - # Define default color theme - colors = plt.cm.tab10(np.linspace(0, 1, connected_components.num_connected_components + 1)) - - # Plot each component with a different color - for label_num in range(1, connected_components.num_connected_components + 1): - # Find the voxels that belong to the current component - component_voxels = connected_components.get_connected_component(label_num) - - # Plot each voxel of the component - for voxel in zip(*component_voxels.nonzero()): - x, y, z = voxel - ax.bar3d(x, y, z, 1, 1, 1, color=colors[label_num], shade=True, alpha=0.5) - - # Set labels and titles if necessary - ax.set_xlabel('X axis') - ax.set_ylabel('Y axis') - ax.set_zlabel('Z axis') - ax.set_title('3D Visualization of Connected Components') - - if show: - plt.show() - plt.close() - - return fig \ No newline at end of file + fig = slice_viz(connected_components.connected_components, **kwargs) + return fig