Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
Loading items

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
Loading items
Show changes
Commits on Source (2)
docs/assets/screenshots/interactive_thresholding.gif

2.57 MiB

...@@ -24,6 +24,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v ...@@ -24,6 +24,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v
- colormaps - colormaps
- fade_mask - fade_mask
- line_profile - line_profile
- threshold
::: qim3d.viz.colormaps ::: qim3d.viz.colormaps
options: options:
......
# from ._sync import Sync # this will be added back after future development # from ._sync import Sync # this will be added back after future development
from ._convert import convert
from ._downloader import Downloader
from ._loading import load, load_mesh from ._loading import load, load_mesh
from ._ome_zarr import export_ome_zarr, import_ome_zarr from ._downloader import Downloader
from ._saving import save, save_mesh from ._saving import save, save_mesh
from ._convert import convert
from ._ome_zarr import export_ome_zarr, import_ome_zarr
...@@ -7,9 +7,10 @@ import numpy as np ...@@ -7,9 +7,10 @@ import numpy as np
import tifffile as tiff import tifffile as tiff
import zarr import zarr
import zarr.core import zarr.core
import qim3d
from tqdm import tqdm from tqdm import tqdm
from qim3d.io import save
from qim3d.utils._misc import stringify_path from qim3d.utils._misc import stringify_path
...@@ -121,7 +122,7 @@ class Convert: ...@@ -121,7 +122,7 @@ class Convert:
""" """
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(tif_path, z) qim3d.io.save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array: def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
""" """
...@@ -173,7 +174,7 @@ class Convert: ...@@ -173,7 +174,7 @@ class Convert:
""" """
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(nifti_path, z, compression=compression) qim3d.io.save(nifti_path, z, compression=compression)
def convert( def convert(
......
...@@ -4,12 +4,11 @@ from ._data_exploration import ( ...@@ -4,12 +4,11 @@ from ._data_exploration import (
chunks, chunks,
fade_mask, fade_mask,
histogram, histogram,
line_profile,
slicer, slicer,
slicer_orthogonal, slicer_orthogonal,
slices_grid, slices_grid,
chunks, threshold,
histogram,
line_profile
) )
from ._detection import circles from ._detection import circles
from ._k3d import mesh, volumetric from ._k3d import mesh, volumetric
......
...@@ -4,20 +4,27 @@ Provides a collection of visualization functions. ...@@ -4,20 +4,27 @@ Provides a collection of visualization functions.
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union, Tuple
import dask.array as da import dask.array as da
import ipywidgets as widgets import ipywidgets as widgets
import matplotlib import matplotlib
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib
from IPython.display import SVG, display, clear_output
import matplotlib
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
import skimage.measure import skimage.measure
from skimage.filters import (
threshold_otsu,
threshold_isodata,
threshold_li,
threshold_mean,
threshold_minimum,
threshold_triangle,
threshold_yen,
)
from IPython.display import clear_output, display
import qim3d import qim3d
from qim3d.utils._logger import log from qim3d.utils._logger import log
...@@ -29,7 +36,7 @@ def slices_grid( ...@@ -29,7 +36,7 @@ def slices_grid(
slice_positions: Optional[Union[str, int, List[int]]] = None, slice_positions: Optional[Union[str, int, List[int]]] = None,
num_slices: int = 15, num_slices: int = 15,
max_columns: int = 5, max_columns: int = 5,
color_map: str = 'magma', color_map: str = "magma",
value_min: float = None, value_min: float = None,
value_max: float = None, value_max: float = None,
image_size: int = None, image_size: int = None,
...@@ -39,7 +46,7 @@ def slices_grid( ...@@ -39,7 +46,7 @@ def slices_grid(
display_positions: bool = True, display_positions: bool = True,
interpolation: Optional[str] = None, interpolation: Optional[str] = None,
color_bar: bool = False, color_bar: bool = False,
color_bar_style: str = 'small', color_bar_style: str = "small",
**matplotlib_imshow_kwargs, **matplotlib_imshow_kwargs,
) -> matplotlib.figure.Figure: ) -> matplotlib.figure.Figure:
""" """
...@@ -93,18 +100,18 @@ def slices_grid( ...@@ -93,18 +100,18 @@ def slices_grid(
# If we pass python None to the imshow function, it will set to # If we pass python None to the imshow function, it will set to
# default value 'antialiased' # default value 'antialiased'
if interpolation is None: if interpolation is None:
interpolation = 'none' interpolation = "none"
# Numpy array or Torch tensor input # Numpy array or Torch tensor input
if not isinstance(volume, (np.ndarray, da.core.Array)): if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError('Data type not supported') raise ValueError("Data type not supported")
if volume.ndim < 3: if volume.ndim < 3:
raise ValueError( raise ValueError(
'The provided object is not a volume as it has less than 3 dimensions.' "The provided object is not a volume as it has less than 3 dimensions."
) )
color_bar_style_options = ['small', 'large'] color_bar_style_options = ["small", "large"]
if color_bar_style not in color_bar_style_options: if color_bar_style not in color_bar_style_options:
raise ValueError( raise ValueError(
f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}." f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}."
...@@ -122,11 +129,11 @@ def slices_grid( ...@@ -122,11 +129,11 @@ def slices_grid(
# Here we deal with the case that the user wants to use the objects colormap directly # Here we deal with the case that the user wants to use the objects colormap directly
if ( if (
type(color_map) == matplotlib.colors.LinearSegmentedColormap type(color_map) == matplotlib.colors.LinearSegmentedColormap
or color_map == 'segmentation' or color_map == "segmentation"
): ):
num_labels = len(np.unique(volume)) num_labels = len(np.unique(volume))
if color_map == 'segmentation': if color_map == "segmentation":
color_map = qim3d.viz.colormaps.segmentation(num_labels) color_map = qim3d.viz.colormaps.segmentation(num_labels)
# If value_min and value_max are not set like this, then in case the # If value_min and value_max are not set like this, then in case the
# number of objects changes on new slice, objects might change # number of objects changes on new slice, objects might change
...@@ -143,15 +150,15 @@ def slices_grid( ...@@ -143,15 +150,15 @@ def slices_grid(
slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int) slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int)
# Position is a string # Position is a string
elif isinstance(slice_positions, str) and slice_positions.lower() in [ elif isinstance(slice_positions, str) and slice_positions.lower() in [
'start', "start",
'mid', "mid",
'end', "end",
]: ]:
if slice_positions.lower() == 'start': if slice_positions.lower() == "start":
slice_idxs = _get_slice_range(0, num_slices, n_total) slice_idxs = _get_slice_range(0, num_slices, n_total)
elif slice_positions.lower() == 'mid': elif slice_positions.lower() == "mid":
slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total) slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total)
elif slice_positions.lower() == 'end': elif slice_positions.lower() == "end":
slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total) slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total)
# Position is an integer # Position is an integer
elif isinstance(slice_positions, int): elif isinstance(slice_positions, int):
...@@ -232,25 +239,25 @@ def slices_grid( ...@@ -232,25 +239,25 @@ def slices_grid(
ax.text( ax.text(
0.0, 0.0,
1.0, 1.0,
f'slice {slice_idxs[slice_idx]} ', f"slice {slice_idxs[slice_idx]} ",
transform=ax.transAxes, transform=ax.transAxes,
color='white', color="white",
fontsize=8, fontsize=8,
va='top', va="top",
ha='left', ha="left",
bbox=dict(facecolor='#303030', linewidth=0, pad=0), bbox=dict(facecolor="#303030", linewidth=0, pad=0),
) )
ax.text( ax.text(
1.0, 1.0,
0.0, 0.0,
f'axis {slice_axis} ', f"axis {slice_axis} ",
transform=ax.transAxes, transform=ax.transAxes,
color='white', color="white",
fontsize=8, fontsize=8,
va='bottom', va="bottom",
ha='right', ha="right",
bbox=dict(facecolor='#303030', linewidth=0, pad=0), bbox=dict(facecolor="#303030", linewidth=0, pad=0),
) )
except IndexError: except IndexError:
...@@ -258,11 +265,11 @@ def slices_grid( ...@@ -258,11 +265,11 @@ def slices_grid(
pass pass
# Hide the axis, so that we have a nice grid # Hide the axis, so that we have a nice grid
ax.axis('off') ax.axis("off")
if color_bar: if color_bar:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout() fig.tight_layout()
norm = matplotlib.colors.Normalize( norm = matplotlib.colors.Normalize(
...@@ -270,15 +277,15 @@ def slices_grid( ...@@ -270,15 +277,15 @@ def slices_grid(
) )
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map) mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map)
if color_bar_style == 'small': if color_bar_style == "small":
# Figure coordinates of top-right axis # Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images # The width is divided by ncols to make it the same relative size to the images
color_bar_ax = fig.add_axes( color_bar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
) )
fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical') fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
elif color_bar_style == 'large': elif color_bar_style == "large":
# Figure coordinates of bottom- and top-right axis # Figure coordinates of bottom- and top-right axis
br_pos = np.atleast_1d(axs[-1])[-1].get_position() br_pos = np.atleast_1d(axs[-1])[-1].get_position()
tr_pos = np.atleast_1d(axs[0])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position()
...@@ -291,7 +298,7 @@ def slices_grid( ...@@ -291,7 +298,7 @@ def slices_grid(
(tr_pos.y1 - br_pos.y0) - 0.0015, (tr_pos.y1 - br_pos.y0) - 0.0015,
] ]
) )
fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical') fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
if display_figure: if display_figure:
plt.show() plt.show()
...@@ -322,7 +329,7 @@ def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray ...@@ -322,7 +329,7 @@ def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray
def slicer( def slicer(
volume: np.ndarray, volume: np.ndarray,
slice_axis: int = 0, slice_axis: int = 0,
color_map: str = 'magma', color_map: str = "magma",
value_min: float = None, value_min: float = None,
value_max: float = None, value_max: float = None,
image_height: int = 3, image_height: int = 3,
...@@ -366,14 +373,14 @@ def slicer( ...@@ -366,14 +373,14 @@ def slicer(
image_height = image_size image_height = image_size
image_width = image_size image_width = image_size
color_bar_options = [None, 'slices', 'volume'] color_bar_options = [None, "slices", "volume"]
if color_bar not in color_bar_options: if color_bar not in color_bar_options:
raise ValueError( raise ValueError(
f"Unrecognized value '{color_bar}' for parameter color_bar. " f"Unrecognized value '{color_bar}' for parameter color_bar. "
f'Expected one of {color_bar_options}.' f"Expected one of {color_bar_options}."
) )
show_color_bar = color_bar is not None show_color_bar = color_bar is not None
if color_bar == 'slices': if color_bar == "slices":
# Precompute the minimum and maximum along each slice for faster widget sliding. # Precompute the minimum and maximum along each slice for faster widget sliding.
non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis) non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis)
slice_mins = np.min(volume, axis=non_slice_axes) slice_mins = np.min(volume, axis=non_slice_axes)
...@@ -381,7 +388,7 @@ def slicer( ...@@ -381,7 +388,7 @@ def slicer(
# Create the interactive widget # Create the interactive widget
def _slicer(slice_positions): def _slicer(slice_positions):
if color_bar == 'slices': if color_bar == "slices":
dynamic_min = slice_mins[slice_positions] dynamic_min = slice_mins[slice_positions]
dynamic_max = slice_maxs[slice_positions] dynamic_max = slice_maxs[slice_positions]
else: else:
...@@ -410,18 +417,18 @@ def slicer( ...@@ -410,18 +417,18 @@ def slicer(
value=volume.shape[slice_axis] // 2, value=volume.shape[slice_axis] // 2,
min=0, min=0,
max=volume.shape[slice_axis] - 1, max=volume.shape[slice_axis] - 1,
description='Slice', description="Slice",
continuous_update=True, continuous_update=True,
) )
slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider) slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider)
slicer_obj.layout = widgets.Layout(align_items='flex-start') slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj return slicer_obj
def slicer_orthogonal( def slicer_orthogonal(
volume: np.ndarray, volume: np.ndarray,
color_map: str = 'magma', color_map: str = "magma",
value_min: float = None, value_min: float = None,
value_max: float = None, value_max: float = None,
image_height: int = 3, image_height: int = 3,
...@@ -477,9 +484,9 @@ def slicer_orthogonal( ...@@ -477,9 +484,9 @@ def slicer_orthogonal(
y_slicer = get_slicer_for_axis(slice_axis=1) y_slicer = get_slicer_for_axis(slice_axis=1)
x_slicer = get_slicer_for_axis(slice_axis=2) x_slicer = get_slicer_for_axis(slice_axis=2)
z_slicer.children[0].description = 'Z' z_slicer.children[0].description = "Z"
y_slicer.children[0].description = 'Y' y_slicer.children[0].description = "Y"
x_slicer.children[0].description = 'X' x_slicer.children[0].description = "X"
return widgets.HBox([z_slicer, y_slicer, x_slicer]) return widgets.HBox([z_slicer, y_slicer, x_slicer])
...@@ -487,7 +494,7 @@ def slicer_orthogonal( ...@@ -487,7 +494,7 @@ def slicer_orthogonal(
def fade_mask( def fade_mask(
volume: np.ndarray, volume: np.ndarray,
axis: int = 0, axis: int = 0,
color_map: str = 'magma', color_map: str = "magma",
value_min: float = None, value_min: float = None,
value_max: float = None, value_max: float = None,
) -> widgets.interactive: ) -> widgets.interactive:
...@@ -537,8 +544,8 @@ def fade_mask( ...@@ -537,8 +544,8 @@ def fade_mask(
axes[0].imshow( axes[0].imshow(
slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
) )
axes[0].set_title('Original') axes[0].set_title("Original")
axes[0].axis('off') axes[0].axis("off")
mask = qim3d.operations.fade_mask( mask = qim3d.operations.fade_mask(
np.ones_like(volume), np.ones_like(volume),
...@@ -549,8 +556,8 @@ def fade_mask( ...@@ -549,8 +556,8 @@ def fade_mask(
invert=invert, invert=invert,
) )
axes[1].imshow(mask[position, :, :], cmap=color_map) axes[1].imshow(mask[position, :, :], cmap=color_map)
axes[1].set_title('Mask') axes[1].set_title("Mask")
axes[1].axis('off') axes[1].axis("off")
masked_volume = qim3d.operations.fade_mask( masked_volume = qim3d.operations.fade_mask(
volume, volume,
...@@ -576,22 +583,22 @@ def fade_mask( ...@@ -576,22 +583,22 @@ def fade_mask(
axes[2].imshow( axes[2].imshow(
slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
) )
axes[2].set_title('Masked') axes[2].set_title("Masked")
axes[2].axis('off') axes[2].axis("off")
return fig return fig
shape_dropdown = widgets.Dropdown( shape_dropdown = widgets.Dropdown(
options=['spherical', 'cylindrical'], options=["spherical", "cylindrical"],
value='spherical', # default value value="spherical", # default value
description='Geometry', description="Geometry",
) )
position_slider = widgets.IntSlider( position_slider = widgets.IntSlider(
value=volume.shape[0] // 2, value=volume.shape[0] // 2,
min=0, min=0,
max=volume.shape[0] - 1, max=volume.shape[0] - 1,
description='Slice', description="Slice",
continuous_update=False, continuous_update=False,
) )
decay_rate_slider = widgets.FloatSlider( decay_rate_slider = widgets.FloatSlider(
...@@ -599,7 +606,7 @@ def fade_mask( ...@@ -599,7 +606,7 @@ def fade_mask(
min=1, min=1,
max=50, max=50,
step=1.0, step=1.0,
description='Decay Rate', description="Decay Rate",
continuous_update=False, continuous_update=False,
) )
ratio_slider = widgets.FloatSlider( ratio_slider = widgets.FloatSlider(
...@@ -607,14 +614,14 @@ def fade_mask( ...@@ -607,14 +614,14 @@ def fade_mask(
min=0.1, min=0.1,
max=1, max=1,
step=0.01, step=0.01,
description='Ratio', description="Ratio",
continuous_update=False, continuous_update=False,
) )
# Create the Checkbox widget # Create the Checkbox widget
invert_checkbox = widgets.Checkbox( invert_checkbox = widgets.Checkbox(
value=False, value=False,
description='Invert', # default value description="Invert", # default value
) )
slicer_obj = widgets.interactive( slicer_obj = widgets.interactive(
...@@ -625,7 +632,7 @@ def fade_mask( ...@@ -625,7 +632,7 @@ def fade_mask(
geometry=shape_dropdown, geometry=shape_dropdown,
invert=invert_checkbox, invert=invert_checkbox,
) )
slicer_obj.layout = widgets.Layout(align_items='flex-start') slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj return slicer_obj
...@@ -657,15 +664,15 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -657,15 +664,15 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
""" """
# Load the Zarr dataset # Load the Zarr dataset
zarr_data = zarr.open(zarr_path, mode='r') zarr_data = zarr.open(zarr_path, mode="r")
# Save arguments for later use # Save arguments for later use
# visualization_method = visualization_method # visualization_method = visualization_method
# preserved_kwargs = kwargs # preserved_kwargs = kwargs
# Create label to display the chunk coordinates # Create label to display the chunk coordinates
widget_title = widgets.HTML('<h2>Chunk Explorer</h2>') widget_title = widgets.HTML("<h2>Chunk Explorer</h2>")
chunk_info_label = widgets.HTML(value='Chunk info will be displayed here') chunk_info_label = widgets.HTML(value="Chunk info will be displayed here")
def load_and_visualize( def load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
...@@ -699,13 +706,13 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -699,13 +706,13 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
# Update the chunk info label with the chunk coordinates # Update the chunk info label with the chunk coordinates
info_string = ( info_string = (
f'<b>shape:</b> {chunk_shape}\n' f"<b>shape:</b> {chunk_shape}\n"
+ f'<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n' + f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n"
+ f'<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n' + f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n"
+ f'<b>dtype:</b> {chunk.dtype}\n' + f"<b>dtype:</b> {chunk.dtype}\n"
+ f'<b>min value:</b> {np.min(chunk)}\n' + f"<b>min value:</b> {np.min(chunk)}\n"
+ f'<b>max value:</b> {np.max(chunk)}\n' + f"<b>max value:</b> {np.max(chunk)}\n"
+ f'<b>mean value:</b> {np.mean(chunk)}\n' + f"<b>mean value:</b> {np.mean(chunk)}\n"
) )
chunk_info_label.value = f""" chunk_info_label.value = f"""
...@@ -719,22 +726,22 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -719,22 +726,22 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
""" """
# Prepare chunk visualization based on the selected method # Prepare chunk visualization based on the selected method
if visualization_method == 'slicer': # return a widget if visualization_method == "slicer": # return a widget
viz_widget = qim3d.viz.slicer(chunk, **kwargs) viz_widget = qim3d.viz.slicer(chunk, **kwargs)
elif visualization_method == 'slices': # return a plt.Figure elif visualization_method == "slices": # return a plt.Figure
viz_widget = widgets.Output() viz_widget = widgets.Output()
with viz_widget: with viz_widget:
viz_widget.clear_output(wait=True) viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices_grid(chunk, **kwargs) fig = qim3d.viz.slices_grid(chunk, **kwargs)
display(fig) display(fig)
elif visualization_method == 'volume': elif visualization_method == "volume":
viz_widget = widgets.Output() viz_widget = widgets.Output()
with viz_widget: with viz_widget:
viz_widget.clear_output(wait=True) viz_widget.clear_output(wait=True)
out = qim3d.viz.volumetric(chunk, show=False, **kwargs) out = qim3d.viz.volumetric(chunk, show=False, **kwargs)
display(out) display(out)
else: else:
log.info(f'Invalid visualization method: {visualization_method}') log.info(f"Invalid visualization method: {visualization_method}")
return viz_widget return viz_widget
...@@ -743,16 +750,16 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -743,16 +750,16 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)] return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]
scale_options = { scale_options = {
f'{i} {zarr_data[i].shape}': i for i in range(len(zarr_data)) f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))
} # len(zarr_data) gives number of scales } # len(zarr_data) gives number of scales
description_width = '128px' description_width = "128px"
# Create dropdown for scale # Create dropdown for scale
scale_dropdown = widgets.Dropdown( scale_dropdown = widgets.Dropdown(
options=scale_options, options=scale_options,
value=0, # Default to first scale value=0, # Default to first scale
description='OME-Zarr scale', description="OME-Zarr scale",
style={'description_width': description_width, 'text_align': 'left'}, style={"description_width": description_width, "text_align": "left"},
) )
# Initialize the options for x, y, and z based on the first scale by default # Initialize the options for x, y, and z based on the first scale by default
...@@ -763,44 +770,44 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -763,44 +770,44 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
z_dropdown = widgets.Dropdown( z_dropdown = widgets.Dropdown(
options=list(range(num_chunks[0])), options=list(range(num_chunks[0])),
value=0, value=0,
description='First dimension (Z)', description="First dimension (Z)",
style={'description_width': description_width, 'text_align': 'left'}, style={"description_width": description_width, "text_align": "left"},
) )
y_dropdown = widgets.Dropdown( y_dropdown = widgets.Dropdown(
options=list(range(num_chunks[1])), options=list(range(num_chunks[1])),
value=0, value=0,
description='Second dimension (Y)', description="Second dimension (Y)",
style={'description_width': description_width, 'text_align': 'left'}, style={"description_width": description_width, "text_align": "left"},
) )
x_dropdown = widgets.Dropdown( x_dropdown = widgets.Dropdown(
options=list(range(num_chunks[2])), options=list(range(num_chunks[2])),
value=0, value=0,
description='Third dimension (X)', description="Third dimension (X)",
style={'description_width': description_width, 'text_align': 'left'}, style={"description_width": description_width, "text_align": "left"},
) )
method_dropdown = widgets.Dropdown( method_dropdown = widgets.Dropdown(
options=['slicer', 'slices', 'volume'], options=["slicer", "slices", "volume"],
value='slicer', value="slicer",
description='Visualization', description="Visualization",
style={'description_width': description_width, 'text_align': 'left'}, style={"description_width": description_width, "text_align": "left"},
) )
# Funtion to temporarily disable observers # Funtion to temporarily disable observers
def disable_observers(): def disable_observers():
x_dropdown.unobserve(update_visualization, names='value') x_dropdown.unobserve(update_visualization, names="value")
y_dropdown.unobserve(update_visualization, names='value') y_dropdown.unobserve(update_visualization, names="value")
z_dropdown.unobserve(update_visualization, names='value') z_dropdown.unobserve(update_visualization, names="value")
method_dropdown.unobserve(update_visualization, names='value') method_dropdown.unobserve(update_visualization, names="value")
# Funtion to enable observers # Funtion to enable observers
def enable_observers(): def enable_observers():
x_dropdown.observe(update_visualization, names='value') x_dropdown.observe(update_visualization, names="value")
y_dropdown.observe(update_visualization, names='value') y_dropdown.observe(update_visualization, names="value")
z_dropdown.observe(update_visualization, names='value') z_dropdown.observe(update_visualization, names="value")
method_dropdown.observe(update_visualization, names='value') method_dropdown.observe(update_visualization, names="value")
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0 # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def update_coordinate_dropdowns(scale): def update_coordinate_dropdowns(scale):
...@@ -853,7 +860,7 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -853,7 +860,7 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
# Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes # Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
scale_dropdown.observe( scale_dropdown.observe(
lambda change: update_coordinate_dropdowns(scale_dropdown.value), names='value' lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value"
) )
enable_observers() enable_observers()
...@@ -881,21 +888,23 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: ...@@ -881,21 +888,23 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
def histogram( def histogram(
volume: np.ndarray, volume: np.ndarray,
bins: Union[int, str] = 'auto', bins: Union[int, str] = "auto",
slice_idx: Union[int, str] = None, slice_idx: Union[int, str, None] = None,
vertical_line: int = None,
axis: int = 0, axis: int = 0,
kde: bool = True, kde: bool = True,
log_scale: bool = False, log_scale: bool = False,
despine: bool = True, despine: bool = True,
show_title: bool = True, show_title: bool = True,
color: str = 'qim3d', color: str = "qim3d",
edgecolor: str | None = None, edgecolor: Optional[str] = None,
figsize: tuple[float, float] = (8, 4.5), figsize: Tuple[float, float] = (8, 4.5),
element: str = 'step', element: str = "step",
return_fig: bool = False, return_fig: bool = False,
show: bool = True, show: bool = True,
**sns_kwargs, ax: Optional[plt.Axes] = None,
) -> None | matplotlib.figure.Figure: **sns_kwargs: Union[str, float, int, bool]
) -> Optional[Union[plt.Figure, plt.Axes]]:
""" """
Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
...@@ -903,73 +912,63 @@ def histogram( ...@@ -903,73 +912,63 @@ def histogram(
Args: Args:
volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (int or str, optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0. axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (int or str or None, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
kde (bool, optional): Whether to overlay a kernel density estimate. Default is True. kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False. log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True. despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
show_title (bool, optional): If True, displays a title with slice information. Default is True. show_title (bool, optional): If True, displays a title with slice information. Default is True.
color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d". color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d".
edgecolor (str, optional): Color for the edges of the histogram bars. Default is None. edgecolor (str, optional): Color for the edges of the histogram bars. Default is None.
figsize (tuple of floats, optional): Size of the figure (width, height). Default is (8, 4.5). figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5).
element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step". element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False. return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True. show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
**sns_kwargs (Any): Additional keyword arguments for `seaborn.histplot`. ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
**sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
Returns: Returns:
fig (Optional[matplotlib.figure.Figure]): If `return_fig` is True, returns the generated figure object. Otherwise, returns None. Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
If `return_fig` is True, returns the generated figure object.
If `return_fig` is False and `ax` is provided, returns the `Axes` object.
Otherwise, returns None.
Raises: Raises:
ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis. ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol)
```
![viz histogram](../../assets/screenshots/viz-histogram-vol.png)
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True)
```
![viz histogram](../../assets/screenshots/viz-histogram-slice.png)
""" """
if not (0 <= axis < volume.ndim): if not (0 <= axis < volume.ndim):
raise ValueError(f'Axis must be an integer between 0 and {volume.ndim - 1}.') raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == 'middle': if slice_idx == "middle":
slice_idx = volume.shape[axis] // 2 slice_idx = volume.shape[axis] // 2
if slice_idx: if slice_idx is not None:
if 0 <= slice_idx < volume.shape[axis]: if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(volume, indices=slice_idx, axis=axis) img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel() data = img_slice.ravel()
title = f'Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}' title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
else: else:
raise ValueError( raise ValueError(
f'Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}.' f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
) )
else: else:
data = volume.ravel() data = volume.ravel()
title = f'Intensity histogram for whole volume {volume.shape}' title = f"Intensity histogram for whole volume {volume.shape}"
# Use provided Axes or create new figure
if ax is None:
fig, ax = plt.subplots(figsize=figsize) fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
if log_scale: if log_scale:
plt.yscale('log') ax.set_yscale("log")
if color == 'qim3d': if color == "qim3d":
color = qim3d.viz.colormaps.qim(1.0) color = qim3d.viz.colormaps.qim(1.0)
sns.histplot( sns.histplot(
...@@ -979,38 +978,59 @@ def histogram( ...@@ -979,38 +978,59 @@ def histogram(
color=color, color=color,
element=element, element=element,
edgecolor=edgecolor, edgecolor=edgecolor,
ax=ax, # Plot directly on the specified Axes
**sns_kwargs, **sns_kwargs,
) )
if vertical_line is not None:
ax.axvline(
x=vertical_line,
color='red',
linestyle="--",
linewidth=2,
)
if despine: if despine:
sns.despine( sns.despine(
fig=None, fig=None,
ax=None, ax=ax,
top=True, top=True,
right=True, right=True,
left=False, left=False,
bottom=False, bottom=False,
offset={'left': 0, 'bottom': 18}, offset={"left": 0, "bottom": 18},
trim=True, trim=True,
) )
plt.xlabel('Voxel Intensity') ax.set_xlabel("Voxel Intensity")
plt.ylabel('Frequency') ax.set_ylabel("Frequency")
if show_title: if show_title:
plt.title(title, fontsize=10) ax.set_title(title, fontsize=10)
# Handle show and return # Handle show and return
if show: if show and fig is not None:
plt.show() plt.show()
else:
plt.close(fig)
if return_fig: if return_fig:
return fig return fig
elif ax is not None:
return ax
class _LineProfile: class _LineProfile:
def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range): def __init__(
self,
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
):
self.volume = volume self.volume = volume
self.slice_axis = slice_axis self.slice_axis = slice_axis
...@@ -1038,20 +1058,29 @@ class _LineProfile: ...@@ -1038,20 +1058,29 @@ class _LineProfile:
self.y_widget.value = self.y_max // 2 self.y_widget.value = self.y_max // 2
def initialize_widgets(self): def initialize_widgets(self):
layout = widgets.Layout(width='300px', height='auto') layout = widgets.Layout(width="300px", height="auto")
self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) self.x_widget = widgets.IntSlider(
self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) min=self.pad, step=1, description="", layout=layout
self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout) )
self.y_widget = widgets.IntSlider(
min=self.pad, step=1, description="", layout=layout
)
self.angle_widget = widgets.IntSlider(
min=0, max=360, step=1, value=0, description="", layout=layout
)
self.line_fraction_widget = widgets.FloatRangeSlider( self.line_fraction_widget = widgets.FloatRangeSlider(
min=0, max=1, step=0.01, value=[0, 1], min=0, max=1, step=0.01, value=[0, 1], description="", layout=layout
description="", layout=layout
) )
self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis') self.slice_axis_widget = widgets.Dropdown(
self.slice_axis_widget.layout.width = '250px' options=[0, 1, 2], value=self.slice_axis, description="Slice axis"
)
self.slice_axis_widget.layout.width = "250px"
self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout) self.slice_index_widget = widgets.IntSlider(
self.slice_index_widget.layout.width = '400px' min=0, step=1, description="Slice index", layout=layout
)
self.slice_index_widget.layout.width = "400px"
def calculate_line_endpoints(self, x, y, angle): def calculate_line_endpoints(self, x, y, angle):
""" """
...@@ -1091,7 +1120,10 @@ class _LineProfile: ...@@ -1091,7 +1120,10 @@ class _LineProfile:
image = np.take(self.volume, slice_index, slice_axis) image = np.take(self.volume, slice_index, slice_axis)
angle = np.radians(angle_deg) angle = np.radians(angle_deg)
src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)] src, dst = (
np.array(point, dtype="float32")
for point in self.calculate_line_endpoints(x, y, angle)
)
# Rescale endpoints # Rescale endpoints
line_vec = dst - src line_vec = dst - src
...@@ -1106,41 +1138,54 @@ class _LineProfile: ...@@ -1106,41 +1138,54 @@ class _LineProfile:
num_segments = 100 num_segments = 100
x_seg = np.linspace(src[0], dst[0], num_segments) x_seg = np.linspace(src[0], dst[0], num_segments)
y_seg = np.linspace(src[1], dst[1], num_segments) y_seg = np.linspace(src[1], dst[1], num_segments)
segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]), segments = np.stack(
np.column_stack([y_seg[2:], x_seg[2:]])], axis=1) [
np.column_stack([y_seg[:-2], x_seg[:-2]]),
np.column_stack([y_seg[2:], x_seg[2:]]),
],
axis=1,
)
norm = plt.Normalize(vmin=0, vmax=num_segments - 1) norm = plt.Normalize(vmin=0, vmax=num_segments - 1)
colors = self.cmap(norm(np.arange(num_segments - 1))) colors = self.cmap(norm(np.arange(num_segments - 1)))
lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2) lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2)
ax[0].imshow(image,cmap='gray') ax[0].imshow(image, cmap="gray")
ax[0].add_collection(lc) ax[0].add_collection(lc)
# pivot point # pivot point
ax[0].plot(y,x,marker='s', linestyle='', color='cyan', markersize=4) ax[0].plot(y, x, marker="s", linestyle="", color="cyan", markersize=4)
ax[0].set_xlabel(f'axis {np.delete(np.arange(3), self.slice_axis)[1]}') ax[0].set_xlabel(f"axis {np.delete(np.arange(3), self.slice_axis)[1]}")
ax[0].set_ylabel(f'axis {np.delete(np.arange(3), self.slice_axis)[0]}') ax[0].set_ylabel(f"axis {np.delete(np.arange(3), self.slice_axis)[0]}")
# Profile intensity plot # Profile intensity plot
norm = plt.Normalize(0, vmax=len(y_pline) - 1) norm = plt.Normalize(0, vmax=len(y_pline) - 1)
x_pline = np.arange(len(y_pline)) x_pline = np.arange(len(y_pline))
points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :] points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :]
segments = np.concatenate([points[:-1], points[1:]], axis=1) segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2) lc = matplotlib.collections.LineCollection(
segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2
)
ax[1].add_collection(lc) ax[1].add_collection(lc)
ax[1].autoscale() ax[1].autoscale()
ax[1].set_xlabel('Distance along line') ax[1].set_xlabel("Distance along line")
ax[1].grid(True) ax[1].grid(True)
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
def build_interactive(self): def build_interactive(self):
# Group widgets into two columns # Group widgets into two columns
title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" title_style = (
title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>") "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;"
title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>") )
title_column1 = widgets.HTML(
f"<div style='{title_style}'>Line parameterization</div>"
)
title_column2 = widgets.HTML(
f"<div style='{title_style}'>Slice selection</div>"
)
# Make label widgets instead of descriptions which have different lengths. # Make label widgets instead of descriptions which have different lengths.
label_layout = widgets.Layout(width='120px') label_layout = widgets.Layout(width="120px")
label_x = widgets.Label("Vertical position", layout=label_layout) label_x = widgets.Label("Vertical position", layout=label_layout)
label_y = widgets.Label("Horizontal position", layout=label_layout) label_y = widgets.Label("Horizontal position", layout=label_layout)
label_angle = widgets.Label("Angle (°)", layout=label_layout) label_angle = widgets.Label("Angle (°)", layout=label_layout)
...@@ -1151,29 +1196,40 @@ class _LineProfile: ...@@ -1151,29 +1196,40 @@ class _LineProfile:
row_angle = widgets.HBox([label_angle, self.angle_widget]) row_angle = widgets.HBox([label_angle, self.angle_widget])
row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget]) row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget])
controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction]) controls_column1 = widgets.VBox(
controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget]) [title_column1, row_x, row_y, row_angle, row_fraction]
)
controls_column2 = widgets.VBox(
[title_column2, self.slice_axis_widget, self.slice_index_widget]
)
controls = widgets.HBox([controls_column1, controls_column2]) controls = widgets.HBox([controls_column1, controls_column2])
interactive_plot = widgets.interactive_output( interactive_plot = widgets.interactive_output(
self.update, self.update,
{'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget, {
'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget, "slice_axis": self.slice_axis_widget,
'fraction_range': self.line_fraction_widget} "slice_index": self.slice_index_widget,
"x": self.x_widget,
"y": self.y_widget,
"angle_deg": self.angle_widget,
"fraction_range": self.line_fraction_widget,
},
) )
return widgets.VBox([controls, interactive_plot]) return widgets.VBox([controls, interactive_plot])
def line_profile( def line_profile(
volume: np.ndarray, volume: np.ndarray,
slice_axis: int = 0, slice_axis: int = 0,
slice_index: int | str='middle', slice_index: int | str = "middle",
vertical_position: int | str='middle', vertical_position: int | str = "middle",
horizontal_position: int | str='middle', horizontal_position: int | str = "middle",
angle: int = 0, angle: int = 0,
fraction_range: Tuple[float,float]=(0.00, 1.00) fraction_range: Tuple[float, float] = (0.00, 1.00),
) -> widgets.interactive: ) -> widgets.interactive:
"""Returns an interactive widget for visualizing the intensity profiles of lines on slices. """
Returns an interactive widget for visualizing the intensity profiles of lines on slices.
Args: Args:
volume (np.ndarray): The 3D volume of interest. volume (np.ndarray): The 3D volume of interest.
...@@ -1198,23 +1254,29 @@ def line_profile( ...@@ -1198,23 +1254,29 @@ def line_profile(
![viz histogram](../../assets/screenshots/viz-line_profile.gif) ![viz histogram](../../assets/screenshots/viz-line_profile.gif)
""" """
def parse_position(pos, pos_range, name): def parse_position(pos, pos_range, name):
if isinstance(pos, int): if isinstance(pos, int):
if not pos_range[0] <= pos < pos_range[1]: if not pos_range[0] <= pos < pos_range[1]:
raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]') raise ValueError(
f"Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]"
)
return pos return pos
elif isinstance(pos, str): elif isinstance(pos, str):
pos = pos.lower() pos = pos.lower()
if pos == 'start': return pos_range[0] if pos == "start":
elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 return pos_range[0]
elif pos == 'end': return pos_range[1] elif pos == "middle":
return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
elif pos == "end":
return pos_range[1]
else: else:
raise ValueError( raise ValueError(
f"Invalid string '{pos}' for {name}. " f"Invalid string '{pos}' for {name}. "
"Must be 'start', 'middle', or 'end'." "Must be 'start', 'middle', or 'end'."
) )
else: else:
raise TypeError(f'Axis position must be of type int or str.') raise TypeError("Axis position must be of type int or str.")
if not isinstance(volume, (np.ndarray, da.core.Array)): if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError("Data type for volume not supported.") raise ValueError("Data type for volume not supported.")
...@@ -1222,17 +1284,240 @@ def line_profile( ...@@ -1222,17 +1284,240 @@ def line_profile(
raise ValueError("Volume must be 3D.") raise ValueError("Volume must be 3D.")
dims = volume.shape dims = volume.shape
slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index') slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), "slice_index")
# the omission of the ends for the pivot point is due to border issues. # the omission of the ends for the pivot point is due to border issues.
vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position') vertical_position = parse_position(
horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position') vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), "vertical_position"
)
horizontal_position = parse_position(
horizontal_position,
(1, np.delete(dims, slice_axis)[1] - 2),
"horizontal_position",
)
if not isinstance(angle, int | float): if not isinstance(angle, int | float):
raise ValueError("Invalid type for angle.") raise ValueError("Invalid type for angle.")
angle = round(angle) % 360 angle = round(angle) % 360
if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]): if not (
0.0 <= fraction_range[0] <= 1.0
and 0.0 <= fraction_range[1] <= 1.0
and fraction_range[0] <= fraction_range[1]
):
raise ValueError("Invalid values for fraction_range.") raise ValueError("Invalid values for fraction_range.")
lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range) lp = _LineProfile(
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
)
return lp.build_interactive() return lp.build_interactive()
def threshold(
volume: np.ndarray,
cmap_image: str = 'magma',
vmin: float = None,
vmax: float = None,
) -> widgets.VBox:
"""
This function provides an interactive interface to explore thresholding on a
3D volume slice-by-slice. Users can either manually set the threshold value
using a slider or select an automatic thresholding method from `skimage`.
The visualization includes the original image slice, a binary mask showing regions above the
threshold and an overlay combining the binary mask and the original image.
Args:
volume (np.ndarray): 3D volume to threshold.
cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'.
cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'.
vmin (float, optional): Minimum value for the colormap. Defaults to None.
vmax (float, optional): Maximum value for the colormap. Defaults to None.
Returns:
slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume.
Interactivity:
- **Manual Thresholding**:
Select 'Manual' from the dropdown menu to manually adjust the threshold
using the slider.
- **Automatic Thresholding**:
Choose a method from the dropdown menu to apply an automatic thresholding
algorithm. Available methods include:
- Otsu
- Isodata
- Li
- Mean
- Minimum
- Triangle
- Yen
The threshold slider will display the computed value and will be disabled
in this mode.
```python
import qim3d
# Load a sample volume
vol = qim3d.examples.bone_128x128x128
# Visualize interactive thresholding
qim3d.viz.threshold(vol)
```
![interactive threshold](../../assets/screenshots/interactive_thresholding.gif)
"""
# Centralized state dictionary to track current parameters
state = {
"position": volume.shape[0] // 2,
"threshold": int((volume.min() + volume.max()) / 2),
"method": "Manual",
}
threshold_methods = {
"Otsu": threshold_otsu,
"Isodata": threshold_isodata,
"Li": threshold_li,
"Mean": threshold_mean,
"Minimum": threshold_minimum,
"Triangle": threshold_triangle,
"Yen": threshold_yen,
}
# Create an output widget to display the plot
output = widgets.Output()
# Function to update the state and trigger visualization
def update_state(change):
# Update state based on widget values
state["position"] = position_slider.value
state["method"] = method_dropdown.value
if state["method"] == "Manual":
state["threshold"] = threshold_slider.value
threshold_slider.disabled = False
else:
threshold_func = threshold_methods.get(state["method"])
if threshold_func:
slice_img = volume[state["position"], :, :]
computed_threshold = threshold_func(slice_img)
state["threshold"] = computed_threshold
# Programmatically update the slider without triggering callbacks
threshold_slider.unobserve_all()
threshold_slider.value = computed_threshold
threshold_slider.disabled = True
threshold_slider.observe(update_state, names="value")
else:
raise ValueError(f"Unsupported thresholding method: {state['method']}")
# Trigger visualization
update_visualization()
# Visualization function
def update_visualization():
slice_img = volume[state["position"], :, :]
with output:
output.clear_output(wait=True) # Clear previous plot
fig, axes = plt.subplots(1, 4, figsize=(25, 5))
# Original image
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title("Original")
axes[0].axis("off")
# Histogram
histogram(
volume=volume,
bins=32,
slice_idx=state["position"],
vertical_line=state["threshold"],
axis=1,
kde=False,
ax=axes[1],
show=False,
)
axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")
# Binary mask
mask = slice_img > state["threshold"]
axes[2].imshow(mask, cmap="gray")
axes[2].set_title("Binary mask")
axes[2].axis("off")
# Overlay
mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
mask_rgb[:, :, 0] = mask
masked_volume = qim3d.operations.overlay_rgb_images(
background=slice_img,
foreground=mask_rgb,
)
axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
axes[3].set_title("Overlay")
axes[3].axis("off")
plt.show()
# Widgets
position_slider = widgets.IntSlider(
value=state["position"],
min=0,
max=volume.shape[0] - 1,
description="Slice",
)
threshold_slider = widgets.IntSlider(
value=state["threshold"],
min=volume.min(),
max=volume.max(),
description="Threshold",
)
method_dropdown = widgets.Dropdown(
options=[
"Manual",
"Otsu",
"Isodata",
"Li",
"Mean",
"Minimum",
"Triangle",
"Yen",
],
value=state["method"],
description="Method",
)
# Attach the state update function to widgets
position_slider.observe(update_state, names="value")
threshold_slider.observe(update_state, names="value")
method_dropdown.observe(update_state, names="value")
# Layout
controls_left = widgets.VBox([position_slider, threshold_slider])
controls_right = widgets.VBox([method_dropdown])
controls_layout = widgets.HBox(
[controls_left, controls_right],
layout=widgets.Layout(justify_content="flex-start"),
)
interactive_ui = widgets.VBox([controls_layout, output])
update_visualization()
return interactive_ui