diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py index 54a4b03449dbb3e0c95d5238ae5fff9e0947d5ff..1b20178527e8844d30f8fdf2cafd62a8d4002716 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/explore.py @@ -3,11 +3,14 @@ Provides a collection of visualization functions. """ import math +import warnings + from typing import List, Optional, Union import dask.array as da import ipywidgets as widgets import matplotlib.pyplot as plt +import matplotlib import numpy as np import qim3d @@ -28,6 +31,7 @@ def slices( show_position: bool = True, interpolation: Optional[str] = "none", img_size=None, + cbar: bool = False, **imshow_kwargs, ) -> plt.Figure: """Displays one or several slices from a 3d volume. @@ -50,6 +54,7 @@ def slices( show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. show_position (bool, optional): If True, displays the position of the slices. Defaults to True. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. + cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. Returns: fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. @@ -127,6 +132,7 @@ def slices( figsize=(ncols * img_height, nrows * img_width), constrained_layout=True, ) + if nrows == 1: axs = [axs] # Convert to a list for uniformity @@ -134,6 +140,11 @@ def slices( if isinstance(vol, da.core.Array): vol = vol.compute() + if cbar: + # In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar. + new_vmin = vmin if vmin else np.min(vol) + new_vmax = vmax if vmax else np.max(vol) + # Run through each ax of the grid for i, ax_row in enumerate(axs): for j, ax in enumerate(np.atleast_1d(ax_row)): @@ -141,10 +152,12 @@ def slices( try: slice_img = vol.take(slice_idxs[slice_idx], axis=axis) - # If vmin is higher than the highest value in the image ValueError is raised - # We don't want to override the values because next slices might be okay - new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin - new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + if not cbar: + # If vmin is higher than the highest value in the image ValueError is raised + # We don't want to override the values because next slices might be okay + new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin + new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + ax.imshow( slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs ) @@ -181,6 +194,19 @@ def slices( # Hide the axis, so that we have a nice grid ax.axis("off") + if cbar: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + fig.tight_layout() + norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True) + mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + + # Figure coordinates of top-right axis + tr_pos = np.atleast_1d(axs[0])[-1].get_position() + # The width is divided by ncols to make it the same relative size to the images + cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height]) + fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical') + if show: plt.show() @@ -216,6 +242,7 @@ def slicer( show_position: bool = False, interpolation: Optional[str] = "none", img_size=None, + cbar: bool = False, **imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. @@ -230,6 +257,7 @@ def slicer( img_width (int, optional): Width of the figure. Defaults to 3. show_position (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. + cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. Returns: slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. @@ -263,6 +291,7 @@ def slicer( position=position, n_slices=1, show=True, + cbar=cbar, **imshow_kwargs, ) return fig