diff --git a/docs/assets/screenshots/interactive_thresholding.gif b/docs/assets/screenshots/interactive_thresholding.gif new file mode 100644 index 0000000000000000000000000000000000000000..80efb01b8e74334ff40277b6905a1a0f95589a95 Binary files /dev/null and b/docs/assets/screenshots/interactive_thresholding.gif differ diff --git a/docs/viz.md b/docs/viz.md index 8b6788b60610d4eee757505ca28bdac5b8691572..6e798f90181128735b93e54301547b7704159f87 100644 --- a/docs/viz.md +++ b/docs/viz.md @@ -10,6 +10,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - orthogonal - vol - chunks + - threshold - itk_vtk - mesh - local_thickness diff --git a/qim3d/processing/operations.py b/qim3d/processing/operations.py index e559a16d0100186e36ba96b597b6bdfa58f286fb..565d694ff18b9edb99412553f7aed0bb997e2f43 100644 --- a/qim3d/processing/operations.py +++ b/qim3d/processing/operations.py @@ -1,6 +1,8 @@ import numpy as np import qim3d.processing.filters as filters from qim3d.utils.logger import log +import skimage +import scipy def remove_background( @@ -88,8 +90,6 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i  """ - import skimage - import scipy if len(np.unique(bin_vol)) > 2: raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.") diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 33d94416d98cdf1af5d2c9fca548416eee65a359..dd46be61a889bf7eac970097d7b3815a993616cc 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -8,6 +8,7 @@ from .explore import ( slices, chunks, histogram, + threshold, ) from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError from .k3d import vol, mesh diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py index 8fe2a545adcf4540c8fdc50285759837e05b6fac..383a23c059a6faa2e888b8a4d1677d7477e5e5cf 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/explore.py @@ -5,7 +5,7 @@ Provides a collection of visualization functions. import math import warnings -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import dask.array as da import ipywidgets as widgets @@ -15,6 +15,8 @@ import matplotlib import numpy as np import zarr from qim3d.utils.logger import log +from ipywidgets import interact, IntSlider, FloatSlider, Dropdown +from skimage.filters import threshold_otsu, threshold_isodata, threshold_li, threshold_mean, threshold_minimum, threshold_triangle, threshold_yen import seaborn as sns import qim3d @@ -813,35 +815,37 @@ def chunks(zarr_path: str, **kwargs): # Display the VBox display(final_layout) - def histogram( - vol: np.ndarray, + volume: np.ndarray, bins: Union[int, str] = "auto", - slice_idx: Union[int, str] = None, + slice_idx: Union[int, str, None] = None, + vertical_line: int = None, axis: int = 0, kde: bool = True, log_scale: bool = False, despine: bool = True, show_title: bool = True, - color="qim3d", - edgecolor=None, - figsize=(8, 4.5), - element="step", - return_fig=False, - show=True, - **sns_kwargs, -): + color: str = "qim3d", + edgecolor: Optional[str] = None, + figsize: Tuple[float, float] = (8, 4.5), + element: str = "step", + return_fig: bool = False, + show: bool = True, + ax: Optional[plt.Axes] = None, + **sns_kwargs: Union[str, float, int, bool] +) -> Optional[Union[plt.Figure, plt.Axes]]: """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Args: - vol (np.ndarray): A 3D NumPy array representing the volume to be visualized. + volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". axis (int, optional): Axis along which to take a slice. Default is 0. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None. + vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None. kde (bool, optional): Whether to overlay a kernel density estimate. Default is True. log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False. despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True. @@ -852,56 +856,46 @@ def histogram( element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step". return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False. show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True. + ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None. **sns_kwargs: Additional keyword arguments for `seaborn.histplot`. Returns: - Optional[matplotlib.figure.Figure]: If `return_fig` is True, returns the generated figure object. Otherwise, returns None. + Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: + If `return_fig` is True, returns the generated figure object. + If `return_fig` is False and `ax` is provided, returns the `Axes` object. + Otherwise, returns None. Raises: ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `slice_idx` is an integer and is out of range for the specified axis. - - Example: - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol) - ``` -  - - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True) - ``` -  """ - - if not (0 <= axis < vol.ndim): - raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") + if not (0 <= axis < volume.ndim): + raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.") if slice_idx == "middle": - slice_idx = vol.shape[axis] // 2 + slice_idx = volume.shape[axis] // 2 - if slice_idx: - if 0 <= slice_idx < vol.shape[axis]: - img_slice = np.take(vol, indices=slice_idx, axis=axis) + if slice_idx is not None: + if 0 <= slice_idx < volume.shape[axis]: + img_slice = np.take(volume, indices=slice_idx, axis=axis) data = img_slice.ravel() title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" else: raise ValueError( - f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}." + f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}." ) else: - data = vol.ravel() - title = f"Intensity histogram for whole volume {vol.shape}" + data = volume.ravel() + title = f"Intensity histogram for whole volume {volume.shape}" - fig, ax = plt.subplots(figsize=figsize) + # Use provided Axes or create new figure + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = None if log_scale: - plt.yscale("log") + ax.set_yscale("log") if color == "qim3d": color = qim3d.viz.colormaps.qim(1.0) @@ -913,13 +907,23 @@ def histogram( color=color, element=element, edgecolor=edgecolor, + ax=ax, # Plot directly on the specified Axes **sns_kwargs, ) + if vertical_line is not None: + ax.axvline( + x=vertical_line, + color='red', + linestyle="--", + linewidth=2, + + ) + if despine: sns.despine( fig=None, - ax=None, + ax=ax, top=True, right=True, left=False, @@ -928,17 +932,209 @@ def histogram( trim=True, ) - plt.xlabel("Voxel Intensity") - plt.ylabel("Frequency") + ax.set_xlabel("Voxel Intensity") + ax.set_ylabel("Frequency") if show_title: - plt.title(title, fontsize=10) + ax.set_title(title, fontsize=10) # Handle show and return - if show: + if show and fig is not None: plt.show() - else: - plt.close(fig) if return_fig: return fig + elif ax is not None: + return ax + +def threshold( + volume: np.ndarray, + cmap_image: str = 'viridis', + vmin: float = None, + vmax: float = None, +) -> widgets.VBox: + """ + This function provides an interactive interface to explore thresholding on a + 3D volume slice-by-slice. Users can either manually set the threshold value + using a slider or select an automatic thresholding method from `skimage`. + + The visualization includes the original image slice, a binary mask showing regions above the + threshold and an overlay combining the binary mask and the original image. + + Args: + volume (np.ndarray): 3D volume to threshold. + cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'. + cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'. + vmin (float, optional): Minimum value for the colormap. Defaults to None. + vmax (float, optional): Maximum value for the colormap. Defaults to None. + + Returns: + slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume. + + Interactivity: + - **Manual Thresholding**: + Select 'Manual' from the dropdown menu to manually adjust the threshold + using the slider. + - **Automatic Thresholding**: + Choose a method from the dropdown menu to apply an automatic thresholding + algorithm. Available methods include: + - Otsu + - Isodata + - Li + - Mean + - Minimum + - Triangle + - Yen + + The threshold slider will display the computed value and will be disabled + in this mode. + + + ```python + import qim3d + + # Load a sample volume + vol = qim3d.examples.bone_128x128x128 + + # Visualize interactive thresholding + qim3d.viz.threshold(vol) + ``` +  + """ + + # Centralized state dictionary to track current parameters + state = { + 'position': volume.shape[0] // 2, + 'threshold': int((volume.min() + volume.max()) / 2), + 'method': 'Manual', + } + + threshold_methods = { + 'Otsu': threshold_otsu, + 'Isodata': threshold_isodata, + 'Li': threshold_li, + 'Mean': threshold_mean, + 'Minimum': threshold_minimum, + 'Triangle': threshold_triangle, + 'Yen': threshold_yen, + } + + # Create an output widget to display the plot + output = widgets.Output() + + # Function to update the state and trigger visualization + def update_state(change): + # Update state based on widget values + state['position'] = position_slider.value + state['method'] = method_dropdown.value + + if state['method'] == 'Manual': + state['threshold'] = threshold_slider.value + threshold_slider.disabled = False + else: + threshold_func = threshold_methods.get(state['method']) + if threshold_func: + slice_img = volume[state['position'], :, :] + computed_threshold = threshold_func(slice_img) + state['threshold'] = computed_threshold + + # Programmatically update the slider without triggering callbacks + threshold_slider.unobserve_all() + threshold_slider.value = computed_threshold + threshold_slider.disabled = True + threshold_slider.observe(update_state, names='value') + else: + raise ValueError(f"Unsupported thresholding method: {state['method']}") + + # Trigger visualization + update_visualization() + + # Visualization function + def update_visualization(): + slice_img = volume[state['position'], :, :] + with output: + output.clear_output(wait=True) # Clear previous plot + fig, axes = plt.subplots(1, 4, figsize=(25, 5)) + + # Original image + new_vmin = ( + None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin + ) + new_vmax = ( + None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + ) + axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) + axes[0].set_title('Original') + axes[0].axis('off') + + # Histogram + histogram( + volume=volume, + bins=32, + slice_idx=state['position'], + vertical_line=state['threshold'], + axis=1, + kde=False, + ax=axes[1], + show=False, + ) + axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}") + + # Binary mask + mask = slice_img > state['threshold'] + axes[2].imshow(mask, cmap='gray') + axes[2].set_title('Binary mask') + axes[2].axis('off') + + # Overlay + mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + mask_rgb[:, :, 0] = mask + masked_volume = qim3d.processing.operations.overlay_rgb_images( + background=slice_img, + foreground=mask_rgb, + ) + axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax) + axes[3].set_title('Overlay') + axes[3].axis('off') + + plt.show() + + # Widgets + position_slider = widgets.IntSlider( + value=state['position'], + min=0, + max=volume.shape[0] - 1, + description='Slice', + ) + + threshold_slider = widgets.IntSlider( + value=state['threshold'], + min=volume.min(), + max=volume.max(), + description='Threshold', + ) + + method_dropdown = widgets.Dropdown( + options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'], + value=state['method'], + description='Method', + ) + + # Attach the state update function to widgets + position_slider.observe(update_state, names='value') + threshold_slider.observe(update_state, names='value') + method_dropdown.observe(update_state, names='value') + + # Layout + controls_left = widgets.VBox([position_slider, threshold_slider]) + controls_right = widgets.VBox([method_dropdown]) + controls_layout = widgets.HBox( + [controls_left, controls_right], + layout=widgets.Layout(justify_content='flex-start'), + ) + interactive_ui = widgets.VBox([controls_layout, output]) + update_visualization() + + return interactive_ui + + \ No newline at end of file