Skip to content
Snippets Groups Projects
Commit 88d24bb7 authored by Alessia Saccardo's avatar Alessia Saccardo
Browse files

add histogram

parent d7800a05
No related branches found
No related tags found
1 merge request!135Threshold exploration
......@@ -8,6 +8,7 @@ from .explore import (
slices,
chunks,
histogram,
threshold,
)
from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
from .k3d import vol, mesh
......
......@@ -830,6 +830,7 @@ def histogram(
element="step",
return_fig=False,
show=True,
ax=None, # New parameter for target axes
**sns_kwargs,
):
"""
......@@ -853,40 +854,26 @@ def histogram(
element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
**sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
Returns:
Optional[matplotlib.figure.Figure]: If `return_fig` is True, returns the generated figure object. Otherwise, returns None.
Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
If `return_fig` is True, returns the generated figure object.
If `return_fig` is False and `ax` is provided, returns the `Axes` object.
Otherwise, returns None.
Raises:
ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol)
```
![viz histogram](assets/screenshots/viz-histogram-vol.png)
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True)
```
![viz histogram](assets/screenshots/viz-histogram-slice.png)
"""
if not (0 <= axis < vol.ndim):
raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
if slice_idx == "middle":
slice_idx = vol.shape[axis] // 2
if slice_idx:
if slice_idx is not None:
if 0 <= slice_idx < vol.shape[axis]:
img_slice = np.take(vol, indices=slice_idx, axis=axis)
data = img_slice.ravel()
......@@ -899,10 +886,14 @@ def histogram(
data = vol.ravel()
title = f"Intensity histogram for whole volume {vol.shape}"
# Use provided Axes or create new figure
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
if log_scale:
plt.yscale("log")
ax.set_yscale("log")
if color == "qim3d":
color = qim3d.viz.colormaps.qim(1.0)
......@@ -914,42 +905,32 @@ def histogram(
color=color,
element=element,
edgecolor=edgecolor,
ax=ax, # Plot directly on the specified Axes
**sns_kwargs,
)
if despine:
sns.despine(
fig=None,
ax=None,
top=True,
right=True,
left=False,
bottom=False,
offset={"left": 0, "bottom": 18},
trim=True,
)
sns.despine(ax=ax, top=True, right=True)
plt.xlabel("Voxel Intensity")
plt.ylabel("Frequency")
ax.set_xlabel("Voxel Intensity")
ax.set_ylabel("Frequency")
if show_title:
plt.title(title, fontsize=10)
ax.set_title(title, fontsize=10)
# Handle show and return
if show:
if show and fig is not None:
plt.show()
else:
plt.close(fig)
if return_fig:
return fig
elif ax is not None:
return ax
def threshold(
volume: np.ndarray,
cmap_image: str = 'viridis',
cmap_threshold: str = 'gray',
cmap_overlay: str = 'gray',
vmin: float = None,
vmax: float = None,
):
......@@ -1014,7 +995,7 @@ def threshold(
# Create the interactive widget
def _slicer(position, threshold, method):
fig, axes = plt.subplots(1, 4, figsize=(9, 3))
fig, axes = plt.subplots(1, 4, figsize=(25, 5))
slice_img = volume[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
......@@ -1030,12 +1011,12 @@ def threshold(
else vmax
)
# Add original image to the plot
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title('Original')
axes[0].axis('off')
# Compute the threshold value
if method == 'Manual':
threshold_slider.disabled = False
else:
......@@ -1050,16 +1031,37 @@ def threshold(
else:
raise ValueError(f"Unsupported thresholding method: {method}")
# Compute and add the histogram to the plot
histogram(
vol=volume,
bins=32,
slice_idx=position,
axis=1,
kde=False,
ax=axes[1],
show=False,
)
axes[1].axvline(
x=threshold,
color="red",
linestyle="--",
linewidth=2,
label=f"Threshold = {threshold}",
)
axes[1].set_title("Histogram")
# Compute and add the binary mask to the plot
mask = slice_img > threshold
axes[2].imshow(mask, cmap=cmap_threshold)
axes[2].imshow(mask, cmap='grey')
axes[2].set_title('Binary mask')
axes[2].axis('off')
# Compute and add the overlay to the plot
masked_volume = qim3d.processing.operations.overlay_rgb_images(
background = slice_img,
foreground = mask,
)
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = (
......@@ -1072,7 +1074,7 @@ def threshold(
if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume))
else vmax
)
axes[3].imshow(masked_volume, cmap=cmap_threshold, vmin=new_vmin, vmax=new_vmax)
axes[3].imshow(masked_volume, cmap=cmap_overlay, vmin=new_vmin, vmax=new_vmax)
axes[3].set_title('Overlay')
axes[3].axis('off')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment