diff --git a/docs/assets/screenshots/interactive_thresholding.gif b/docs/assets/screenshots/interactive_thresholding.gif new file mode 100644 index 0000000000000000000000000000000000000000..f7ec78d06c3608954446a03f0e6c9ce88dcaf59c Binary files /dev/null and b/docs/assets/screenshots/interactive_thresholding.gif differ diff --git a/docs/doc/visualization/viz.md b/docs/doc/visualization/viz.md index 7a4b39a8ec49409fe51b190a079971a7568f852b..2cdb54d6130c8ec7c47bc856a79bc75f7bcabcaa 100644 --- a/docs/doc/visualization/viz.md +++ b/docs/doc/visualization/viz.md @@ -24,6 +24,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - colormaps - fade_mask - line_profile + - threshold ::: qim3d.viz.colormaps options: diff --git a/qim3d/io/__init__.py b/qim3d/io/__init__.py index 23f02d4c4aceff410adf3daf8f615e425fbbb246..db64b581d28e70fbfe6e63ce58d787d862b8fba7 100644 --- a/qim3d/io/__init__.py +++ b/qim3d/io/__init__.py @@ -1,6 +1,6 @@ # from ._sync import Sync # this will be added back after future development -from ._convert import convert -from ._downloader import Downloader from ._loading import load, load_mesh -from ._ome_zarr import export_ome_zarr, import_ome_zarr +from ._downloader import Downloader from ._saving import save, save_mesh +from ._convert import convert +from ._ome_zarr import export_ome_zarr, import_ome_zarr diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py index 49d2633fcbe9df3676b4e9f79f3d8d34383f4d1b..fbfe75a9b9425246dedbd05198b61181a8c84f1a 100644 --- a/qim3d/io/_convert.py +++ b/qim3d/io/_convert.py @@ -7,9 +7,10 @@ import numpy as np import tifffile as tiff import zarr import zarr.core +import qim3d + from tqdm import tqdm -from qim3d.io import save from qim3d.utils._misc import stringify_path @@ -121,7 +122,7 @@ class Convert: """ z = zarr.open(zarr_path) - save(tif_path, z) + qim3d.io.save(tif_path, z) def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array: """ @@ -173,7 +174,7 @@ class Convert: """ z = zarr.open(zarr_path) - save(nifti_path, z, compression=compression) + qim3d.io.save(nifti_path, z, compression=compression) def convert( diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index a2be34e4c2addf8b4641d9d9c7d3a292e9423581..f785319d6eb90429839d5fc1a4770b4783e6c8d5 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -4,12 +4,11 @@ from ._data_exploration import ( chunks, fade_mask, histogram, + line_profile, slicer, slicer_orthogonal, slices_grid, - chunks, - histogram, - line_profile + threshold, ) from ._detection import circles from ._k3d import mesh, volumetric diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index b6ae03ac8f3e4494969ad84339cd3bad1bdd9de4..e8d3e0f4b4c72401db9dcb586e02c9406b27a92f 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -4,20 +4,27 @@ Provides a collection of visualization functions. import math import warnings - -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import dask.array as da import ipywidgets as widgets import matplotlib import matplotlib.figure import matplotlib.pyplot as plt -import matplotlib -from IPython.display import SVG, display, clear_output -import matplotlib import numpy as np import seaborn as sns import skimage.measure +from skimage.filters import ( + threshold_otsu, + threshold_isodata, + threshold_li, + threshold_mean, + threshold_minimum, + threshold_triangle, + threshold_yen, +) + +from IPython.display import clear_output, display import qim3d from qim3d.utils._logger import log @@ -29,7 +36,7 @@ def slices_grid( slice_positions: Optional[Union[str, int, List[int]]] = None, num_slices: int = 15, max_columns: int = 5, - color_map: str = 'magma', + color_map: str = "magma", value_min: float = None, value_max: float = None, image_size: int = None, @@ -39,7 +46,7 @@ def slices_grid( display_positions: bool = True, interpolation: Optional[str] = None, color_bar: bool = False, - color_bar_style: str = 'small', + color_bar_style: str = "small", **matplotlib_imshow_kwargs, ) -> matplotlib.figure.Figure: """ @@ -93,18 +100,18 @@ def slices_grid( # If we pass python None to the imshow function, it will set to # default value 'antialiased' if interpolation is None: - interpolation = 'none' + interpolation = "none" # Numpy array or Torch tensor input if not isinstance(volume, (np.ndarray, da.core.Array)): - raise ValueError('Data type not supported') + raise ValueError("Data type not supported") if volume.ndim < 3: raise ValueError( - 'The provided object is not a volume as it has less than 3 dimensions.' + "The provided object is not a volume as it has less than 3 dimensions." ) - color_bar_style_options = ['small', 'large'] + color_bar_style_options = ["small", "large"] if color_bar_style not in color_bar_style_options: raise ValueError( f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}." @@ -122,11 +129,11 @@ def slices_grid( # Here we deal with the case that the user wants to use the objects colormap directly if ( type(color_map) == matplotlib.colors.LinearSegmentedColormap - or color_map == 'segmentation' + or color_map == "segmentation" ): num_labels = len(np.unique(volume)) - if color_map == 'segmentation': + if color_map == "segmentation": color_map = qim3d.viz.colormaps.segmentation(num_labels) # If value_min and value_max are not set like this, then in case the # number of objects changes on new slice, objects might change @@ -143,15 +150,15 @@ def slices_grid( slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int) # Position is a string elif isinstance(slice_positions, str) and slice_positions.lower() in [ - 'start', - 'mid', - 'end', + "start", + "mid", + "end", ]: - if slice_positions.lower() == 'start': + if slice_positions.lower() == "start": slice_idxs = _get_slice_range(0, num_slices, n_total) - elif slice_positions.lower() == 'mid': + elif slice_positions.lower() == "mid": slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total) - elif slice_positions.lower() == 'end': + elif slice_positions.lower() == "end": slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total) # Position is an integer elif isinstance(slice_positions, int): @@ -232,25 +239,25 @@ def slices_grid( ax.text( 0.0, 1.0, - f'slice {slice_idxs[slice_idx]} ', + f"slice {slice_idxs[slice_idx]} ", transform=ax.transAxes, - color='white', + color="white", fontsize=8, - va='top', - ha='left', - bbox=dict(facecolor='#303030', linewidth=0, pad=0), + va="top", + ha="left", + bbox=dict(facecolor="#303030", linewidth=0, pad=0), ) ax.text( 1.0, 0.0, - f'axis {slice_axis} ', + f"axis {slice_axis} ", transform=ax.transAxes, - color='white', + color="white", fontsize=8, - va='bottom', - ha='right', - bbox=dict(facecolor='#303030', linewidth=0, pad=0), + va="bottom", + ha="right", + bbox=dict(facecolor="#303030", linewidth=0, pad=0), ) except IndexError: @@ -258,11 +265,11 @@ def slices_grid( pass # Hide the axis, so that we have a nice grid - ax.axis('off') + ax.axis("off") if color_bar: with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=UserWarning) + warnings.simplefilter("ignore", category=UserWarning) fig.tight_layout() norm = matplotlib.colors.Normalize( @@ -270,15 +277,15 @@ def slices_grid( ) mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map) - if color_bar_style == 'small': + if color_bar_style == "small": # Figure coordinates of top-right axis tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images color_bar_ax = fig.add_axes( [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] ) - fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical') - elif color_bar_style == 'large': + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") + elif color_bar_style == "large": # Figure coordinates of bottom- and top-right axis br_pos = np.atleast_1d(axs[-1])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position() @@ -291,7 +298,7 @@ def slices_grid( (tr_pos.y1 - br_pos.y0) - 0.0015, ] ) - fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical') + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") if display_figure: plt.show() @@ -322,7 +329,7 @@ def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray def slicer( volume: np.ndarray, slice_axis: int = 0, - color_map: str = 'magma', + color_map: str = "magma", value_min: float = None, value_max: float = None, image_height: int = 3, @@ -366,14 +373,14 @@ def slicer( image_height = image_size image_width = image_size - color_bar_options = [None, 'slices', 'volume'] + color_bar_options = [None, "slices", "volume"] if color_bar not in color_bar_options: raise ValueError( f"Unrecognized value '{color_bar}' for parameter color_bar. " - f'Expected one of {color_bar_options}.' + f"Expected one of {color_bar_options}." ) show_color_bar = color_bar is not None - if color_bar == 'slices': + if color_bar == "slices": # Precompute the minimum and maximum along each slice for faster widget sliding. non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis) slice_mins = np.min(volume, axis=non_slice_axes) @@ -381,7 +388,7 @@ def slicer( # Create the interactive widget def _slicer(slice_positions): - if color_bar == 'slices': + if color_bar == "slices": dynamic_min = slice_mins[slice_positions] dynamic_max = slice_maxs[slice_positions] else: @@ -410,18 +417,18 @@ def slicer( value=volume.shape[slice_axis] // 2, min=0, max=volume.shape[slice_axis] - 1, - description='Slice', + description="Slice", continuous_update=True, ) slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider) - slicer_obj.layout = widgets.Layout(align_items='flex-start') + slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj def slicer_orthogonal( volume: np.ndarray, - color_map: str = 'magma', + color_map: str = "magma", value_min: float = None, value_max: float = None, image_height: int = 3, @@ -477,9 +484,9 @@ def slicer_orthogonal( y_slicer = get_slicer_for_axis(slice_axis=1) x_slicer = get_slicer_for_axis(slice_axis=2) - z_slicer.children[0].description = 'Z' - y_slicer.children[0].description = 'Y' - x_slicer.children[0].description = 'X' + z_slicer.children[0].description = "Z" + y_slicer.children[0].description = "Y" + x_slicer.children[0].description = "X" return widgets.HBox([z_slicer, y_slicer, x_slicer]) @@ -487,7 +494,7 @@ def slicer_orthogonal( def fade_mask( volume: np.ndarray, axis: int = 0, - color_map: str = 'magma', + color_map: str = "magma", value_min: float = None, value_max: float = None, ) -> widgets.interactive: @@ -537,8 +544,8 @@ def fade_mask( axes[0].imshow( slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max ) - axes[0].set_title('Original') - axes[0].axis('off') + axes[0].set_title("Original") + axes[0].axis("off") mask = qim3d.operations.fade_mask( np.ones_like(volume), @@ -549,8 +556,8 @@ def fade_mask( invert=invert, ) axes[1].imshow(mask[position, :, :], cmap=color_map) - axes[1].set_title('Mask') - axes[1].axis('off') + axes[1].set_title("Mask") + axes[1].axis("off") masked_volume = qim3d.operations.fade_mask( volume, @@ -576,22 +583,22 @@ def fade_mask( axes[2].imshow( slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max ) - axes[2].set_title('Masked') - axes[2].axis('off') + axes[2].set_title("Masked") + axes[2].axis("off") return fig shape_dropdown = widgets.Dropdown( - options=['spherical', 'cylindrical'], - value='spherical', # default value - description='Geometry', + options=["spherical", "cylindrical"], + value="spherical", # default value + description="Geometry", ) position_slider = widgets.IntSlider( value=volume.shape[0] // 2, min=0, max=volume.shape[0] - 1, - description='Slice', + description="Slice", continuous_update=False, ) decay_rate_slider = widgets.FloatSlider( @@ -599,7 +606,7 @@ def fade_mask( min=1, max=50, step=1.0, - description='Decay Rate', + description="Decay Rate", continuous_update=False, ) ratio_slider = widgets.FloatSlider( @@ -607,14 +614,14 @@ def fade_mask( min=0.1, max=1, step=0.01, - description='Ratio', + description="Ratio", continuous_update=False, ) # Create the Checkbox widget invert_checkbox = widgets.Checkbox( value=False, - description='Invert', # default value + description="Invert", # default value ) slicer_obj = widgets.interactive( @@ -625,7 +632,7 @@ def fade_mask( geometry=shape_dropdown, invert=invert_checkbox, ) - slicer_obj.layout = widgets.Layout(align_items='flex-start') + slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj @@ -657,15 +664,15 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: """ # Load the Zarr dataset - zarr_data = zarr.open(zarr_path, mode='r') + zarr_data = zarr.open(zarr_path, mode="r") # Save arguments for later use # visualization_method = visualization_method # preserved_kwargs = kwargs # Create label to display the chunk coordinates - widget_title = widgets.HTML('<h2>Chunk Explorer</h2>') - chunk_info_label = widgets.HTML(value='Chunk info will be displayed here') + widget_title = widgets.HTML("<h2>Chunk Explorer</h2>") + chunk_info_label = widgets.HTML(value="Chunk info will be displayed here") def load_and_visualize( scale, z_coord, y_coord, x_coord, visualization_method, **kwargs @@ -699,13 +706,13 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: # Update the chunk info label with the chunk coordinates info_string = ( - f'<b>shape:</b> {chunk_shape}\n' - + f'<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n' - + f'<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n' - + f'<b>dtype:</b> {chunk.dtype}\n' - + f'<b>min value:</b> {np.min(chunk)}\n' - + f'<b>max value:</b> {np.max(chunk)}\n' - + f'<b>mean value:</b> {np.mean(chunk)}\n' + f"<b>shape:</b> {chunk_shape}\n" + + f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n" + + f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n" + + f"<b>dtype:</b> {chunk.dtype}\n" + + f"<b>min value:</b> {np.min(chunk)}\n" + + f"<b>max value:</b> {np.max(chunk)}\n" + + f"<b>mean value:</b> {np.mean(chunk)}\n" ) chunk_info_label.value = f""" @@ -719,22 +726,22 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: """ # Prepare chunk visualization based on the selected method - if visualization_method == 'slicer': # return a widget + if visualization_method == "slicer": # return a widget viz_widget = qim3d.viz.slicer(chunk, **kwargs) - elif visualization_method == 'slices': # return a plt.Figure + elif visualization_method == "slices": # return a plt.Figure viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) fig = qim3d.viz.slices_grid(chunk, **kwargs) display(fig) - elif visualization_method == 'volume': + elif visualization_method == "volume": viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) out = qim3d.viz.volumetric(chunk, show=False, **kwargs) display(out) else: - log.info(f'Invalid visualization method: {visualization_method}') + log.info(f"Invalid visualization method: {visualization_method}") return viz_widget @@ -743,16 +750,16 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)] scale_options = { - f'{i} {zarr_data[i].shape}': i for i in range(len(zarr_data)) + f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data)) } # len(zarr_data) gives number of scales - description_width = '128px' + description_width = "128px" # Create dropdown for scale scale_dropdown = widgets.Dropdown( options=scale_options, value=0, # Default to first scale - description='OME-Zarr scale', - style={'description_width': description_width, 'text_align': 'left'}, + description="OME-Zarr scale", + style={"description_width": description_width, "text_align": "left"}, ) # Initialize the options for x, y, and z based on the first scale by default @@ -763,44 +770,44 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: z_dropdown = widgets.Dropdown( options=list(range(num_chunks[0])), value=0, - description='First dimension (Z)', - style={'description_width': description_width, 'text_align': 'left'}, + description="First dimension (Z)", + style={"description_width": description_width, "text_align": "left"}, ) y_dropdown = widgets.Dropdown( options=list(range(num_chunks[1])), value=0, - description='Second dimension (Y)', - style={'description_width': description_width, 'text_align': 'left'}, + description="Second dimension (Y)", + style={"description_width": description_width, "text_align": "left"}, ) x_dropdown = widgets.Dropdown( options=list(range(num_chunks[2])), value=0, - description='Third dimension (X)', - style={'description_width': description_width, 'text_align': 'left'}, + description="Third dimension (X)", + style={"description_width": description_width, "text_align": "left"}, ) method_dropdown = widgets.Dropdown( - options=['slicer', 'slices', 'volume'], - value='slicer', - description='Visualization', - style={'description_width': description_width, 'text_align': 'left'}, + options=["slicer", "slices", "volume"], + value="slicer", + description="Visualization", + style={"description_width": description_width, "text_align": "left"}, ) # Funtion to temporarily disable observers def disable_observers(): - x_dropdown.unobserve(update_visualization, names='value') - y_dropdown.unobserve(update_visualization, names='value') - z_dropdown.unobserve(update_visualization, names='value') - method_dropdown.unobserve(update_visualization, names='value') + x_dropdown.unobserve(update_visualization, names="value") + y_dropdown.unobserve(update_visualization, names="value") + z_dropdown.unobserve(update_visualization, names="value") + method_dropdown.unobserve(update_visualization, names="value") # Funtion to enable observers def enable_observers(): - x_dropdown.observe(update_visualization, names='value') - y_dropdown.observe(update_visualization, names='value') - z_dropdown.observe(update_visualization, names='value') - method_dropdown.observe(update_visualization, names='value') + x_dropdown.observe(update_visualization, names="value") + y_dropdown.observe(update_visualization, names="value") + z_dropdown.observe(update_visualization, names="value") + method_dropdown.observe(update_visualization, names="value") # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0 def update_coordinate_dropdowns(scale): @@ -853,7 +860,7 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: # Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes scale_dropdown.observe( - lambda change: update_coordinate_dropdowns(scale_dropdown.value), names='value' + lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value" ) enable_observers() @@ -881,95 +888,87 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive: def histogram( volume: np.ndarray, - bins: Union[int, str] = 'auto', - slice_idx: Union[int, str] = None, + bins: Union[int, str] = "auto", + slice_idx: Union[int, str, None] = None, + vertical_line: int = None, axis: int = 0, kde: bool = True, log_scale: bool = False, despine: bool = True, show_title: bool = True, - color: str = 'qim3d', - edgecolor: str | None = None, - figsize: tuple[float, float] = (8, 4.5), - element: str = 'step', + color: str = "qim3d", + edgecolor: Optional[str] = None, + figsize: Tuple[float, float] = (8, 4.5), + element: str = "step", return_fig: bool = False, show: bool = True, - **sns_kwargs, -) -> None | matplotlib.figure.Figure: + ax: Optional[plt.Axes] = None, + **sns_kwargs: Union[str, float, int, bool] +) -> Optional[Union[plt.Figure, plt.Axes]]: """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. - + Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Args: volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. - bins (int or str, optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". + bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". axis (int, optional): Axis along which to take a slice. Default is 0. - slice_idx (int or str or None, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. + slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None. + vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None. kde (bool, optional): Whether to overlay a kernel density estimate. Default is True. log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False. despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True. show_title (bool, optional): If True, displays a title with slice information. Default is True. color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d". edgecolor (str, optional): Color for the edges of the histogram bars. Default is None. - figsize (tuple of floats, optional): Size of the figure (width, height). Default is (8, 4.5). + figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5). element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step". return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False. show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True. - **sns_kwargs (Any): Additional keyword arguments for `seaborn.histplot`. + ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None. + **sns_kwargs: Additional keyword arguments for `seaborn.histplot`. Returns: - fig (Optional[matplotlib.figure.Figure]): If `return_fig` is True, returns the generated figure object. Otherwise, returns None. + Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: + If `return_fig` is True, returns the generated figure object. + If `return_fig` is False and `ax` is provided, returns the `Axes` object. + Otherwise, returns None. Raises: ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `slice_idx` is an integer and is out of range for the specified axis. - - Example: - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol) - ``` -  - - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True) - ``` -  - """ - if not (0 <= axis < volume.ndim): - raise ValueError(f'Axis must be an integer between 0 and {volume.ndim - 1}.') + raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.") - if slice_idx == 'middle': + if slice_idx == "middle": slice_idx = volume.shape[axis] // 2 - if slice_idx: + if slice_idx is not None: if 0 <= slice_idx < volume.shape[axis]: img_slice = np.take(volume, indices=slice_idx, axis=axis) data = img_slice.ravel() - title = f'Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}' + title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" else: raise ValueError( - f'Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}.' + f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}." ) else: data = volume.ravel() - title = f'Intensity histogram for whole volume {volume.shape}' + title = f"Intensity histogram for whole volume {volume.shape}" - fig, ax = plt.subplots(figsize=figsize) + # Use provided Axes or create new figure + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = None if log_scale: - plt.yscale('log') + ax.set_yscale("log") - if color == 'qim3d': + if color == "qim3d": color = qim3d.viz.colormaps.qim(1.0) sns.histplot( @@ -979,43 +978,64 @@ def histogram( color=color, element=element, edgecolor=edgecolor, + ax=ax, # Plot directly on the specified Axes **sns_kwargs, ) + if vertical_line is not None: + ax.axvline( + x=vertical_line, + color='red', + linestyle="--", + linewidth=2, + + ) + if despine: sns.despine( fig=None, - ax=None, + ax=ax, top=True, right=True, left=False, bottom=False, - offset={'left': 0, 'bottom': 18}, + offset={"left": 0, "bottom": 18}, trim=True, ) - plt.xlabel('Voxel Intensity') - plt.ylabel('Frequency') + ax.set_xlabel("Voxel Intensity") + ax.set_ylabel("Frequency") if show_title: - plt.title(title, fontsize=10) + ax.set_title(title, fontsize=10) # Handle show and return - if show: + if show and fig is not None: plt.show() - else: - plt.close(fig) if return_fig: return fig + elif ax is not None: + return ax + + class _LineProfile: - def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range): + def __init__( + self, + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ): self.volume = volume self.slice_axis = slice_axis self.dims = np.array(volume.shape) - self.pad = 1 # Padding on pivot point to avoid border issues + self.pad = 1 # Padding on pivot point to avoid border issues self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1] self.initialize_widgets() @@ -1025,7 +1045,7 @@ class _LineProfile: self.y_widget.value = vertical_position self.angle_widget.value = angle self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]] - + def update_slice_axis(self, slice_axis): self.slice_axis = slice_axis self.slice_index_widget.max = self.volume.shape[slice_axis] - 1 @@ -1038,36 +1058,45 @@ class _LineProfile: self.y_widget.value = self.y_max // 2 def initialize_widgets(self): - layout = widgets.Layout(width='300px', height='auto') - self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) - self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) - self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout) + layout = widgets.Layout(width="300px", height="auto") + self.x_widget = widgets.IntSlider( + min=self.pad, step=1, description="", layout=layout + ) + self.y_widget = widgets.IntSlider( + min=self.pad, step=1, description="", layout=layout + ) + self.angle_widget = widgets.IntSlider( + min=0, max=360, step=1, value=0, description="", layout=layout + ) self.line_fraction_widget = widgets.FloatRangeSlider( - min=0, max=1, step=0.01, value=[0, 1], - description="", layout=layout + min=0, max=1, step=0.01, value=[0, 1], description="", layout=layout ) - self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis') - self.slice_axis_widget.layout.width = '250px' + self.slice_axis_widget = widgets.Dropdown( + options=[0, 1, 2], value=self.slice_axis, description="Slice axis" + ) + self.slice_axis_widget.layout.width = "250px" + + self.slice_index_widget = widgets.IntSlider( + min=0, step=1, description="Slice index", layout=layout + ) + self.slice_index_widget.layout.width = "400px" - self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout) - self.slice_index_widget.layout.width = '400px' - def calculate_line_endpoints(self, x, y, angle): """ Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)] """ if np.isclose(angle, 0): return [0, y], [self.x_max, y] - elif np.isclose(angle, np.pi/2): + elif np.isclose(angle, np.pi / 2): return [x, 0], [x, self.y_max] elif np.isclose(angle, np.pi): return [self.x_max, y], [0, y] - elif np.isclose(angle, 3*np.pi/2): + elif np.isclose(angle, 3 * np.pi / 2): return [x, self.y_max], [x, 0] - elif np.isclose(angle, 2*np.pi): + elif np.isclose(angle, 2 * np.pi): return [0, y], [self.x_max, y] - + t_left = -x / np.cos(angle) t_bottom = -y / np.sin(angle) t_right = (self.x_max - x) / np.cos(angle) @@ -1075,23 +1104,26 @@ class _LineProfile: t_values = np.array([t_left, t_top, t_right, t_bottom]) t_pos = np.min(t_values[t_values > 0]) t_neg = np.max(t_values[t_values < 0]) - + src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)] dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)] return src, dst - + def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range): if slice_axis != self.slice_axis: self.update_slice_axis(slice_axis) x = self.x_widget.value y = self.y_widget.value slice_index = self.slice_index_widget.value - + clear_output(wait=True) - + image = np.take(self.volume, slice_index, slice_axis) angle = np.radians(angle_deg) - src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)] + src, dst = ( + np.array(point, dtype="float32") + for point in self.calculate_line_endpoints(x, y, angle) + ) # Rescale endpoints line_vec = dst - src @@ -1101,46 +1133,59 @@ class _LineProfile: y_pline = skimage.measure.profile_line(image, src, dst) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) - + # Image with color-gradiented line num_segments = 100 x_seg = np.linspace(src[0], dst[0], num_segments) y_seg = np.linspace(src[1], dst[1], num_segments) - segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]), - np.column_stack([y_seg[2:], x_seg[2:]])], axis=1) - norm = plt.Normalize(vmin=0, vmax=num_segments-1) + segments = np.stack( + [ + np.column_stack([y_seg[:-2], x_seg[:-2]]), + np.column_stack([y_seg[2:], x_seg[2:]]), + ], + axis=1, + ) + norm = plt.Normalize(vmin=0, vmax=num_segments - 1) colors = self.cmap(norm(np.arange(num_segments - 1))) lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2) - ax[0].imshow(image,cmap='gray') + ax[0].imshow(image, cmap="gray") ax[0].add_collection(lc) # pivot point - ax[0].plot(y,x,marker='s', linestyle='', color='cyan', markersize=4) - ax[0].set_xlabel(f'axis {np.delete(np.arange(3), self.slice_axis)[1]}') - ax[0].set_ylabel(f'axis {np.delete(np.arange(3), self.slice_axis)[0]}') - + ax[0].plot(y, x, marker="s", linestyle="", color="cyan", markersize=4) + ax[0].set_xlabel(f"axis {np.delete(np.arange(3), self.slice_axis)[1]}") + ax[0].set_ylabel(f"axis {np.delete(np.arange(3), self.slice_axis)[0]}") + # Profile intensity plot norm = plt.Normalize(0, vmax=len(y_pline) - 1) x_pline = np.arange(len(y_pline)) points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :] segments = np.concatenate([points[:-1], points[1:]], axis=1) - lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2) + lc = matplotlib.collections.LineCollection( + segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2 + ) ax[1].add_collection(lc) ax[1].autoscale() - ax[1].set_xlabel('Distance along line') + ax[1].set_xlabel("Distance along line") ax[1].grid(True) plt.tight_layout() plt.show() - + def build_interactive(self): # Group widgets into two columns - title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" - title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>") - title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>") + title_style = ( + "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" + ) + title_column1 = widgets.HTML( + f"<div style='{title_style}'>Line parameterization</div>" + ) + title_column2 = widgets.HTML( + f"<div style='{title_style}'>Slice selection</div>" + ) # Make label widgets instead of descriptions which have different lengths. - label_layout = widgets.Layout(width='120px') + label_layout = widgets.Layout(width="120px") label_x = widgets.Label("Vertical position", layout=label_layout) label_y = widgets.Label("Horizontal position", layout=label_layout) label_angle = widgets.Label("Angle (°)", layout=label_layout) @@ -1151,29 +1196,40 @@ class _LineProfile: row_angle = widgets.HBox([label_angle, self.angle_widget]) row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget]) - controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction]) - controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget]) + controls_column1 = widgets.VBox( + [title_column1, row_x, row_y, row_angle, row_fraction] + ) + controls_column2 = widgets.VBox( + [title_column2, self.slice_axis_widget, self.slice_index_widget] + ) controls = widgets.HBox([controls_column1, controls_column2]) interactive_plot = widgets.interactive_output( - self.update, - {'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget, - 'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget, - 'fraction_range': self.line_fraction_widget} + self.update, + { + "slice_axis": self.slice_axis_widget, + "slice_index": self.slice_index_widget, + "x": self.x_widget, + "y": self.y_widget, + "angle_deg": self.angle_widget, + "fraction_range": self.line_fraction_widget, + }, ) return widgets.VBox([controls, interactive_plot]) + def line_profile( - volume: np.ndarray, - slice_axis: int=0, - slice_index: int | str='middle', - vertical_position: int | str='middle', - horizontal_position: int | str='middle', - angle: int=0, - fraction_range: Tuple[float,float]=(0.00, 1.00) - ) -> widgets.interactive: - """Returns an interactive widget for visualizing the intensity profiles of lines on slices. + volume: np.ndarray, + slice_axis: int = 0, + slice_index: int | str = "middle", + vertical_position: int | str = "middle", + horizontal_position: int | str = "middle", + angle: int = 0, + fraction_range: Tuple[float, float] = (0.00, 1.00), +) -> widgets.interactive: + """ + Returns an interactive widget for visualizing the intensity profiles of lines on slices. Args: volume (np.ndarray): The 3D volume of interest. @@ -1181,13 +1237,13 @@ def line_profile( slice_index (int or str, optional): Specifies the initial slice index along slice_axis. vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point. horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point. - angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. + angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0]. Returns: widget (widgets.widget_box.VBox): The interactive widget. - + Example: ```python import qim3d @@ -1198,41 +1254,270 @@ def line_profile(  """ + def parse_position(pos, pos_range, name): if isinstance(pos, int): if not pos_range[0] <= pos < pos_range[1]: - raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]') + raise ValueError( + f"Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]" + ) return pos elif isinstance(pos, str): pos = pos.lower() - if pos == 'start': return pos_range[0] - elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 - elif pos == 'end': return pos_range[1] + if pos == "start": + return pos_range[0] + elif pos == "middle": + return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 + elif pos == "end": + return pos_range[1] else: raise ValueError( f"Invalid string '{pos}' for {name}. " "Must be 'start', 'middle', or 'end'." ) else: - raise TypeError(f'Axis position must be of type int or str.') - + raise TypeError("Axis position must be of type int or str.") + if not isinstance(volume, (np.ndarray, da.core.Array)): raise ValueError("Data type for volume not supported.") if volume.ndim != 3: raise ValueError("Volume must be 3D.") - + dims = volume.shape - slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index') + slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), "slice_index") # the omission of the ends for the pivot point is due to border issues. - vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position') - horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position') - + vertical_position = parse_position( + vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), "vertical_position" + ) + horizontal_position = parse_position( + horizontal_position, + (1, np.delete(dims, slice_axis)[1] - 2), + "horizontal_position", + ) + if not isinstance(angle, int | float): raise ValueError("Invalid type for angle.") angle = round(angle) % 360 - if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]): + if not ( + 0.0 <= fraction_range[0] <= 1.0 + and 0.0 <= fraction_range[1] <= 1.0 + and fraction_range[0] <= fraction_range[1] + ): raise ValueError("Invalid values for fraction_range.") - lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range) + lp = _LineProfile( + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ) return lp.build_interactive() + + +def threshold( + volume: np.ndarray, + cmap_image: str = 'magma', + vmin: float = None, + vmax: float = None, +) -> widgets.VBox: + """ + This function provides an interactive interface to explore thresholding on a + 3D volume slice-by-slice. Users can either manually set the threshold value + using a slider or select an automatic thresholding method from `skimage`. + + The visualization includes the original image slice, a binary mask showing regions above the + threshold and an overlay combining the binary mask and the original image. + + Args: + volume (np.ndarray): 3D volume to threshold. + cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'. + cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'. + vmin (float, optional): Minimum value for the colormap. Defaults to None. + vmax (float, optional): Maximum value for the colormap. Defaults to None. + + Returns: + slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume. + + Interactivity: + - **Manual Thresholding**: + Select 'Manual' from the dropdown menu to manually adjust the threshold + using the slider. + - **Automatic Thresholding**: + Choose a method from the dropdown menu to apply an automatic thresholding + algorithm. Available methods include: + - Otsu + - Isodata + - Li + - Mean + - Minimum + - Triangle + - Yen + + The threshold slider will display the computed value and will be disabled + in this mode. + + + ```python + import qim3d + + # Load a sample volume + vol = qim3d.examples.bone_128x128x128 + + # Visualize interactive thresholding + qim3d.viz.threshold(vol) + ``` +  + + """ + + # Centralized state dictionary to track current parameters + state = { + "position": volume.shape[0] // 2, + "threshold": int((volume.min() + volume.max()) / 2), + "method": "Manual", + } + + threshold_methods = { + "Otsu": threshold_otsu, + "Isodata": threshold_isodata, + "Li": threshold_li, + "Mean": threshold_mean, + "Minimum": threshold_minimum, + "Triangle": threshold_triangle, + "Yen": threshold_yen, + } + + # Create an output widget to display the plot + output = widgets.Output() + + # Function to update the state and trigger visualization + def update_state(change): + # Update state based on widget values + state["position"] = position_slider.value + state["method"] = method_dropdown.value + + if state["method"] == "Manual": + state["threshold"] = threshold_slider.value + threshold_slider.disabled = False + else: + threshold_func = threshold_methods.get(state["method"]) + if threshold_func: + slice_img = volume[state["position"], :, :] + computed_threshold = threshold_func(slice_img) + state["threshold"] = computed_threshold + + # Programmatically update the slider without triggering callbacks + threshold_slider.unobserve_all() + threshold_slider.value = computed_threshold + threshold_slider.disabled = True + threshold_slider.observe(update_state, names="value") + else: + raise ValueError(f"Unsupported thresholding method: {state['method']}") + + # Trigger visualization + update_visualization() + + # Visualization function + def update_visualization(): + slice_img = volume[state["position"], :, :] + with output: + output.clear_output(wait=True) # Clear previous plot + fig, axes = plt.subplots(1, 4, figsize=(25, 5)) + + # Original image + new_vmin = ( + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin + ) + new_vmax = ( + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax + ) + axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) + axes[0].set_title("Original") + axes[0].axis("off") + + # Histogram + histogram( + volume=volume, + bins=32, + slice_idx=state["position"], + vertical_line=state["threshold"], + axis=1, + kde=False, + ax=axes[1], + show=False, + ) + axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}") + + # Binary mask + mask = slice_img > state["threshold"] + axes[2].imshow(mask, cmap="gray") + axes[2].set_title("Binary mask") + axes[2].axis("off") + + # Overlay + mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + mask_rgb[:, :, 0] = mask + masked_volume = qim3d.operations.overlay_rgb_images( + background=slice_img, + foreground=mask_rgb, + ) + axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax) + axes[3].set_title("Overlay") + axes[3].axis("off") + + plt.show() + + # Widgets + position_slider = widgets.IntSlider( + value=state["position"], + min=0, + max=volume.shape[0] - 1, + description="Slice", + ) + + threshold_slider = widgets.IntSlider( + value=state["threshold"], + min=volume.min(), + max=volume.max(), + description="Threshold", + ) + + method_dropdown = widgets.Dropdown( + options=[ + "Manual", + "Otsu", + "Isodata", + "Li", + "Mean", + "Minimum", + "Triangle", + "Yen", + ], + value=state["method"], + description="Method", + ) + + # Attach the state update function to widgets + position_slider.observe(update_state, names="value") + threshold_slider.observe(update_state, names="value") + method_dropdown.observe(update_state, names="value") + + # Layout + controls_left = widgets.VBox([position_slider, threshold_slider]) + controls_right = widgets.VBox([method_dropdown]) + controls_layout = widgets.HBox( + [controls_left, controls_right], + layout=widgets.Layout(justify_content="flex-start"), + ) + interactive_ui = widgets.VBox([controls_layout, output]) + update_visualization() + + return interactive_ui