diff --git a/qim3d/utils/connected_components.py b/qim3d/utils/connected_components.py index d0eba60b79d7d26c58e29e841f6fc178086e7044..6cccc4896f9b08456da0306e1474d05ac7a1d57c 100644 --- a/qim3d/utils/connected_components.py +++ b/qim3d/utils/connected_components.py @@ -35,23 +35,31 @@ class ConnectedComponents: """ return self._num_connected_components - def get_connected_component(self, index=None): + def get_connected_component(self, index=None, crop=False): """ Get the connected component with the given index, if index is None selects a random component. Args: index (int): The index of the connected component. If none selects a random component. + crop (bool): If True, the volume is cropped to the bounding box of the connected component. Returns: np.ndarray: The connected component as a binary mask. """ if index is None: - return self.connected_components == np.random.randint( + volume = self.connected_components == np.random.randint( 1, self.num_connected_components + 1 ) else: assert 1 <= index <= self.num_connected_components, "Index out of range." - return self.connected_components == index + volume = self.connected_components == index + + if crop: + # As we index get_bounding_box element 0 will be the bounding box for the connected component at index + bbox = self.get_bounding_box(index)[0] + volume = volume[bbox] + + return volume def get_bounding_box(self, index=None): """Get the bounding boxes of the connected components.