Skip to content
Snippets Groups Projects

Chunk visualization

Merged s212246 requested to merge chunk_visualization into main
1 file
+ 184
115
Compare changes
  • Side-by-side
  • Inline
+ 184
115
@@ -10,7 +10,7 @@ from typing import List, Optional, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display
from IPython.display import SVG, display
import matplotlib
import numpy as np
import zarr
@@ -27,8 +27,8 @@ def slices(
n_slices: int = 5,
max_cols: int = 5,
cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
vmin: float = None,
vmax: float = None,
img_height: int = 2,
img_width: int = 2,
show: bool = False,
@@ -148,7 +148,7 @@ def slices(
# In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin else np.min(vol)
new_vmax = vmax if vmax else np.max(vol)
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)):
@@ -159,11 +159,24 @@ def slices(
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
slice_img,
cmap=cmap,
interpolation=interpolation,
vmin=new_vmin,
vmax=new_vmax,
**imshow_kwargs,
)
if show_position:
@@ -198,7 +211,7 @@ def slices(
# Hide the axis, so that we have a nice grid
ax.axis("off")
if cbar:
if cbar:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout()
@@ -208,8 +221,10 @@ def slices(
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
cbar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
if show:
plt.show()
@@ -239,8 +254,8 @@ def slicer(
vol: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
vmin: float = None,
vmax: float = None,
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
@@ -286,8 +301,8 @@ def slicer(
vol,
axis=axis,
cmap=cmap,
vmin = vmin,
vmax = vmax,
vmin=vmin,
vmax=vmax,
img_height=img_height,
img_width=img_width,
show_position=show_position,
@@ -316,8 +331,8 @@ def slicer(
def orthogonal(
vol: np.ndarray,
cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
vmin: float = None,
vmax: float = None,
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
@@ -355,19 +370,19 @@ def orthogonal(
get_slicer_for_axis = lambda axis: slicer(
vol,
axis = axis,
cmap = cmap,
vmin = vmin,
vmax = vmax,
img_height = img_height,
img_width = img_width,
show_position = show_position,
interpolation = interpolation,
)
axis=axis,
cmap=cmap,
vmin=vmin,
vmax=vmax,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
z_slicer = get_slicer_for_axis(axis = 0)
y_slicer = get_slicer_for_axis(axis = 1)
x_slicer = get_slicer_for_axis(axis = 2)
z_slicer = get_slicer_for_axis(axis=0)
y_slicer = get_slicer_for_axis(axis=1)
x_slicer = get_slicer_for_axis(axis=2)
z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y"
@@ -376,7 +391,13 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', vmin:float = None, vmax:float = None):
def interactive_fade_mask(
vol: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask.
@@ -405,10 +426,18 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
slice_img = vol[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax)
axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title("Original")
axes[0].axis("off")
@@ -435,16 +464,24 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :]
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
axes[2].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax)
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[2].set_title("Masked")
axes[2].axis("off")
return fig
shape_dropdown = widgets.Dropdown(
options=["spherical", "cylindrical"],
options=["spherical", "cylindrical"],
value="spherical", # default value
description="Geometry",
)
@@ -491,45 +528,55 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
return slicer_obj
def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
def chunks(zarr_path: str, **kwargs):
"""
Function to visualize chunks of a Zarr dataset using the specified visualization method.
Args:
zarr_path (str): Path to the Zarr dataset.
visualization_method (str, optional): The visualization method to use ('slicer', 'slices', or 'vol'). Each method leverages the corresponding qim3d visualization function. Defaults to 'slicer'.
**kwargs: Additional keyword arguments to pass to the visualization method.
Example:
```python
import qim3d
zarr_path = "path/to/zarr/dataset.zarr"
qim3d.viz.chunks(zarr_path, visualization_method='vol')
qim3d.viz.chunks(zarr_path)
```
![chunks-visualization](assets/screenshots/chunks_visualization.gif)
"""
# Load the Zarr dataset
zarr_data = zarr.open(zarr_path, mode='r')
zarr_data = zarr.open(zarr_path, mode="r")
# Save arguments for later use
visualization_method = visualization_method
preserved_kwargs = kwargs
# visualization_method = visualization_method
# preserved_kwargs = kwargs
# Create label to display the chunk coordinates
chunk_info_label = widgets.HTML(value="Chunk coordinates and size will appear here")
widget_title = widgets.HTML("<h2>Chunk Explorer</h2>")
chunk_info_label = widgets.HTML(value="Chunk info will be displayed here")
def load_and_visualize(scale, z_coord, y_coord, x_coord, visualization_method, **kwargs):
def load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
):
# Get chunk shape for the selected scale
chunk_shape = zarr_data[scale].chunks
# Calculate slice indices for the selected chunk
slices = (
slice(z_coord * chunk_shape[0], min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0])),
slice(y_coord * chunk_shape[1], min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1])),
slice(x_coord * chunk_shape[2], min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]))
slice(
z_coord * chunk_shape[0],
min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]),
),
slice(
y_coord * chunk_shape[1],
min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]),
),
slice(
x_coord * chunk_shape[2],
min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]),
),
)
# Extract start and stop values from each slice object
@@ -537,57 +584,44 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
y_start, y_stop = slices[1].start, slices[1].stop
x_start, x_stop = slices[2].start, slices[2].stop
# Extract the chunk
chunk = zarr_data[scale][slices]
# Update the chunk info label with the chunk coordinates
info_string = (
f"<b>shape:</b> {chunk_shape}\n"
+ f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n"
+ f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n"
+ f"<b>dtype:</b> {chunk.dtype}\n"
+ f"<b>min value:</b> {np.min(chunk)}\n"
+ f"<b>max value:</b> {np.max(chunk)}\n"
+ f"<b>mean value:</b> {np.mean(chunk)}\n"
)
chunk_info_label.value = f"""
<div style="font-size: 14px; text-align: center;">
<b>Chunk Info</b>
<div style="font-size: 14px; text-align: left; margin-left:32px">
<h3 style="margin: 0px">Chunk Info</h3>
<div style="font-size: 14px; text-align: left;">
<pre>{info_string}</pre>
</div>
</div>
<div style="font-size: 14px; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: left;">
<table style="font-size: 13px; border-collapse: collapse; width: 100%;">
<tr style="background-color: #f5f5f5;">
<td colspan="2" style="text-align: center;"><b>Range:</b></td>
</tr>
<tr>
<td style="text-align: right; padding-right: 10px;">Z:</td>
<td>{z_start}-{z_stop}</td>
</tr>
<tr style="background-color: #f5f5f5;">
<td style="text-align: right; padding-right: 10px;">Y:</td>
<td>{y_start}-{y_stop}</td>
</tr>
<tr>
<td style="text-align: right; padding-right: 10px;">X:</td>
<td>{x_start}-{x_stop}</td>
</tr>
</table>
</div>
<div style="flex: 1; text-align: right; padding-left: 20px; padding-top: 3px; font-size: 13px; white-space: nowrap;">
<b>Chunk size:</b> ({z_stop - z_start}, {y_stop - y_start}, {x_stop - x_start})
</div>
</div>
"""
# Extract the chunk
chunk = zarr_data[scale][slices]
"""
# Prepare chunk visualization based on the selected method
if visualization_method == 'slicer': # return a widget
if visualization_method == "slicer": # return a widget
viz_widget = qim3d.viz.slicer(chunk, **kwargs)
elif visualization_method == 'slices': # return a plt.Figure
elif visualization_method == "slices": # return a plt.Figure
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices(chunk, **kwargs)
display(fig)
elif visualization_method == 'vol':
elif visualization_method == "vol":
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
out = qim3d.viz.vol(chunk, show = False, **kwargs)
out = qim3d.viz.vol(chunk, show=False, **kwargs)
display(out)
else:
log.info(f"Invalid visualization method: {visualization_method}")
@@ -598,14 +632,17 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
def get_num_chunks(shape, chunk_size):
return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]
scale_options = {f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))} # len(zarr_data) gives number of scales
scale_options = {
f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))
} # len(zarr_data) gives number of scales
description_width = "128px"
# Create dropdown for scale
scale_dropdown = widgets.Dropdown(
options=scale_options,
options=scale_options,
value=0, # Default to first scale
description='Scale:',
description="OME-Zarr scale",
style={"description_width": description_width, "text_align": "left"},
)
# Initialize the options for x, y, and z based on the first scale by default
@@ -616,60 +653,78 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
z_dropdown = widgets.Dropdown(
options=list(range(num_chunks[0])),
value=0,
description='Z:',
description="First dimension (Z)",
style={"description_width": description_width, "text_align": "left"},
)
y_dropdown = widgets.Dropdown(
options=list(range(num_chunks[1])),
value=0,
description='Y:',
description="Second dimension (Y)",
style={"description_width": description_width, "text_align": "left"},
)
x_dropdown = widgets.Dropdown(
options=list(range(num_chunks[2])),
value=0,
description='X:',
description="Third dimension (X)",
style={"description_width": description_width, "text_align": "left"},
)
method_dropdown = widgets.Dropdown(
options=["slicer", "slices", "vol"],
value="slicer",
description="Visualization",
style={"description_width": description_width, "text_align": "left"},
)
# Funtion to temporarily disable observers
def disable_observers():
x_dropdown.unobserve(update_visualization, names='value')
y_dropdown.unobserve(update_visualization, names='value')
z_dropdown.unobserve(update_visualization, names='value')
x_dropdown.unobserve(update_visualization, names="value")
y_dropdown.unobserve(update_visualization, names="value")
z_dropdown.unobserve(update_visualization, names="value")
method_dropdown.unobserve(update_visualization, names="value")
# Funtion to enable observers
def enable_observers():
x_dropdown.observe(update_visualization, names='value')
y_dropdown.observe(update_visualization, names='value')
z_dropdown.observe(update_visualization, names='value')
x_dropdown.observe(update_visualization, names="value")
y_dropdown.observe(update_visualization, names="value")
z_dropdown.observe(update_visualization, names="value")
method_dropdown.observe(update_visualization, names="value")
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def update_coordinate_dropdowns(scale):
disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns
disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns
multiscale_shape = zarr_data[scale].shape
chunk_shape = zarr_data[scale].chunks
num_chunks = get_num_chunks(multiscale_shape, chunk_shape) # Calculate new chunk options
num_chunks = get_num_chunks(
multiscale_shape, chunk_shape
) # Calculate new chunk options
# Reset X, Y, Z dropdowns to 0
z_dropdown.options = list(range(num_chunks[0]))
z_dropdown.value = 0 # Reset to 0
z_dropdown.disabled = len(z_dropdown.options) == 1 # Disable if only one option (0) is available
z_dropdown.disabled = (
len(z_dropdown.options) == 1
) # Disable if only one option (0) is available
y_dropdown.options = list(range(num_chunks[1]))
y_dropdown.value = 0 # Reset to 0
y_dropdown.disabled = len(y_dropdown.options) == 1 # Disable if only one option (0) is available
y_dropdown.disabled = (
len(y_dropdown.options) == 1
) # Disable if only one option (0) is available
x_dropdown.options = list(range(num_chunks[2]))
x_dropdown.value = 0 # Reset to 0
x_dropdown.disabled = len(x_dropdown.options) == 1 # Disable if only one option (0) is available
x_dropdown.disabled = (
len(x_dropdown.options) == 1
) # Disable if only one option (0) is available
enable_observers()
update_visualization()
update_visualization()
# Function to update the visualization when any dropdown value changes
def update_visualization(*args):
@@ -677,25 +732,39 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
x_coord = x_dropdown.value
y_coord = y_dropdown.value
z_coord = z_dropdown.value
visualization_method = method_dropdown.value
# Clear and update the chunk visualization
slicer_widget = load_and_visualize(scale, z_coord, y_coord, x_coord, visualization_method, **preserved_kwargs)
slicer_widget = load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
)
# Recreate the layout and display the new visualization
vbox_layout.children = [hbox_layout, slicer_widget]
final_layout.children = [widget_title, hbox_layout, slicer_widget]
# Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
scale_dropdown.observe(lambda change: update_coordinate_dropdowns(scale_dropdown.value), names='value')
scale_dropdown.observe(
lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value"
)
enable_observers()
# Create first visualization
slicer_widget = load_and_visualize(scale_dropdown.value, z_dropdown.value, y_dropdown.value, x_dropdown.value, visualization_method, **preserved_kwargs)
slicer_widget = load_and_visualize(
scale_dropdown.value,
z_dropdown.value,
y_dropdown.value,
x_dropdown.value,
method_dropdown.value,
**kwargs,
)
# Create the layout
vbox_dropbox = widgets.VBox([scale_dropdown, z_dropdown, y_dropdown, x_dropdown], layout=widgets.Layout(margin='35px 80px 0 0'))
vbox_dropbox = widgets.VBox(
[scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown]
)
hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label])
vbox_layout = widgets.VBox([hbox_layout, slicer_widget])
final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget])
# Display the VBox
display(vbox_layout)
\ No newline at end of file
display(final_layout)
Loading