diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py
index 54a4b03449dbb3e0c95d5238ae5fff9e0947d5ff..1b20178527e8844d30f8fdf2cafd62a8d4002716 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/explore.py
@@ -3,11 +3,14 @@ Provides a collection of visualization functions.
 """
 
 import math
+import warnings
+
 from typing import List, Optional, Union
 
 import dask.array as da
 import ipywidgets as widgets
 import matplotlib.pyplot as plt
+import matplotlib
 import numpy as np
 
 import qim3d
@@ -28,6 +31,7 @@ def slices(
     show_position: bool = True,
     interpolation: Optional[str] = "none",
     img_size=None,
+    cbar: bool = False,
     **imshow_kwargs,
 ) -> plt.Figure:
     """Displays one or several slices from a 3d volume.
@@ -50,6 +54,7 @@ def slices(
         show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
         show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
+        cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
 
     Returns:
         fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
@@ -127,6 +132,7 @@ def slices(
         figsize=(ncols * img_height, nrows * img_width),
         constrained_layout=True,
     )
+
     if nrows == 1:
         axs = [axs]  # Convert to a list for uniformity
 
@@ -134,6 +140,11 @@ def slices(
     if isinstance(vol, da.core.Array):
         vol = vol.compute()
 
+    if cbar:
+        # In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
+        new_vmin = vmin if vmin else np.min(vol)
+        new_vmax = vmax if vmax else np.max(vol)
+    
     # Run through each ax of the grid
     for i, ax_row in enumerate(axs):
         for j, ax in enumerate(np.atleast_1d(ax_row)):
@@ -141,10 +152,12 @@ def slices(
             try:
                 slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
 
-                # If vmin is higher than the highest value in the image ValueError is raised
-                # We don't want to override the values because next slices might be okay
-                new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
-                new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
+                if not cbar:
+                    # If vmin is higher than the highest value in the image ValueError is raised
+                    # We don't want to override the values because next slices might be okay
+                    new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
+                    new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
+                
                 ax.imshow(
                     slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
                 )
@@ -181,6 +194,19 @@ def slices(
             # Hide the axis, so that we have a nice grid
             ax.axis("off")
 
+    if cbar:     
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore", category=UserWarning)
+            fig.tight_layout()
+        norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
+        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
+
+        # Figure coordinates of top-right axis
+        tr_pos = np.atleast_1d(axs[0])[-1].get_position()
+        # The width is divided by ncols to make it the same relative size to the images
+        cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
+        fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
+
     if show:
         plt.show()
 
@@ -216,6 +242,7 @@ def slicer(
     show_position: bool = False,
     interpolation: Optional[str] = "none",
     img_size=None,
+    cbar: bool = False,
     **imshow_kwargs,
 ) -> widgets.interactive:
     """Interactive widget for visualizing slices of a 3D volume.
@@ -230,6 +257,7 @@ def slicer(
         img_width (int, optional): Width of the figure. Defaults to 3.
         show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
+        cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
 
     Returns:
         slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
@@ -263,6 +291,7 @@ def slicer(
             position=position,
             n_slices=1,
             show=True,
+            cbar=cbar,
             **imshow_kwargs,
         )
         return fig