Skip to content
Snippets Groups Projects
Commit 93088abb authored by fima's avatar fima :beers:
Browse files

Merge branch 'chunk_visualization' into 'main'

Chunk visualization

See merge request !128
parents f990ec24 1b165d02
Branches
No related tags found
1 merge request!128Chunk visualization
docs/assets/screenshots/chunks_visualization.gif

1.58 MiB

......@@ -8,6 +8,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v
- slicer
- orthogonal
- vol
- chunks
- itk_vtk
- mesh
- local_thickness
......
......@@ -6,6 +6,7 @@ from .explore import (
orthogonal,
slicer,
slices,
chunks,
)
from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
from .k3d import vol, mesh
......
......@@ -10,8 +10,12 @@ from typing import List, Optional, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import SVG, display
import matplotlib
import numpy as np
import zarr
from qim3d.utils.logger import log
import qim3d
......@@ -155,11 +159,24 @@ def slices(
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
slice_img,
cmap=cmap,
interpolation=interpolation,
vmin=new_vmin,
vmax=new_vmax,
**imshow_kwargs,
)
if show_position:
......@@ -204,8 +221,10 @@ def slices(
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
cbar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
if show:
plt.show()
......@@ -372,7 +391,13 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', vmin:float = None, vmax:float = None):
def interactive_fade_mask(
vol: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask.
......@@ -401,8 +426,16 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
slice_img = vol[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title("Original")
......@@ -431,8 +464,16 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :]
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[2].set_title("Masked")
axes[2].axis("off")
......@@ -485,3 +526,252 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj
def chunks(zarr_path: str, **kwargs):
"""
Function to visualize chunks of a Zarr dataset using the specified visualization method.
Args:
zarr_path (str): Path to the Zarr dataset.
**kwargs: Additional keyword arguments to pass to the visualization method.
Example:
```python
import qim3d
# Download dataset
downloader = qim3d.io.Downloader()
data = downloader.Snail.Escargot(load_file=True)
# Export as OME-Zarr
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True)
# Explore chunks
qim3d.viz.chunks("Escargot.zarr")
```
![chunks-visualization](assets/screenshots/chunks_visualization.gif)
"""
# Load the Zarr dataset
zarr_data = zarr.open(zarr_path, mode="r")
# Save arguments for later use
# visualization_method = visualization_method
# preserved_kwargs = kwargs
# Create label to display the chunk coordinates
widget_title = widgets.HTML("<h2>Chunk Explorer</h2>")
chunk_info_label = widgets.HTML(value="Chunk info will be displayed here")
def load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
):
# Get chunk shape for the selected scale
chunk_shape = zarr_data[scale].chunks
# Calculate slice indices for the selected chunk
slices = (
slice(
z_coord * chunk_shape[0],
min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]),
),
slice(
y_coord * chunk_shape[1],
min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]),
),
slice(
x_coord * chunk_shape[2],
min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]),
),
)
# Extract start and stop values from each slice object
z_start, z_stop = slices[0].start, slices[0].stop
y_start, y_stop = slices[1].start, slices[1].stop
x_start, x_stop = slices[2].start, slices[2].stop
# Extract the chunk
chunk = zarr_data[scale][slices]
# Update the chunk info label with the chunk coordinates
info_string = (
f"<b>shape:</b> {chunk_shape}\n"
+ f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n"
+ f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n"
+ f"<b>dtype:</b> {chunk.dtype}\n"
+ f"<b>min value:</b> {np.min(chunk)}\n"
+ f"<b>max value:</b> {np.max(chunk)}\n"
+ f"<b>mean value:</b> {np.mean(chunk)}\n"
)
chunk_info_label.value = f"""
<div style="font-size: 14px; text-align: left; margin-left:32px">
<h3 style="margin: 0px">Chunk Info</h3>
<div style="font-size: 14px; text-align: left;">
<pre>{info_string}</pre>
</div>
</div>
"""
# Prepare chunk visualization based on the selected method
if visualization_method == "slicer": # return a widget
viz_widget = qim3d.viz.slicer(chunk, **kwargs)
elif visualization_method == "slices": # return a plt.Figure
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices(chunk, **kwargs)
display(fig)
elif visualization_method == "vol":
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
out = qim3d.viz.vol(chunk, show=False, **kwargs)
display(out)
else:
log.info(f"Invalid visualization method: {visualization_method}")
return viz_widget
# Function to calculate the number of chunks for each dimension, including partial chunks
def get_num_chunks(shape, chunk_size):
return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]
scale_options = {
f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))
} # len(zarr_data) gives number of scales
description_width = "128px"
# Create dropdown for scale
scale_dropdown = widgets.Dropdown(
options=scale_options,
value=0, # Default to first scale
description="OME-Zarr scale",
style={"description_width": description_width, "text_align": "left"},
)
# Initialize the options for x, y, and z based on the first scale by default
multiscale_shape = zarr_data[0].shape
chunk_shape = zarr_data[0].chunks
num_chunks = get_num_chunks(multiscale_shape, chunk_shape)
z_dropdown = widgets.Dropdown(
options=list(range(num_chunks[0])),
value=0,
description="First dimension (Z)",
style={"description_width": description_width, "text_align": "left"},
)
y_dropdown = widgets.Dropdown(
options=list(range(num_chunks[1])),
value=0,
description="Second dimension (Y)",
style={"description_width": description_width, "text_align": "left"},
)
x_dropdown = widgets.Dropdown(
options=list(range(num_chunks[2])),
value=0,
description="Third dimension (X)",
style={"description_width": description_width, "text_align": "left"},
)
method_dropdown = widgets.Dropdown(
options=["slicer", "slices", "vol"],
value="slicer",
description="Visualization",
style={"description_width": description_width, "text_align": "left"},
)
# Funtion to temporarily disable observers
def disable_observers():
x_dropdown.unobserve(update_visualization, names="value")
y_dropdown.unobserve(update_visualization, names="value")
z_dropdown.unobserve(update_visualization, names="value")
method_dropdown.unobserve(update_visualization, names="value")
# Funtion to enable observers
def enable_observers():
x_dropdown.observe(update_visualization, names="value")
y_dropdown.observe(update_visualization, names="value")
z_dropdown.observe(update_visualization, names="value")
method_dropdown.observe(update_visualization, names="value")
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def update_coordinate_dropdowns(scale):
disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns
multiscale_shape = zarr_data[scale].shape
chunk_shape = zarr_data[scale].chunks
num_chunks = get_num_chunks(
multiscale_shape, chunk_shape
) # Calculate new chunk options
# Reset X, Y, Z dropdowns to 0
z_dropdown.options = list(range(num_chunks[0]))
z_dropdown.value = 0 # Reset to 0
z_dropdown.disabled = (
len(z_dropdown.options) == 1
) # Disable if only one option (0) is available
y_dropdown.options = list(range(num_chunks[1]))
y_dropdown.value = 0 # Reset to 0
y_dropdown.disabled = (
len(y_dropdown.options) == 1
) # Disable if only one option (0) is available
x_dropdown.options = list(range(num_chunks[2]))
x_dropdown.value = 0 # Reset to 0
x_dropdown.disabled = (
len(x_dropdown.options) == 1
) # Disable if only one option (0) is available
enable_observers()
update_visualization()
# Function to update the visualization when any dropdown value changes
def update_visualization(*args):
scale = scale_dropdown.value
x_coord = x_dropdown.value
y_coord = y_dropdown.value
z_coord = z_dropdown.value
visualization_method = method_dropdown.value
# Clear and update the chunk visualization
slicer_widget = load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
)
# Recreate the layout and display the new visualization
final_layout.children = [widget_title, hbox_layout, slicer_widget]
# Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
scale_dropdown.observe(
lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value"
)
enable_observers()
# Create first visualization
slicer_widget = load_and_visualize(
scale_dropdown.value,
z_dropdown.value,
y_dropdown.value,
x_dropdown.value,
method_dropdown.value,
**kwargs,
)
# Create the layout
vbox_dropbox = widgets.VBox(
[scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown]
)
hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label])
final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget])
# Display the VBox
display(final_layout)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment