Skip to content
Snippets Groups Projects
Commit f990ec24 authored by fima's avatar fima :beers:
Browse files

Merge branch 'viz-add-colorbar' into 'main'

Viz add colorbar

See merge request !124
parents e353df0c e692b028
Branches
Tags
1 merge request!124Viz add colorbar
...@@ -3,11 +3,14 @@ Provides a collection of visualization functions. ...@@ -3,11 +3,14 @@ Provides a collection of visualization functions.
""" """
import math import math
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import dask.array as da import dask.array as da
import ipywidgets as widgets import ipywidgets as widgets
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib
import numpy as np import numpy as np
import qim3d import qim3d
...@@ -28,6 +31,7 @@ def slices( ...@@ -28,6 +31,7 @@ def slices(
show_position: bool = True, show_position: bool = True,
interpolation: Optional[str] = "none", interpolation: Optional[str] = "none",
img_size=None, img_size=None,
cbar: bool = False,
**imshow_kwargs, **imshow_kwargs,
) -> plt.Figure: ) -> plt.Figure:
"""Displays one or several slices from a 3d volume. """Displays one or several slices from a 3d volume.
...@@ -50,6 +54,7 @@ def slices( ...@@ -50,6 +54,7 @@ def slices(
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True. show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
Returns: Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
...@@ -127,6 +132,7 @@ def slices( ...@@ -127,6 +132,7 @@ def slices(
figsize=(ncols * img_height, nrows * img_width), figsize=(ncols * img_height, nrows * img_width),
constrained_layout=True, constrained_layout=True,
) )
if nrows == 1: if nrows == 1:
axs = [axs] # Convert to a list for uniformity axs = [axs] # Convert to a list for uniformity
...@@ -134,6 +140,11 @@ def slices( ...@@ -134,6 +140,11 @@ def slices(
if isinstance(vol, da.core.Array): if isinstance(vol, da.core.Array):
vol = vol.compute() vol = vol.compute()
if cbar:
# In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin else np.min(vol)
new_vmax = vmax if vmax else np.max(vol)
# Run through each ax of the grid # Run through each ax of the grid
for i, ax_row in enumerate(axs): for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)): for j, ax in enumerate(np.atleast_1d(ax_row)):
...@@ -141,10 +152,12 @@ def slices( ...@@ -141,10 +152,12 @@ def slices(
try: try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis) slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised # If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay # We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
ax.imshow( ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
) )
...@@ -181,6 +194,19 @@ def slices( ...@@ -181,6 +194,19 @@ def slices(
# Hide the axis, so that we have a nice grid # Hide the axis, so that we have a nice grid
ax.axis("off") ax.axis("off")
if cbar:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout()
norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
if show: if show:
plt.show() plt.show()
...@@ -216,6 +242,7 @@ def slicer( ...@@ -216,6 +242,7 @@ def slicer(
show_position: bool = False, show_position: bool = False,
interpolation: Optional[str] = "none", interpolation: Optional[str] = "none",
img_size=None, img_size=None,
cbar: bool = False,
**imshow_kwargs, **imshow_kwargs,
) -> widgets.interactive: ) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume. """Interactive widget for visualizing slices of a 3D volume.
...@@ -230,6 +257,7 @@ def slicer( ...@@ -230,6 +257,7 @@ def slicer(
img_width (int, optional): Width of the figure. Defaults to 3. img_width (int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False. show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
Returns: Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
...@@ -263,6 +291,7 @@ def slicer( ...@@ -263,6 +291,7 @@ def slicer(
position=position, position=position,
n_slices=1, n_slices=1,
show=True, show=True,
cbar=cbar,
**imshow_kwargs, **imshow_kwargs,
) )
return fig return fig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment