diff --git a/docs/assets/screenshots/interactive_thresholding.gif b/docs/assets/screenshots/interactive_thresholding.gif new file mode 100644 index 0000000000000000000000000000000000000000..80efb01b8e74334ff40277b6905a1a0f95589a95 Binary files /dev/null and b/docs/assets/screenshots/interactive_thresholding.gif differ diff --git a/docs/doc/visualization/viz.md b/docs/doc/visualization/viz.md index 7a4b39a8ec49409fe51b190a079971a7568f852b..2cdb54d6130c8ec7c47bc856a79bc75f7bcabcaa 100644 --- a/docs/doc/visualization/viz.md +++ b/docs/doc/visualization/viz.md @@ -24,6 +24,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - colormaps - fade_mask - line_profile + - threshold ::: qim3d.viz.colormaps options: diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index a2be34e4c2addf8b4641d9d9c7d3a292e9423581..e7ba84a55f22e9b7589228823a7e69f548870baa 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -9,7 +9,8 @@ from ._data_exploration import ( slices_grid, chunks, histogram, - line_profile + line_profile, + threshold ) from ._detection import circles from ._k3d import mesh, volumetric diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index b6ae03ac8f3e4494969ad84339cd3bad1bdd9de4..855d730bf6487869f3e64c5e4b5ac90c1a03bdda 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -1236,3 +1236,193 @@ def line_profile( lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range) return lp.build_interactive() + +def threshold( + volume: np.ndarray, + cmap_image: str = 'viridis', + vmin: float = None, + vmax: float = None, +) -> widgets.VBox: + """ + This function provides an interactive interface to explore thresholding on a + 3D volume slice-by-slice. Users can either manually set the threshold value + using a slider or select an automatic thresholding method from `skimage`. + + The visualization includes the original image slice, a binary mask showing regions above the + threshold and an overlay combining the binary mask and the original image. + + Args: + volume (np.ndarray): 3D volume to threshold. + cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'. + cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'. + vmin (float, optional): Minimum value for the colormap. Defaults to None. + vmax (float, optional): Maximum value for the colormap. Defaults to None. + + Returns: + slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume. + + Interactivity: + - **Manual Thresholding**: + Select 'Manual' from the dropdown menu to manually adjust the threshold + using the slider. + - **Automatic Thresholding**: + Choose a method from the dropdown menu to apply an automatic thresholding + algorithm. Available methods include: + - Otsu + - Isodata + - Li + - Mean + - Minimum + - Triangle + - Yen + + The threshold slider will display the computed value and will be disabled + in this mode. + + + ```python + import qim3d + + # Load a sample volume + vol = qim3d.examples.bone_128x128x128 + + # Visualize interactive thresholding + qim3d.viz.threshold(vol) + ``` +  + """ + + # Centralized state dictionary to track current parameters + state = { + 'position': volume.shape[0] // 2, + 'threshold': int((volume.min() + volume.max()) / 2), + 'method': 'Manual', + } + + threshold_methods = { + 'Otsu': threshold_otsu, + 'Isodata': threshold_isodata, + 'Li': threshold_li, + 'Mean': threshold_mean, + 'Minimum': threshold_minimum, + 'Triangle': threshold_triangle, + 'Yen': threshold_yen, + } + + # Create an output widget to display the plot + output = widgets.Output() + + # Function to update the state and trigger visualization + def update_state(change): + # Update state based on widget values + state['position'] = position_slider.value + state['method'] = method_dropdown.value + + if state['method'] == 'Manual': + state['threshold'] = threshold_slider.value + threshold_slider.disabled = False + else: + threshold_func = threshold_methods.get(state['method']) + if threshold_func: + slice_img = volume[state['position'], :, :] + computed_threshold = threshold_func(slice_img) + state['threshold'] = computed_threshold + + # Programmatically update the slider without triggering callbacks + threshold_slider.unobserve_all() + threshold_slider.value = computed_threshold + threshold_slider.disabled = True + threshold_slider.observe(update_state, names='value') + else: + raise ValueError(f"Unsupported thresholding method: {state['method']}") + + # Trigger visualization + update_visualization() + + # Visualization function + def update_visualization(): + slice_img = volume[state['position'], :, :] + with output: + output.clear_output(wait=True) # Clear previous plot + fig, axes = plt.subplots(1, 4, figsize=(25, 5)) + + # Original image + new_vmin = ( + None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin + ) + new_vmax = ( + None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + ) + axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) + axes[0].set_title('Original') + axes[0].axis('off') + + # Histogram + histogram( + volume=volume, + bins=32, + slice_idx=state['position'], + vertical_line=state['threshold'], + axis=1, + kde=False, + ax=axes[1], + show=False, + ) + axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}") + + # Binary mask + mask = slice_img > state['threshold'] + axes[2].imshow(mask, cmap='gray') + axes[2].set_title('Binary mask') + axes[2].axis('off') + + # Overlay + mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + mask_rgb[:, :, 0] = mask + masked_volume = qim3d.processing.operations.overlay_rgb_images( + background=slice_img, + foreground=mask_rgb, + ) + axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax) + axes[3].set_title('Overlay') + axes[3].axis('off') + + plt.show() + + # Widgets + position_slider = widgets.IntSlider( + value=state['position'], + min=0, + max=volume.shape[0] - 1, + description='Slice', + ) + + threshold_slider = widgets.IntSlider( + value=state['threshold'], + min=volume.min(), + max=volume.max(), + description='Threshold', + ) + + method_dropdown = widgets.Dropdown( + options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'], + value=state['method'], + description='Method', + ) + + # Attach the state update function to widgets + position_slider.observe(update_state, names='value') + threshold_slider.observe(update_state, names='value') + method_dropdown.observe(update_state, names='value') + + # Layout + controls_left = widgets.VBox([position_slider, threshold_slider]) + controls_right = widgets.VBox([method_dropdown]) + controls_layout = widgets.HBox( + [controls_left, controls_right], + layout=widgets.Layout(justify_content='flex-start'), + ) + interactive_ui = widgets.VBox([controls_layout, output]) + update_visualization() + + return interactive_ui