diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py
index 7361d8cdd03af2d818a464d498ac6c854946c7b6..9b08365be27631d622fbf6edfb29014330885fc9 100644
--- a/qim3d/utils/__init__.py
+++ b/qim3d/utils/__init__.py
@@ -3,6 +3,6 @@ from . import doi, internal_tools
 from .augmentations import Augmentation
 from .cc import get_3d_cc
 from .data import Dataset, prepare_dataloaders, prepare_datasets
+from .img import overlay_rgb_images
 from .models import inference, model_summary, train_model
 from .system import Memory
-from .img import overlay_rgb_images
diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index a641174d1fba75ebc66847d328e92e956b3df8c0..b770c3d0a51477bc17a241812a32f18aeba12657 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -1,4 +1,5 @@
 from .visualizations import plot_metrics
 from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
 from .k3d import vol
-from .detection import circles
\ No newline at end of file
+from .colormaps import objects
+from .detection import circles
diff --git a/qim3d/viz/colormaps.py b/qim3d/viz/colormaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ee1b5f7f2708ee376574b896758dff8d8f1939
--- /dev/null
+++ b/qim3d/viz/colormaps.py
@@ -0,0 +1,95 @@
+import colorsys
+
+import numpy as np
+from matplotlib.colors import LinearSegmentedColormap
+
+from qim3d.io.logger import log
+
+
+def objects(
+    nlabels,
+    style="bright",
+    first_color_background=True,
+    last_color_background=False,
+    background_color=(0.0, 0.0, 0.0),
+    seed=19,
+):
+    """
+    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
+
+    Args:
+        nlabels (int): Number of labels (size of colormap)
+        style (str, optional): 'bright' for strong colors, 'soft' for pastel colors. Defaults to 'bright'.
+        first_color_background (bool, optional): Option to use first color as background. Defaults to True.
+        last_color_background (bool, optional): Option to use last color as background. Defaults to False.
+        seed (int, optional): Seed for random number generator. Defaults to 19.
+
+    Returns:
+        matplotlib.colors.LinearSegmentedColormap: Colormap for matplotlib
+    """
+    # Check style
+    if style not in ("bright", "soft"):
+        raise ValueError(
+            f'Please choose "bright" or "soft" for style in qim3dCmap not "{style}"'
+        )
+
+    # Translate strings to background color
+    color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)}
+    if not isinstance(background_color, tuple):
+        try:
+            background_color = color_dict[background_color]
+        except KeyError:
+            raise ValueError(
+                f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.'
+            )
+
+    # Add one to nlabels to include the background color
+    nlabels += 1
+
+    # Create a new random generator, to locally set seed
+    rng = np.random.default_rng(seed)
+
+    # Generate color map for bright colors, based on hsv
+    if style == "bright":
+        randHSVcolors = [
+            (
+                rng.uniform(low=0.0, high=1),
+                rng.uniform(low=0.4, high=1),
+                rng.uniform(low=0.9, high=1),
+            )
+            for i in range(nlabels)
+        ]
+
+        # Convert HSV list to RGB
+        randRGBcolors = []
+        for HSVcolor in randHSVcolors:
+            randRGBcolors.append(
+                colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])
+            )
+
+    # Generate soft pastel colors, by limiting the RGB spectrum
+    if style == "soft":
+        low = 0.6
+        high = 0.95
+        randRGBcolors = [
+            (
+                rng.uniform(low=low, high=high),
+                rng.uniform(low=low, high=high),
+                rng.uniform(low=low, high=high),
+            )
+            for i in range(nlabels)
+        ]
+
+    # Set first and last color to background
+    if first_color_background:
+        randRGBcolors[0] = background_color
+
+    if last_color_background:
+        randRGBcolors[-1] = background_color
+
+    # Create colormap
+    objects_cmap = LinearSegmentedColormap.from_list(
+        "objects_cmap", randRGBcolors, N=nlabels
+    )
+
+    return objects_cmap
diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index 812b2028621ff508a638b78b009d4fc69a7fd139..d3cf0bd4b9eb4fe7bd6224a8b00da126b11752aa 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -4,6 +4,8 @@ Provides a collection of visualization functions.
 
 import math
 from typing import List, Optional, Union, Tuple
+
+import ipywidgets as widgets
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
@@ -11,9 +13,9 @@ from matplotlib import colormaps
 from matplotlib.colors import LinearSegmentedColormap
 
 import qim3d.io
