Skip to content
Snippets Groups Projects
Commit 0282cba5 authored by fima's avatar fima :beers:
Browse files

Merge branch 'structure_tensor_wrapper' into 'main'

Structure tensor wrapper

See merge request !66
parents 955e7f8b 4f33d647
No related branches found
No related tags found
1 merge request!66Structure tensor wrapper
......@@ -12,3 +12,5 @@
- append
::: qim3d.processing.structure_tensor
\ No newline at end of file
from .filters import *
from .local_thickness import local_thickness
from .structure_tensor import structure_tensor
from .detection import *
......@@ -4,7 +4,18 @@ from typing import Union, Type
import numpy as np
from scipy import ndimage
__all__ = ['Gaussian','Median','Maximum','Minimum','Pipeline','gaussian','median','maximum','minimum']
__all__ = [
"Gaussian",
"Median",
"Maximum",
"Minimum",
"Pipeline",
"gaussian",
"median",
"maximum",
"minimum",
]
class FilterBase:
def __init__(self, *args, **kwargs):
......@@ -18,6 +29,7 @@ class FilterBase:
self.args = args
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input):
"""
......@@ -31,6 +43,7 @@ class Gaussian(FilterBase):
"""
return gaussian(input, *self.args, **self.kwargs)
class Median(FilterBase):
def __call__(self, input):
"""
......@@ -44,6 +57,7 @@ class Median(FilterBase):
"""
return median(input, **self.kwargs)
class Maximum(FilterBase):
def __call__(self, input):
"""
......@@ -57,6 +71,7 @@ class Maximum(FilterBase):
"""
return maximum(input, **self.kwargs)
class Minimum(FilterBase):
def __call__(self, input):
"""
......@@ -70,6 +85,7 @@ class Minimum(FilterBase):
"""
return minimum(input, **self.kwargs)
class Pipeline:
def __init__(self, *args: Type[FilterBase]):
"""
......@@ -90,13 +106,17 @@ class Pipeline:
Args:
name: A string representing the name or identifier of the filter.
fn: An instance of a FilterBase subclass.
Raises:
AssertionError: If `fn` is not an instance of the FilterBase class.
"""
if not isinstance(fn,FilterBase):
filter_names = [subclass.__name__ for subclass in FilterBase.__subclasses__()]
raise AssertionError(f'filters should be instances of one of the following classes: {filter_names}')
if not isinstance(fn, FilterBase):
filter_names = [
subclass.__name__ for subclass in FilterBase.__subclasses__()
]
raise AssertionError(
f"filters should be instances of one of the following classes: {filter_names}"
)
self.filters[name] = fn
def append(self, fn: Type[FilterBase]):
......@@ -122,6 +142,7 @@ class Pipeline:
input = fn(input)
return input
def gaussian(vol, *args, **kwargs):
"""
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter.
......@@ -136,6 +157,7 @@ def gaussian(vol, *args, **kwargs):
"""
return ndimage.gaussian_filter(vol, *args, **kwargs)
def median(vol, **kwargs):
"""
Applies a median filter to the input volume using scipy.ndimage.median_filter.
......@@ -149,6 +171,7 @@ def median(vol, **kwargs):
"""
return ndimage.median_filter(vol, **kwargs)
def maximum(vol, **kwargs):
"""
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter.
......@@ -159,9 +182,10 @@ def maximum(vol, **kwargs):
Returns:
The filtered image or volume.
"""
"""
return ndimage.maximum_filter(vol, **kwargs)
def minimum(vol, **kwargs):
"""
Applies a minimum filter to the input volume using scipy.ndimage.mainimum_filter.
......@@ -173,4 +197,4 @@ def minimum(vol, **kwargs):
Returns:
The filtered image or volume.
"""
return ndimage.minimum_filter(vol, **kwargs)
\ No newline at end of file
return ndimage.minimum_filter(vol, **kwargs)
"""Wrapper for the structure tensor function from the structure_tensor package"""
from typing import Tuple
import numpy as np
import structure_tensor as st
from qim3d.viz.structure_tensor import vectors
def structure_tensor(
vol: np.ndarray,
sigma: float = 1.0,
rho: float = 6.0,
full: bool = False,
visualize=False,
**viz_kwargs
) -> Tuple[np.ndarray, np.ndarray]:
"""Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/)
Args:
vol (np.ndarray): 3D NumPy array representing the volume.
sigma (float): A noise scale, structures smaller than sigma will be removed by smoothing.
rho (float): An integration scale giving the size over the neighborhood in which the orientation is to be analysed.
full: A flag indicating that all three eigenvalues should be returned. Default is False.
visualize (bool, optional): Whether to visualize the structure tensor. Default is False.
**viz_kwargs: Additional keyword arguments for the visualization function. Only used if visualize=True.
Raises:
ValueError: If the input volume is not 3D.
Returns:
val: An array with shape `(3, *vol.shape)` containing the eigenvalues of the structure tensor.
vec: An array with shape `(3, *vol.shape)` if `full` is `False`, otherwise `(3, 3, *vol.shape)` containing eigenvectors.
!!! quote "Reference"
Jeppesen, N., et al. "Quantifying effects of manufacturing methods on fiber orientation in unidirectional composites using structure tensor analysis." Composites Part A: Applied Science and Manufacturing 149 (2021): 106541.
<https://doi.org/10.1016/j.compositesa.2021.106541>
```bibtex
@article{JEPPESEN2021106541,
title = {Quantifying effects of manufacturing methods on fiber orientation in unidirectional composites using structure tensor analysis},
journal = {Composites Part A: Applied Science and Manufacturing},
volume = {149},
pages = {106541},
year = {2021},
issn = {1359-835X},
doi = {https://doi.org/10.1016/j.compositesa.2021.106541},
url = {https://www.sciencedirect.com/science/article/pii/S1359835X21002633},
author = {N. Jeppesen and L.P. Mikkelsen and A.B. Dahl and A.N. Christensen and V.A. Dahl}
}
```
"""
if vol.ndim != 3:
raise ValueError("The input volume must be 3D")
# Ensure volume is a float
if vol.dtype != np.float32 and vol.dtype != np.float64:
vol = vol.astype(np.float32)
s_vol = st.structure_tensor_3d(vol, sigma, rho)
val, vec = st.eig_special_3d(s_vol, full=full)
if visualize:
display(vectors(vol, vec, **viz_kwargs))
return val, vec
import pytest
import numpy as np
import qim3d
def test_wrong_ndim():
img_2d = np.random.rand(50, 50)
with pytest.raises(ValueError, match = "The input volume must be 3D"):
qim3d.processing.structure_tensor(img_2d, 1.5, 1.5)
def test_structure_tensor():
volume = np.random.rand(50, 50, 50)
val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5)
assert val.shape == (3, 50, 50, 50)
assert vec.shape == (3, 50, 50, 50)
assert np.all(val[0] <= val[1])
assert np.all(val[1] <= val[2])
assert np.all(val[0] <= val[2])
def test_structure_tensor_full():
volume = np.random.rand(50, 50, 50)
val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5, full=True)
assert val.shape == (3, 50, 50, 50)
assert vec.shape == (3, 3, 50, 50, 50)
assert np.all(val[0] <= val[1])
assert np.all(val[1] <= val[2])
assert np.all(val[0] <= val[2])
\ No newline at end of file
from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
from .k3d import vol
from .structure_tensor import vectors
from .colormaps import objects
from .detection import circles
......@@ -590,29 +590,30 @@ def plot_cc(
return
def local_thickness(
image: np.ndarray,
image_lt: np.ndarray,
max_projection: bool = False,
axis: int = 0,
slice_idx: Optional[Union[int, float]] = None,
show: bool = False,
figsize: Tuple[int, int] = (15, 5)
) -> Union[plt.Figure, widgets.interactive]:
image: np.ndarray,
image_lt: np.ndarray,
max_projection: bool = False,
axis: int = 0,
slice_idx: Optional[Union[int, float]] = None,
show: bool = False,
figsize: Tuple[int, int] = (15, 5),
) -> Union[plt.Figure, widgets.interactive]:
"""Visualizes the local thickness of a 2D or 3D image.
Args:
image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
image/volume.
max_projection (bool, optional): If True, displays the maximum projection of the local
thickness. Only used for 3D images. Defaults to False.
axis (int, optional): The axis along which to visualize the local thickness.
axis (int, optional): The axis along which to visualize the local thickness.
Unused for 2D images.
Defaults to 0.
slice_idx (int or float, optional): The initial slice to be visualized. The slice index
slice_idx (int or float, optional): The initial slice to be visualized. The slice index
can afterwards be changed. If value is an integer, it will be the index of the slice
to be visualized. If value is a float between 0 and 1, it will be multiplied by the
to be visualized. If value is a float between 0 and 1, it will be multiplied by the
number of slices and rounded to the nearest integer. If None, the middle slice will
be used for 3D images. Unused for 2D images. Defaults to None.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
......@@ -628,12 +629,13 @@ def local_thickness(
image_lt = qim3d.processing.local_thickness(image)
qim3d.viz.local_thickness(image, image_lt, slice_idx=10)
"""
def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None):
if slice_idx is not None:
image = image.take(slice_idx, axis=axis)
image_lt = image_lt.take(slice_idx, axis=axis)
fig, axs = plt.subplots(1, 3, figsize=figsize,layout="constrained")
fig, axs = plt.subplots(1, 3, figsize=figsize, layout="constrained")
axs[0].imshow(image, cmap="gray")
axs[0].set_title("Original image")
......@@ -643,9 +645,11 @@ def local_thickness(
axs[1].set_title("Local thickness")
axs[1].axis("off")
plt.colorbar(axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical")
plt.colorbar(
axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical"
)
axs[2].hist(image_lt[image_lt>0].ravel(), bins=32, edgecolor='black')
axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor="black")
axs[2].set_title("Local thickness histogram")
axs[2].set_xlabel("Local thickness")
axs[2].set_ylabel("Count")
......@@ -661,7 +665,9 @@ def local_thickness(
if len(image.shape) == 3:
if max_projection:
if slice_idx is not None:
log.warning("slice_idx is not used for max_projection. It will be ignored.")
log.warning(
"slice_idx is not used for max_projection. It will be ignored."
)
image = image.max(axis=axis)
image_lt = image_lt.max(axis=axis)
return _local_thickness(image, image_lt, show, figsize)
......@@ -670,17 +676,26 @@ def local_thickness(
slice_idx = image.shape[axis] // 2
elif isinstance(slice_idx, float):
if slice_idx < 0 or slice_idx > 1:
raise ValueError("Values of slice_idx of float type must be between 0 and 1.")
slice_idx = int(slice_idx * image.shape[0])-1
slide_idx_slider=widgets.IntSlider(min=0, max=image.shape[axis]-1, step=1, value=slice_idx, description="Slice index")
raise ValueError(
"Values of slice_idx of float type must be between 0 and 1."
)
slice_idx = int(slice_idx * image.shape[0]) - 1
slide_idx_slider = widgets.IntSlider(
min=0,
max=image.shape[axis] - 1,
step=1,
value=slice_idx,
description="Slice index",
layout=widgets.Layout(width="450px"),
)
widget_obj = widgets.interactive(
_local_thickness,
image=widgets.fixed(image),
image_lt=widgets.fixed(image_lt),
image_lt=widgets.fixed(image_lt),
show=widgets.fixed(True),
figsize=widgets.fixed(figsize),
axis=widgets.fixed(axis),
slice_idx=slide_idx_slider
slice_idx=slide_idx_slider,
)
widget_obj.layout = widgets.Layout(align_items="center")
if show:
......@@ -688,7 +703,11 @@ def local_thickness(
return widget_obj
else:
if max_projection:
log.warning("max_projection is only used for 3D images. It will be ignored.")
log.warning(
"max_projection is only used for 3D images. It will be ignored."
)
if slice_idx is not None:
log.warning("slice_idx is only used for 3D images. It will be ignored.")
return _local_thickness(image, image_lt, show, figsize)
\ No newline at end of file
return _local_thickness(image, image_lt, show, figsize)
import numpy as np
from typing import Optional, Union, Tuple
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ipywidgets as widgets
import logging as log
def vectors(
volume: np.ndarray,
vec: np.ndarray,
axis: int = 0,
slice_idx: Optional[Union[int, float]] = None,
interactive: bool = True,
figsize: Tuple[int, int] = (10, 5),
show: bool = False,
) -> Union[plt.Figure, widgets.interactive]:
"""
Displays a grid of eigenvectors from the structure tensor to visualize the orientation of the structures in the volume.
Args:
volume (np.ndarray): The 3D volume to be sliced.
vec (np.ndarray): The eigenvectors of the structure tensor.
axis (int, optional): The axis along which to visualize the local thickness. Defaults to 0.
slice_idx (int or float, optional): The initial slice to be visualized. The slice index
can afterwards be changed. If value is an integer, it will be the index of the slice
to be visualized. If value is a float between 0 and 1, it will be multiplied by the
number of slices and rounded to the nearest integer. If None, the middle slice will
be used. Defaults to None.
grid_size (int, optional): The size of the grid. Defaults to 10.
interactive (bool, optional): If True, returns an interactive widget. Defaults to True.
figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5).
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
Raises:
ValueError: If the axis to slice along is not 0, 1, or 2.
ValueError: If the slice index is not an integer or a float between 0 and 1.
"""
# Define Grid size limits
min_grid_size = max(1, volume.shape[axis] // 50)
max_grid_size = max(1, volume.shape[axis] // 10)
if max_grid_size <= min_grid_size:
max_grid_size = min_grid_size * 5
# Testing
grid_size = (min_grid_size + max_grid_size) // 2
if grid_size < min_grid_size or grid_size > max_grid_size:
# Adjust grid size as little as possible to be within the limits
grid_size = min(max(min_grid_size, grid_size), max_grid_size)
log.warning(f"Adjusting grid size to {grid_size} as it is out of bounds.")
def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show):
# Create subplots
fig, ax = plt.subplots(1, 2, figsize=figsize, layout="constrained")
# Choose the appropriate slice based on the specified dimension
if axis == 0:
data_slice = volume[slice_idx, :, :]
vectors_slice_x = vec[0, slice_idx, :, :]
vectors_slice_y = vec[1, slice_idx, :, :]
elif axis == 1:
data_slice = volume[:, slice_idx, :]
vectors_slice_x = vec[0, :, slice_idx, :]
vectors_slice_y = vec[2, :, slice_idx, :]
elif axis == 2:
data_slice = volume[:, :, slice_idx]
vectors_slice_x = vec[1, :, :, slice_idx]
vectors_slice_y = vec[2, :, :, slice_idx]
else:
raise ValueError("Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.")
ax[0].imshow(data_slice, cmap=plt.cm.gray)
# Create meshgrid with the correct dimensions
xmesh, ymesh = np.mgrid[0 : data_slice.shape[0], 0 : data_slice.shape[1]]
# Create a slice object for selecting the grid points
g = slice(grid_size // 2, None, grid_size)
# Plot vectors
ax[0].quiver(
ymesh[g, g],
xmesh[g, g],
vectors_slice_x[g, g],
vectors_slice_y[g, g],
color="orange",
angles="xy",
)
ax[0].quiver(
ymesh[g, g],
xmesh[g, g],
-vectors_slice_x[g, g],
-vectors_slice_y[g, g],
color="orange",
angles="xy",
)
# Set title and turn off axis
ax[0].set_title(f"Slice {slice_idx}" if not interactive else None)
ax[0].set_axis_off()
# Orientations histogram
nbins = 36
angles = np.arctan2(vectors_slice_y, vectors_slice_x) # angles from 0 to pi
distribution, bin_edges = np.histogram(angles, bins=nbins, range=(0.0, np.pi))
# Find the bin with the maximum count
peak_bin_idx = np.argmax(distribution)
# Calculate the center of the peak bin
peak_angle_rad = (bin_edges[peak_bin_idx] + bin_edges[peak_bin_idx + 1]) / 2
# Convert the peak angle to degrees
peak_angle_deg = np.degrees(peak_angle_rad)
bin_centers = (np.arange(nbins) + 0.5) * np.pi / nbins # half circle (180 deg)
colors = plt.cm.hsv(bin_centers / np.pi)
ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=colors)
ax[1].set_xlabel("Angle [radians]")
ax[1].set_xlim([0, np.pi])
ax[1].set_aspect(np.pi / ax[1].get_ylim()[1])
ax[1].set_xticks([0, np.pi / 2, np.pi])
ax[1].set_xticklabels(["0", "$\\frac{\\pi}{2}$", "$\\pi$"])
ax[1].set_ylabel("Count")
ax[1].set_title(f"Histogram over angles (peak at {round(peak_angle_deg)}°)")
if show:
plt.show()
plt.close()
return fig
if vec.ndim == 5:
vec = vec[0, ...]
log.warning(
"Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used."
)
if slice_idx is None:
slice_idx = volume.shape[axis] // 2
elif isinstance(slice_idx, float):
if slice_idx < 0 or slice_idx > 1:
raise ValueError(
"Values of slice_idx of float type must be between 0 and 1."
)
slice_idx = int(slice_idx * volume.shape[0]) - 1
if interactive:
slide_idx_slider = widgets.IntSlider(
min=0,
max=volume.shape[axis] - 1,
step=1,
value=slice_idx,
description="Slice index",
layout=widgets.Layout(width="450px"),
)
grid_size_slider = widgets.IntSlider(
min=min_grid_size,
max=max_grid_size,
step=1,
value=grid_size,
description="Grid size",
layout=widgets.Layout(width="450px"),
)
widget_obj = widgets.interactive(
_structure_tensor,
volume=widgets.fixed(volume),
vec=widgets.fixed(vec),
axis=widgets.fixed(axis),
slice_idx=slide_idx_slider,
grid_size=grid_size_slider,
figsize=widgets.fixed(figsize),
show=widgets.fixed(True),
)
# Arrange sliders horizontally
sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider])
widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]])
widget_obj.layout.align_items = "center"
if show:
display(widget_obj)
return widget_obj
else:
return _structure_tensor(volume, vec, axis, slice_idx, figsize, show)
......@@ -22,3 +22,4 @@ dask>=2023.6.0,
k3d>=2.16.1
olefile>=0.46
psutil>=5.9.0
structure-tensor>=0.2.1
\ No newline at end of file
......@@ -60,6 +60,7 @@ setup(
"dask>=2023.6.0",
"k3d>=2.16.1",
"olefile>=0.46",
"psutil>=5.9.0"
"psutil>=5.9.0",
"structure-tensor>=0.2.1"
],
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment