diff --git a/docs/assets/screenshots/chunks_visualization.gif b/docs/assets/screenshots/chunks_visualization.gif new file mode 100644 index 0000000000000000000000000000000000000000..a4679b7b368d22f4f83dde90404ac7e467525457 Binary files /dev/null and b/docs/assets/screenshots/chunks_visualization.gif differ diff --git a/docs/viz.md b/docs/viz.md index 6a38714001dca3771ed99f9746175b738dc945fa..3efe4276bbb332c3b31e01dfb3c9d0a0dabf9e8c 100644 --- a/docs/viz.md +++ b/docs/viz.md @@ -8,6 +8,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - slicer - orthogonal - vol + - chunks - itk_vtk - mesh - local_thickness diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 444c190c1b545642a1e5c3bc491e3ea0dce8cc87..84b2e8356e5f9dd082ef3df17c8960c14f352b9e 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -6,6 +6,7 @@ from .explore import ( orthogonal, slicer, slices, + chunks, ) from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError from .k3d import vol, mesh diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py index 1b20178527e8844d30f8fdf2cafd62a8d4002716..d7392cbdb92ca249c872f97f45ff138216c66940 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/explore.py @@ -10,8 +10,12 @@ from typing import List, Optional, Union import dask.array as da import ipywidgets as widgets import matplotlib.pyplot as plt +from IPython.display import SVG, display import matplotlib import numpy as np +import zarr +from qim3d.utils.logger import log + import qim3d @@ -23,8 +27,8 @@ def slices( n_slices: int = 5, max_cols: int = 5, cmap: str = "viridis", - vmin:float = None, - vmax:float = None, + vmin: float = None, + vmax: float = None, img_height: int = 2, img_width: int = 2, show: bool = False, @@ -144,7 +148,7 @@ def slices( # In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar. new_vmin = vmin if vmin else np.min(vol) new_vmax = vmax if vmax else np.max(vol) - + # Run through each ax of the grid for i, ax_row in enumerate(axs): for j, ax in enumerate(np.atleast_1d(ax_row)): @@ -155,11 +159,24 @@ def slices( if not cbar: # If vmin is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin - new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax - + new_vmin = ( + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin + ) + new_vmax = ( + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax + ) + ax.imshow( - slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs + slice_img, + cmap=cmap, + interpolation=interpolation, + vmin=new_vmin, + vmax=new_vmax, + **imshow_kwargs, ) if show_position: @@ -194,7 +211,7 @@ def slices( # Hide the axis, so that we have a nice grid ax.axis("off") - if cbar: + if cbar: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) fig.tight_layout() @@ -204,8 +221,10 @@ def slices( # Figure coordinates of top-right axis tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images - cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height]) - fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical') + cbar_ax = fig.add_axes( + [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] + ) + fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") if show: plt.show() @@ -235,8 +254,8 @@ def slicer( vol: np.ndarray, axis: int = 0, cmap: str = "viridis", - vmin:float = None, - vmax:float = None, + vmin: float = None, + vmax: float = None, img_height: int = 3, img_width: int = 3, show_position: bool = False, @@ -282,8 +301,8 @@ def slicer( vol, axis=axis, cmap=cmap, - vmin = vmin, - vmax = vmax, + vmin=vmin, + vmax=vmax, img_height=img_height, img_width=img_width, show_position=show_position, @@ -312,8 +331,8 @@ def slicer( def orthogonal( vol: np.ndarray, cmap: str = "viridis", - vmin:float = None, - vmax:float = None, + vmin: float = None, + vmax: float = None, img_height: int = 3, img_width: int = 3, show_position: bool = False, @@ -351,19 +370,19 @@ def orthogonal( get_slicer_for_axis = lambda axis: slicer( vol, - axis = axis, - cmap = cmap, - vmin = vmin, - vmax = vmax, - img_height = img_height, - img_width = img_width, - show_position = show_position, - interpolation = interpolation, - ) + axis=axis, + cmap=cmap, + vmin=vmin, + vmax=vmax, + img_height=img_height, + img_width=img_width, + show_position=show_position, + interpolation=interpolation, + ) - z_slicer = get_slicer_for_axis(axis = 0) - y_slicer = get_slicer_for_axis(axis = 1) - x_slicer = get_slicer_for_axis(axis = 2) + z_slicer = get_slicer_for_axis(axis=0) + y_slicer = get_slicer_for_axis(axis=1) + x_slicer = get_slicer_for_axis(axis=2) z_slicer.children[0].description = "Z" y_slicer.children[0].description = "Y" @@ -372,7 +391,13 @@ def orthogonal( return widgets.HBox([z_slicer, y_slicer, x_slicer]) -def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', vmin:float = None, vmax:float = None): +def interactive_fade_mask( + vol: np.ndarray, + axis: int = 0, + cmap: str = "viridis", + vmin: float = None, + vmax: float = None, +): """Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. @@ -401,10 +426,18 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v slice_img = vol[position, :, :] # If vmin is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin - new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + new_vmin = ( + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin + ) + new_vmax = ( + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax + ) - axes[0].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax) + axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) axes[0].set_title("Original") axes[0].axis("off") @@ -431,16 +464,24 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v # If vmin is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay slice_img = masked_vol[position, :, :] - new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin - new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax - axes[2].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax) + new_vmin = ( + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin + ) + new_vmax = ( + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax + ) + axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) axes[2].set_title("Masked") axes[2].axis("off") return fig shape_dropdown = widgets.Dropdown( - options=["spherical", "cylindrical"], + options=["spherical", "cylindrical"], value="spherical", # default value description="Geometry", ) @@ -485,3 +526,252 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj + + +def chunks(zarr_path: str, **kwargs): + """ + Function to visualize chunks of a Zarr dataset using the specified visualization method. + + Args: + zarr_path (str): Path to the Zarr dataset. + **kwargs: Additional keyword arguments to pass to the visualization method. + + Example: + ```python + import qim3d + + # Download dataset + downloader = qim3d.io.Downloader() + data = downloader.Snail.Escargot(load_file=True) + + # Export as OME-Zarr + qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True) + + # Explore chunks + qim3d.viz.chunks("Escargot.zarr") + ``` +  + """ + + # Load the Zarr dataset + zarr_data = zarr.open(zarr_path, mode="r") + + # Save arguments for later use + # visualization_method = visualization_method + # preserved_kwargs = kwargs + + # Create label to display the chunk coordinates + widget_title = widgets.HTML("<h2>Chunk Explorer</h2>") + chunk_info_label = widgets.HTML(value="Chunk info will be displayed here") + + def load_and_visualize( + scale, z_coord, y_coord, x_coord, visualization_method, **kwargs + ): + # Get chunk shape for the selected scale + chunk_shape = zarr_data[scale].chunks + + # Calculate slice indices for the selected chunk + slices = ( + slice( + z_coord * chunk_shape[0], + min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]), + ), + slice( + y_coord * chunk_shape[1], + min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]), + ), + slice( + x_coord * chunk_shape[2], + min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]), + ), + ) + + # Extract start and stop values from each slice object + z_start, z_stop = slices[0].start, slices[0].stop + y_start, y_stop = slices[1].start, slices[1].stop + x_start, x_stop = slices[2].start, slices[2].stop + + # Extract the chunk + chunk = zarr_data[scale][slices] + + # Update the chunk info label with the chunk coordinates + info_string = ( + f"<b>shape:</b> {chunk_shape}\n" + + f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n" + + f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n" + + f"<b>dtype:</b> {chunk.dtype}\n" + + f"<b>min value:</b> {np.min(chunk)}\n" + + f"<b>max value:</b> {np.max(chunk)}\n" + + f"<b>mean value:</b> {np.mean(chunk)}\n" + ) + + chunk_info_label.value = f""" + <div style="font-size: 14px; text-align: left; margin-left:32px"> + <h3 style="margin: 0px">Chunk Info</h3> + <div style="font-size: 14px; text-align: left;"> + <pre>{info_string}</pre> + </div> + </div> + + """ + + # Prepare chunk visualization based on the selected method + if visualization_method == "slicer": # return a widget + viz_widget = qim3d.viz.slicer(chunk, **kwargs) + elif visualization_method == "slices": # return a plt.Figure + viz_widget = widgets.Output() + with viz_widget: + viz_widget.clear_output(wait=True) + fig = qim3d.viz.slices(chunk, **kwargs) + display(fig) + elif visualization_method == "vol": + viz_widget = widgets.Output() + with viz_widget: + viz_widget.clear_output(wait=True) + out = qim3d.viz.vol(chunk, show=False, **kwargs) + display(out) + else: + log.info(f"Invalid visualization method: {visualization_method}") + + return viz_widget + + # Function to calculate the number of chunks for each dimension, including partial chunks + def get_num_chunks(shape, chunk_size): + return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)] + + scale_options = { + f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data)) + } # len(zarr_data) gives number of scales + + description_width = "128px" + # Create dropdown for scale + scale_dropdown = widgets.Dropdown( + options=scale_options, + value=0, # Default to first scale + description="OME-Zarr scale", + style={"description_width": description_width, "text_align": "left"}, + ) + + # Initialize the options for x, y, and z based on the first scale by default + multiscale_shape = zarr_data[0].shape + chunk_shape = zarr_data[0].chunks + num_chunks = get_num_chunks(multiscale_shape, chunk_shape) + + z_dropdown = widgets.Dropdown( + options=list(range(num_chunks[0])), + value=0, + description="First dimension (Z)", + style={"description_width": description_width, "text_align": "left"}, + ) + + y_dropdown = widgets.Dropdown( + options=list(range(num_chunks[1])), + value=0, + description="Second dimension (Y)", + style={"description_width": description_width, "text_align": "left"}, + ) + + x_dropdown = widgets.Dropdown( + options=list(range(num_chunks[2])), + value=0, + description="Third dimension (X)", + style={"description_width": description_width, "text_align": "left"}, + ) + + method_dropdown = widgets.Dropdown( + options=["slicer", "slices", "vol"], + value="slicer", + description="Visualization", + style={"description_width": description_width, "text_align": "left"}, + ) + + # Funtion to temporarily disable observers + def disable_observers(): + x_dropdown.unobserve(update_visualization, names="value") + y_dropdown.unobserve(update_visualization, names="value") + z_dropdown.unobserve(update_visualization, names="value") + method_dropdown.unobserve(update_visualization, names="value") + + # Funtion to enable observers + def enable_observers(): + x_dropdown.observe(update_visualization, names="value") + y_dropdown.observe(update_visualization, names="value") + z_dropdown.observe(update_visualization, names="value") + method_dropdown.observe(update_visualization, names="value") + + # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0 + def update_coordinate_dropdowns(scale): + + disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns + + multiscale_shape = zarr_data[scale].shape + chunk_shape = zarr_data[scale].chunks + num_chunks = get_num_chunks( + multiscale_shape, chunk_shape + ) # Calculate new chunk options + + # Reset X, Y, Z dropdowns to 0 + z_dropdown.options = list(range(num_chunks[0])) + z_dropdown.value = 0 # Reset to 0 + z_dropdown.disabled = ( + len(z_dropdown.options) == 1 + ) # Disable if only one option (0) is available + + y_dropdown.options = list(range(num_chunks[1])) + y_dropdown.value = 0 # Reset to 0 + y_dropdown.disabled = ( + len(y_dropdown.options) == 1 + ) # Disable if only one option (0) is available + + x_dropdown.options = list(range(num_chunks[2])) + x_dropdown.value = 0 # Reset to 0 + x_dropdown.disabled = ( + len(x_dropdown.options) == 1 + ) # Disable if only one option (0) is available + + enable_observers() + + update_visualization() + + # Function to update the visualization when any dropdown value changes + def update_visualization(*args): + scale = scale_dropdown.value + x_coord = x_dropdown.value + y_coord = y_dropdown.value + z_coord = z_dropdown.value + visualization_method = method_dropdown.value + + # Clear and update the chunk visualization + slicer_widget = load_and_visualize( + scale, z_coord, y_coord, x_coord, visualization_method, **kwargs + ) + + # Recreate the layout and display the new visualization + final_layout.children = [widget_title, hbox_layout, slicer_widget] + + # Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes + scale_dropdown.observe( + lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value" + ) + + enable_observers() + + # Create first visualization + slicer_widget = load_and_visualize( + scale_dropdown.value, + z_dropdown.value, + y_dropdown.value, + x_dropdown.value, + method_dropdown.value, + **kwargs, + ) + + # Create the layout + vbox_dropbox = widgets.VBox( + [scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown] + ) + hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label]) + final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget]) + + # Display the VBox + display(final_layout)