From 0b8b467275ce917121650134a07028a8d13dfaa5 Mon Sep 17 00:00:00 2001
From: Christian Kento Rasmussen <christian.kento@gmail.com>
Date: Fri, 9 Feb 2024 09:25:13 +0100
Subject: [PATCH] added option to crop volume to cc

---
 qim3d/utils/connected_components.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/qim3d/utils/connected_components.py b/qim3d/utils/connected_components.py
index d0eba60b..6cccc489 100644
--- a/qim3d/utils/connected_components.py
+++ b/qim3d/utils/connected_components.py
@@ -35,23 +35,31 @@ class ConnectedComponents:
         """
         return self._num_connected_components
 
-    def get_connected_component(self, index=None):
+    def get_connected_component(self, index=None, crop=False):
         """
         Get the connected component with the given index, if index is None selects a random component.
 
         Args:
             index (int): The index of the connected component. If none selects a random component.
+            crop (bool): If True, the volume is cropped to the bounding box of the connected component.
 
         Returns:
             np.ndarray: The connected component as a binary mask.
         """
         if index is None:
-            return self.connected_components == np.random.randint(
+            volume =  self.connected_components == np.random.randint(
                 1, self.num_connected_components + 1
             )
         else:
             assert 1 <= index <= self.num_connected_components, "Index out of range."
-            return self.connected_components == index
+            volume = self.connected_components == index
+            
+        if crop:
+            # As we index get_bounding_box element 0 will be the bounding box for the connected component at index
+            bbox = self.get_bounding_box(index)[0] 
+            volume = volume[bbox]
+        
+        return volume
 
     def get_bounding_box(self, index=None):
         """Get the bounding boxes of the connected components.
-- 
GitLab