diff --git a/qim3d/tests/utils/test_connected_components.py b/qim3d/tests/utils/test_connected_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c4763aa74a7206b17eaedd6a7d01a5ca32eca5d
--- /dev/null
+++ b/qim3d/tests/utils/test_connected_components.py
@@ -0,0 +1,49 @@
+import numpy as np
+import pytest
+
+from qim3d.utils.connected_components import get_3d_connected_components
+
+
+@pytest.fixture(scope="module")
+def setup_data():
+    components = np.array([[0,0,1,1,0,0],
+                           [0,0,0,1,0,0],
+                           [1,1,0,0,1,0],
+                           [0,0,0,1,0,0]])
+    num_components = 4
+    connected_components = get_3d_connected_components(components)
+    return connected_components, components, num_components
+
+def test_connected_components_property(setup_data):
+    connected_components, _, _ = setup_data
+    components = np.array([[0,0,1,1,0,0],
+                            [0,0,0,1,0,0],
+                            [2,2,0,0,3,0],
+                            [0,0,0,4,0,0]])
+    assert np.array_equal(connected_components.connected_components, components)
+
+def test_num_connected_components_property(setup_data):
+    connected_components, _, num_components = setup_data
+    assert connected_components.num_connected_components == num_components
+
+def test_get_connected_component_with_index(setup_data):
+    connected_components, _, _ = setup_data
+    expected_component = np.array([[0,0,1,1,0,0],
+                                    [0,0,0,1,0,0],
+                                    [0,0,0,0,0,0],
+                                    [0,0,0,0,0,0]], dtype=bool)
+    print(connected_components.get_connected_component(index=1))
+    print(expected_component)
+    assert np.array_equal(connected_components.get_connected_component(index=1), expected_component)
+
+def test_get_connected_component_without_index(setup_data):
+    connected_components, _, _ = setup_data
+    component = connected_components.get_connected_component()
+    assert np.any(component)
+
+def test_get_connected_component_with_invalid_index(setup_data):
+    connected_components, _, num_components = setup_data
+    with pytest.raises(AssertionError):
+        connected_components.get_connected_component(index=0)
+    with pytest.raises(AssertionError):
+        connected_components.get_connected_component(index=num_components + 1)
\ No newline at end of file
diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index 06fec0d0c41d0c0269ffc94291a77a8b63fa2b36..d200c12d03eeab972f026ed70fea8ed074f4812c 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -1,10 +1,11 @@
-import qim3d
 import matplotlib.pyplot as plt
 import pytest
-
 from torch import ones
+
+import qim3d
 from qim3d.utils.internal_tools import temp_data
 
+
 # unit tests for grid overview
 def test_grid_overview():
     random_tuple = (ones(1,256,256),ones(256,256))
diff --git a/qim3d/tests/viz/test_visualizations.py b/qim3d/tests/viz/test_visualizations.py
index 75e84eb9fe06cffe7fa62a600a4821ebf8265fae..c9a8beceac2ae601b03adbc77d473fd339b577d1 100644
--- a/qim3d/tests/viz/test_visualizations.py
+++ b/qim3d/tests/viz/test_visualizations.py
@@ -1,6 +1,8 @@
-import qim3d
 import pytest
 
+import qim3d
+
+
 #unit test for plot_metrics()
 def test_plot_metrics():
     metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
diff --git a/qim3d/utils/3d_connected_components.py b/qim3d/utils/connected_components.py
similarity index 83%
rename from qim3d/utils/3d_connected_components.py
rename to qim3d/utils/connected_components.py
index 36f986d2a886553a897ef73fe254821870f8d349..ff9e5dcb08c8b380d9f7706ee47bcdb2e3399120 100644
--- a/qim3d/utils/3d_connected_components.py
+++ b/qim3d/utils/connected_components.py
@@ -1,6 +1,7 @@
 import numpy as np
 from scipy.ndimage import label
 
+# TODO: implement find_objects and get_bounding_boxes methods
 
 class ConnectedComponents:
     def __init__(self, connected_components, num_connected_components):
@@ -44,15 +45,14 @@ class ConnectedComponents:
         Returns:
             np.ndarray: The connected component as a binary mask.
         """
-        assert 1 <= index <= self._num_connected_components, "Index out of range."
-
-        if index:
-            return self._connected_components == index
+        if index is None:
+            return self.connected_components == np.random.randint(1, self.num_connected_components + 1)
         else:
-            return self._connected_components == np.random.randint(1, self._num_connected_components + 1)
+            assert 1 <= index <= self.num_connected_components, "Index out of range."
+            return self.connected_components == index
 
 
-def get_3d_connected_components(image, connectivity=1):
+def get_3d_connected_components(image):
     """Get the connected components of a 3D binary image.
 
     Args:
@@ -62,5 +62,5 @@ def get_3d_connected_components(image, connectivity=1):
     Returns:
         class: Returns class object of the connected components.
     """
-    connected_components, num_connected_components = label(image, connectivity)
+    connected_components, num_connected_components = label(image)
     return ConnectedComponents(connected_components, num_connected_components)
diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index e3648f5445c7a4449c2b42057c78b2db9e48e381..8bcffba2e91a3046a5bd82166cc18b6ce53d2978 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -1,11 +1,14 @@
 """ Provides a collection of visualization functions."""
 import matplotlib.pyplot as plt
-from matplotlib.colors import LinearSegmentedColormap
-from matplotlib import colormaps
-import torch
 import numpy as np
-from qim3d.io.logger import log
+import torch
+from matplotlib import colormaps
+from matplotlib.colors import LinearSegmentedColormap
+
 import qim3d.io
+from qim3d.io.logger import log
+from qim3d.utils.connected_components import ConnectedComponents
+
 
 def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
     """Displays an overview grid of images, labels, and masks (if they exist).
@@ -298,4 +301,42 @@ def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = Fal
         plt.show()
     plt.close()
 
+    return fig
+
+def plot_connected_components(connected_components: ConnectedComponents, show=False):
+    """ Plots the connected components in 3D.
+
+    Args:
+        connected_components (ConnectedComponents): The connected components class from the qim3d.utils.connected_components module.
+        show (bool, optional): If matplotlib should show the plot. Defaults to False.
+
+    Returns:
+        matplotlib.pyplot: the 3D plot of the connected components.
+    """
+    # Begin plotting
+    fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
+
+    # Define default color theme
+    colors = plt.cm.tab10(np.linspace(0, 1, connected_components.num_connected_components + 1))
+
+    # Plot each component with a different color
+    for label_num in range(1, connected_components.num_connected_components + 1):
+        # Find the voxels that belong to the current component
+        component_voxels = connected_components.get_connected_component(label_num)
+        
+        # Plot each voxel of the component
+        for voxel in zip(*component_voxels.nonzero()):
+            x, y, z = voxel
+            ax.bar3d(x, y, z, 1, 1, 1, color=colors[label_num], shade=True, alpha=0.5)
+
+    # Set labels and titles if necessary
+    ax.set_xlabel('X axis')
+    ax.set_ylabel('Y axis')
+    ax.set_zlabel('Z axis')
+    ax.set_title('3D Visualization of Connected Components')
+    
+    if show:
+        plt.show()
+    plt.close()
+
     return fig
\ No newline at end of file