Skip to content
Snippets Groups Projects

Threshold exploration

Closed s212246 requested to merge threshold-exploration into main
1 file
+ 27
9
Compare changes
  • Side-by-side
  • Inline
+ 246
50
@@ -5,7 +5,7 @@ Provides a collection of visualization functions.
import math
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import dask.array as da
import ipywidgets as widgets
@@ -15,6 +15,8 @@ import matplotlib
import numpy as np
import zarr
from qim3d.utils.logger import log
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown
from skimage.filters import threshold_otsu, threshold_isodata, threshold_li, threshold_mean, threshold_minimum, threshold_triangle, threshold_yen
import seaborn as sns
import qim3d
@@ -813,35 +815,37 @@ def chunks(zarr_path: str, **kwargs):
# Display the VBox
display(final_layout)
def histogram(
vol: np.ndarray,
volume: np.ndarray,
bins: Union[int, str] = "auto",
slice_idx: Union[int, str] = None,
slice_idx: Union[int, str, None] = None,
vertical_line: int = None,
axis: int = 0,
kde: bool = True,
log_scale: bool = False,
despine: bool = True,
show_title: bool = True,
color="qim3d",
edgecolor=None,
figsize=(8, 4.5),
element="step",
return_fig=False,
show=True,
**sns_kwargs,
):
color: str = "qim3d",
edgecolor: Optional[str] = None,
figsize: Tuple[float, float] = (8, 4.5),
element: str = "step",
return_fig: bool = False,
show: bool = True,
ax: Optional[plt.Axes] = None,
**sns_kwargs: Union[str, float, int, bool]
) -> Optional[Union[plt.Figure, plt.Axes]]:
"""
Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
Args:
vol (np.ndarray): A 3D NumPy array representing the volume to be visualized.
volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
@@ -852,56 +856,46 @@ def histogram(
element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
**sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
Returns:
Optional[matplotlib.figure.Figure]: If `return_fig` is True, returns the generated figure object. Otherwise, returns None.
Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
If `return_fig` is True, returns the generated figure object.
If `return_fig` is False and `ax` is provided, returns the `Axes` object.
Otherwise, returns None.
Raises:
ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol)
```
![viz histogram](assets/screenshots/viz-histogram-vol.png)
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True)
```
![viz histogram](assets/screenshots/viz-histogram-slice.png)
"""
if not (0 <= axis < vol.ndim):
raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
if not (0 <= axis < volume.ndim):
raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == "middle":
slice_idx = vol.shape[axis] // 2
slice_idx = volume.shape[axis] // 2
if slice_idx:
if 0 <= slice_idx < vol.shape[axis]:
img_slice = np.take(vol, indices=slice_idx, axis=axis)
if slice_idx is not None:
if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel()
title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
else:
raise ValueError(
f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}."
f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
)
else:
data = vol.ravel()
title = f"Intensity histogram for whole volume {vol.shape}"
data = volume.ravel()
title = f"Intensity histogram for whole volume {volume.shape}"
# Use provided Axes or create new figure
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
if log_scale:
plt.yscale("log")
ax.set_yscale("log")
if color == "qim3d":
color = qim3d.viz.colormaps.qim(1.0)
@@ -913,13 +907,23 @@ def histogram(
color=color,
element=element,
edgecolor=edgecolor,
ax=ax, # Plot directly on the specified Axes
**sns_kwargs,
)
if vertical_line is not None:
ax.axvline(
x=vertical_line,
color='red',
linestyle="--",
linewidth=2,
)
if despine:
sns.despine(
fig=None,
ax=None,
ax=ax,
top=True,
right=True,
left=False,
@@ -928,17 +932,209 @@ def histogram(
trim=True,
)
plt.xlabel("Voxel Intensity")
plt.ylabel("Frequency")
ax.set_xlabel("Voxel Intensity")
ax.set_ylabel("Frequency")
if show_title:
plt.title(title, fontsize=10)
ax.set_title(title, fontsize=10)
# Handle show and return
if show:
if show and fig is not None:
plt.show()
else:
plt.close(fig)
if return_fig:
return fig
elif ax is not None:
return ax
def threshold(
volume: np.ndarray,
cmap_image: str = 'viridis',
vmin: float = None,
vmax: float = None,
) -> widgets.VBox:
"""
This function provides an interactive interface to explore thresholding on a
3D volume slice-by-slice. Users can either manually set the threshold value
using a slider or select an automatic thresholding method from `skimage`.
The visualization includes the original image slice, a binary mask showing regions above the
threshold and an overlay combining the binary mask and the original image.
Args:
volume (np.ndarray): 3D volume to threshold.
cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'.
cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'.
vmin (float, optional): Minimum value for the colormap. Defaults to None.
vmax (float, optional): Maximum value for the colormap. Defaults to None.
Returns:
slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume.
Interactivity:
- **Manual Thresholding**:
Select 'Manual' from the dropdown menu to manually adjust the threshold
using the slider.
- **Automatic Thresholding**:
Choose a method from the dropdown menu to apply an automatic thresholding
algorithm. Available methods include:
- Otsu
- Isodata
- Li
- Mean
- Minimum
- Triangle
- Yen
The threshold slider will display the computed value and will be disabled
in this mode.
```python
import qim3d
# Load a sample volume
vol = qim3d.examples.bone_128x128x128
# Visualize interactive thresholding
qim3d.viz.threshold(vol)
```
![interactive threshold](assets/screenshots/interactive_thresholding.gif)
"""
# Centralized state dictionary to track current parameters
state = {
'position': volume.shape[0] // 2,
'threshold': int((volume.min() + volume.max()) / 2),
'method': 'Manual',
}
threshold_methods = {
'Otsu': threshold_otsu,
'Isodata': threshold_isodata,
'Li': threshold_li,
'Mean': threshold_mean,
'Minimum': threshold_minimum,
'Triangle': threshold_triangle,
'Yen': threshold_yen,
}
# Create an output widget to display the plot
output = widgets.Output()
# Function to update the state and trigger visualization
def update_state(change):
# Update state based on widget values
state['position'] = position_slider.value
state['method'] = method_dropdown.value
if state['method'] == 'Manual':
state['threshold'] = threshold_slider.value
threshold_slider.disabled = False
else:
threshold_func = threshold_methods.get(state['method'])
if threshold_func:
slice_img = volume[state['position'], :, :]
computed_threshold = threshold_func(slice_img)
state['threshold'] = computed_threshold
# Programmatically update the slider without triggering callbacks
threshold_slider.unobserve_all()
threshold_slider.value = computed_threshold
threshold_slider.disabled = True
threshold_slider.observe(update_state, names='value')
else:
raise ValueError(f"Unsupported thresholding method: {state['method']}")
# Trigger visualization
update_visualization()
# Visualization function
def update_visualization():
slice_img = volume[state['position'], :, :]
with output:
output.clear_output(wait=True) # Clear previous plot
fig, axes = plt.subplots(1, 4, figsize=(25, 5))
# Original image
new_vmin = (
None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
)
new_vmax = (
None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
)
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title('Original')
axes[0].axis('off')
# Histogram
histogram(
volume=volume,
bins=32,
slice_idx=state['position'],
vertical_line=state['threshold'],
axis=1,
kde=False,
ax=axes[1],
show=False,
)
axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")
# Binary mask
mask = slice_img > state['threshold']
axes[2].imshow(mask, cmap='gray')
axes[2].set_title('Binary mask')
axes[2].axis('off')
# Overlay
mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
mask_rgb[:, :, 0] = mask
masked_volume = qim3d.processing.operations.overlay_rgb_images(
background=slice_img,
foreground=mask_rgb,
)
axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
axes[3].set_title('Overlay')
axes[3].axis('off')
plt.show()
# Widgets
position_slider = widgets.IntSlider(
value=state['position'],
min=0,
max=volume.shape[0] - 1,
description='Slice',
)
threshold_slider = widgets.IntSlider(
value=state['threshold'],
min=volume.min(),
max=volume.max(),
description='Threshold',
)
method_dropdown = widgets.Dropdown(
options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
value=state['method'],
description='Method',
)
# Attach the state update function to widgets
position_slider.observe(update_state, names='value')
threshold_slider.observe(update_state, names='value')
method_dropdown.observe(update_state, names='value')
# Layout
controls_left = widgets.VBox([position_slider, threshold_slider])
controls_right = widgets.VBox([method_dropdown])
controls_layout = widgets.HBox(
[controls_left, controls_right],
layout=widgets.Layout(justify_content='flex-start'),
)
interactive_ui = widgets.VBox([controls_layout, output])
update_visualization()
return interactive_ui
\ No newline at end of file
Loading