Skip to content
Snippets Groups Projects

Added cbar_style as option for slices()

+ 20
7
@@ -36,6 +36,7 @@ def slices(
interpolation: Optional[str] = "none",
img_size=None,
cbar: bool = False,
cbar_style: str = "small",
**imshow_kwargs,
) -> plt.Figure:
"""Displays one or several slices from a 3d volume.
@@ -59,6 +60,7 @@ def slices(
show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row and positioned in top-right. Option 'large' spans entire height of image grid. Defaults to 'small'.
Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
@@ -215,16 +217,27 @@ def slices(
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout()
norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
if cbar_style =="small":
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
elif cbar_style == "large":
# Figure coordinates of bottom- and top-right axis
br_pos = np.atleast_1d(axs[-1])[-1].get_position()
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes(
[br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
if show:
plt.show()
Loading