Skip to content
Snippets Groups Projects

Viz add colorbar

+ 28
4
@@ -8,6 +8,7 @@ from typing import List, Optional, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import qim3d
@@ -28,6 +29,7 @@ def slices(
show_position: bool = True,
interpolation: Optional[str] = "none",
img_size=None,
cbar: bool = False,
**imshow_kwargs,
) -> plt.Figure:
"""Displays one or several slices from a 3d volume.
@@ -50,6 +52,7 @@ def slices(
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
@@ -134,6 +137,11 @@ def slices(
if isinstance(vol, da.core.Array):
vol = vol.compute()
if cbar:
# In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin else np.min(vol)
new_vmax = vmax if vmax else np.max(vol)
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)):
@@ -141,10 +149,12 @@ def slices(
try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
)
@@ -181,6 +191,17 @@ def slices(
# Hide the axis, so that we have a nice grid
ax.axis("off")
if cbar:
fig.tight_layout()
norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
if show:
plt.show()
@@ -216,6 +237,7 @@ def slicer(
show_position: bool = False,
interpolation: Optional[str] = "none",
img_size=None,
cbar: bool = False,
**imshow_kwargs,
) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume.
@@ -230,6 +252,7 @@ def slicer(
img_width (int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
@@ -263,6 +286,7 @@ def slicer(
position=position,
n_slices=1,
show=True,
cbar=cbar,
**imshow_kwargs,
)
return fig
Loading