-import ipywidgets as widgets
 from qim3d.io.logger import log
 from qim3d.utils.cc import CC
+from qim3d.viz.colormaps import objects
 
 
 def grid_overview(
@@ -231,6 +233,7 @@ def slices(
     show: bool = False,
     show_position: bool = True,
     interpolation: Optional[str] = "none",
+    **imshow_kwargs,
 ) -> plt.Figure:
     """Displays one or several slices from a 3d volume.
 
@@ -334,7 +337,7 @@ def slices(
             slice_idx = i * max_cols + j
             try:
                 slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
-                ax.imshow(slice_img, cmap=cmap, interpolation=interpolation)
+                ax.imshow(slice_img, cmap=cmap, interpolation=interpolation, **imshow_kwargs)
 
                 if show_position:
                     ax.text(
@@ -400,6 +403,7 @@ def slicer(
     img_width: int = 3,
     show_position: bool = False,
     interpolation: Optional[str] = "none",
+    **imshow_kwargs,
 ) -> widgets.interactive:
     """Interactive widget for visualizing slices of a 3D volume.
 
@@ -437,6 +441,7 @@ def slicer(
             position=position,
             n_slices=1,
             show=True,
+            **imshow_kwargs,
         )
         return fig
 
@@ -535,14 +540,13 @@ def plot_cc(
         components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
         max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
         overlay (optional): Overlay image. Defaults to None.
-        crop (bool, optional): Whether to crop the overlay image. Defaults to False.
+        crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
         show (bool, optional): Whether to show the figure. Defaults to True.
         **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.
 
     Returns:
         figs (list[plt.Figure]): List of figures, if `show=False`.
     """
-    figs = []
     # if no components are given, plot the first max_cc_to_plot=32 components
     if component_indexs is None:
         if len(connected_components) > max_cc_to_plot:
@@ -552,12 +556,11 @@ def plot_cc(
         component_indexs = range(
             1, min(max_cc_to_plot + 1, len(connected_components) + 1)
         )
-
+        
+    figs = []
     for component in component_indexs:
         if overlay is not None:
-            assert (
-                overlay.shape == connected_components.shape
-            ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
+            assert (overlay.shape == connected_components.shape), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
 
             # plots overlay masked to connected component
             if crop:
@@ -573,10 +576,12 @@ def plot_cc(
                 overlay_crop = np.where(cc == 0, 0, overlay)
                 fig = slices(overlay_crop, show=show, **kwargs)
         else:
+            # assigns discrete color map to each connected component if not given 
+            if "cmap" not in kwargs:
+                kwargs["cmap"] = qim3dCmap(len(component_indexs))
+        
             # Plot the connected component without overlay
-            fig = slices(
-                connected_components.get_cc(component, crop=crop), show=show, **kwargs
-            )
+            fig = slices(connected_components.get_cc(component, crop=crop), show=show, **kwargs)
 
         figs.append(fig)
 
diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py
index c0759d92a7f7c686cc88519acaee9699594b2d90..d5f3ebbd6c9e236459df1348fa6bbe03ad1a5d56 100644
--- a/qim3d/viz/k3d.py
+++ b/qim3d/viz/k3d.py
@@ -11,15 +11,15 @@ import k3d
 import numpy as np
 
 
-def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwargs):
+def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, **kwargs):
     """
     Visualizes a 3D volume using volumetric rendering.
 
     Args:
         img (numpy.ndarray): The input 3D image data. It should be a 3D numpy array.
-        aspectmode (str, optional): Determines the proportions of the scene's axes. 
-            If "data", the axes are drawn in proportion with the axes' ranges. 
-            If "cube", the axes are drawn as a cube, regardless of the axes' ranges. 
+        aspectmode (str, optional): Determines the proportions of the scene's axes.
+            If "data", the axes are drawn in proportion with the axes' ranges.
+            If "cube", the axes are drawn as a cube, regardless of the axes' ranges.
             Defaults to "data".
         show (bool, optional): If True, displays the visualization inline. Defaults to True.
         save (bool or str, optional): If True, saves the visualization as an HTML file.
@@ -62,6 +62,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwa
             if aspectmode.lower() == "data"
             else None
         ),
+        color_map=cmap,
     )
     plot = k3d.plot(grid_visible=grid_visible, **kwargs)
     plot += plt_volume