From 4f33d647a900f053e47420d078a55ff76079dc43 Mon Sep 17 00:00:00 2001 From: s184058 <s184058@student.dtu.dk> Date: Wed, 3 Apr 2024 12:55:57 +0200 Subject: [PATCH] Structure tensor wrapper --- docs/processing.md | 2 + qim3d/processing/__init__.py | 1 + qim3d/processing/filters.py | 38 +++- qim3d/processing/structure_tensor.py | 67 +++++++ .../tests/processing/test_structure_tensor.py | 26 +++ qim3d/viz/__init__.py | 1 + qim3d/viz/img.py | 67 ++++--- qim3d/viz/structure_tensor.py | 186 ++++++++++++++++++ requirements.txt | 1 + setup.py | 3 +- 10 files changed, 360 insertions(+), 32 deletions(-) create mode 100644 qim3d/processing/structure_tensor.py create mode 100644 qim3d/tests/processing/test_structure_tensor.py create mode 100644 qim3d/viz/structure_tensor.py diff --git a/docs/processing.md b/docs/processing.md index 8d411707..114e38d8 100644 --- a/docs/processing.md +++ b/docs/processing.md @@ -12,3 +12,5 @@ - append + +::: qim3d.processing.structure_tensor \ No newline at end of file diff --git a/qim3d/processing/__init__.py b/qim3d/processing/__init__.py index be9d29a2..7d251a6b 100644 --- a/qim3d/processing/__init__.py +++ b/qim3d/processing/__init__.py @@ -1,3 +1,4 @@ from .filters import * from .local_thickness import local_thickness +from .structure_tensor import structure_tensor from .detection import * diff --git a/qim3d/processing/filters.py b/qim3d/processing/filters.py index 6b21205a..4390ba69 100644 --- a/qim3d/processing/filters.py +++ b/qim3d/processing/filters.py @@ -4,7 +4,18 @@ from typing import Union, Type import numpy as np from scipy import ndimage -__all__ = ['Gaussian','Median','Maximum','Minimum','Pipeline','gaussian','median','maximum','minimum'] +__all__ = [ + "Gaussian", + "Median", + "Maximum", + "Minimum", + "Pipeline", + "gaussian", + "median", + "maximum", + "minimum", +] + class FilterBase: def __init__(self, *args, **kwargs): @@ -18,6 +29,7 @@ class FilterBase: self.args = args self.kwargs = kwargs + class Gaussian(FilterBase): def __call__(self, input): """ @@ -31,6 +43,7 @@ class Gaussian(FilterBase): """ return gaussian(input, *self.args, **self.kwargs) + class Median(FilterBase): def __call__(self, input): """ @@ -44,6 +57,7 @@ class Median(FilterBase): """ return median(input, **self.kwargs) + class Maximum(FilterBase): def __call__(self, input): """ @@ -57,6 +71,7 @@ class Maximum(FilterBase): """ return maximum(input, **self.kwargs) + class Minimum(FilterBase): def __call__(self, input): """ @@ -70,6 +85,7 @@ class Minimum(FilterBase): """ return minimum(input, **self.kwargs) + class Pipeline: def __init__(self, *args: Type[FilterBase]): """ @@ -90,13 +106,17 @@ class Pipeline: Args: name: A string representing the name or identifier of the filter. fn: An instance of a FilterBase subclass. - + Raises: AssertionError: If `fn` is not an instance of the FilterBase class. """ - if not isinstance(fn,FilterBase): - filter_names = [subclass.__name__ for subclass in FilterBase.__subclasses__()] - raise AssertionError(f'filters should be instances of one of the following classes: {filter_names}') + if not isinstance(fn, FilterBase): + filter_names = [ + subclass.__name__ for subclass in FilterBase.__subclasses__() + ] + raise AssertionError( + f"filters should be instances of one of the following classes: {filter_names}" + ) self.filters[name] = fn def append(self, fn: Type[FilterBase]): @@ -122,6 +142,7 @@ class Pipeline: input = fn(input) return input + def gaussian(vol, *args, **kwargs): """ Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter. @@ -136,6 +157,7 @@ def gaussian(vol, *args, **kwargs): """ return ndimage.gaussian_filter(vol, *args, **kwargs) + def median(vol, **kwargs): """ Applies a median filter to the input volume using scipy.ndimage.median_filter. @@ -149,6 +171,7 @@ def median(vol, **kwargs): """ return ndimage.median_filter(vol, **kwargs) + def maximum(vol, **kwargs): """ Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter. @@ -159,9 +182,10 @@ def maximum(vol, **kwargs): Returns: The filtered image or volume. - """ + """ return ndimage.maximum_filter(vol, **kwargs) + def minimum(vol, **kwargs): """ Applies a minimum filter to the input volume using scipy.ndimage.mainimum_filter. @@ -173,4 +197,4 @@ def minimum(vol, **kwargs): Returns: The filtered image or volume. """ - return ndimage.minimum_filter(vol, **kwargs) \ No newline at end of file + return ndimage.minimum_filter(vol, **kwargs) diff --git a/qim3d/processing/structure_tensor.py b/qim3d/processing/structure_tensor.py new file mode 100644 index 00000000..abd37cc6 --- /dev/null +++ b/qim3d/processing/structure_tensor.py @@ -0,0 +1,67 @@ +"""Wrapper for the structure tensor function from the structure_tensor package""" + +from typing import Tuple +import numpy as np +import structure_tensor as st +from qim3d.viz.structure_tensor import vectors + + +def structure_tensor( + vol: np.ndarray, + sigma: float = 1.0, + rho: float = 6.0, + full: bool = False, + visualize=False, + **viz_kwargs +) -> Tuple[np.ndarray, np.ndarray]: + """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/) + + Args: + vol (np.ndarray): 3D NumPy array representing the volume. + sigma (float): A noise scale, structures smaller than sigma will be removed by smoothing. + rho (float): An integration scale giving the size over the neighborhood in which the orientation is to be analysed. + full: A flag indicating that all three eigenvalues should be returned. Default is False. + visualize (bool, optional): Whether to visualize the structure tensor. Default is False. + **viz_kwargs: Additional keyword arguments for the visualization function. Only used if visualize=True. + Raises: + ValueError: If the input volume is not 3D. + + Returns: + val: An array with shape `(3, *vol.shape)` containing the eigenvalues of the structure tensor. + vec: An array with shape `(3, *vol.shape)` if `full` is `False`, otherwise `(3, 3, *vol.shape)` containing eigenvectors. + + !!! quote "Reference" + Jeppesen, N., et al. "Quantifying effects of manufacturing methods on fiber orientation in unidirectional composites using structure tensor analysis." Composites Part A: Applied Science and Manufacturing 149 (2021): 106541. + <https://doi.org/10.1016/j.compositesa.2021.106541> + + ```bibtex + @article{JEPPESEN2021106541, + title = {Quantifying effects of manufacturing methods on fiber orientation in unidirectional composites using structure tensor analysis}, + journal = {Composites Part A: Applied Science and Manufacturing}, + volume = {149}, + pages = {106541}, + year = {2021}, + issn = {1359-835X}, + doi = {https://doi.org/10.1016/j.compositesa.2021.106541}, + url = {https://www.sciencedirect.com/science/article/pii/S1359835X21002633}, + author = {N. Jeppesen and L.P. Mikkelsen and A.B. Dahl and A.N. Christensen and V.A. Dahl} + } + + ``` + """ + + if vol.ndim != 3: + raise ValueError("The input volume must be 3D") + + # Ensure volume is a float + if vol.dtype != np.float32 and vol.dtype != np.float64: + vol = vol.astype(np.float32) + + + s_vol = st.structure_tensor_3d(vol, sigma, rho) + val, vec = st.eig_special_3d(s_vol, full=full) + + if visualize: + display(vectors(vol, vec, **viz_kwargs)) + + return val, vec diff --git a/qim3d/tests/processing/test_structure_tensor.py b/qim3d/tests/processing/test_structure_tensor.py new file mode 100644 index 00000000..0c6421b0 --- /dev/null +++ b/qim3d/tests/processing/test_structure_tensor.py @@ -0,0 +1,26 @@ +import pytest +import numpy as np +import qim3d + +def test_wrong_ndim(): + img_2d = np.random.rand(50, 50) + with pytest.raises(ValueError, match = "The input volume must be 3D"): + qim3d.processing.structure_tensor(img_2d, 1.5, 1.5) + +def test_structure_tensor(): + volume = np.random.rand(50, 50, 50) + val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5) + assert val.shape == (3, 50, 50, 50) + assert vec.shape == (3, 50, 50, 50) + assert np.all(val[0] <= val[1]) + assert np.all(val[1] <= val[2]) + assert np.all(val[0] <= val[2]) + +def test_structure_tensor_full(): + volume = np.random.rand(50, 50, 50) + val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5, full=True) + assert val.shape == (3, 50, 50, 50) + assert vec.shape == (3, 3, 50, 50, 50) + assert np.all(val[0] <= val[1]) + assert np.all(val[1] <= val[2]) + assert np.all(val[0] <= val[2]) \ No newline at end of file diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index b770c3d0..2057c9d4 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,5 +1,6 @@ from .visualizations import plot_metrics from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness from .k3d import vol +from .structure_tensor import vectors from .colormaps import objects from .detection import circles diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index d3cf0bd4..cc5c417f 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -590,29 +590,30 @@ def plot_cc( return + def local_thickness( - image: np.ndarray, - image_lt: np.ndarray, - max_projection: bool = False, - axis: int = 0, - slice_idx: Optional[Union[int, float]] = None, - show: bool = False, - figsize: Tuple[int, int] = (15, 5) - ) -> Union[plt.Figure, widgets.interactive]: + image: np.ndarray, + image_lt: np.ndarray, + max_projection: bool = False, + axis: int = 0, + slice_idx: Optional[Union[int, float]] = None, + show: bool = False, + figsize: Tuple[int, int] = (15, 5), +) -> Union[plt.Figure, widgets.interactive]: """Visualizes the local thickness of a 2D or 3D image. Args: - image (np.ndarray): 2D or 3D NumPy array representing the image/volume. - image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input + image (np.ndarray): 2D or 3D NumPy array representing the image/volume. + image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input image/volume. max_projection (bool, optional): If True, displays the maximum projection of the local thickness. Only used for 3D images. Defaults to False. - axis (int, optional): The axis along which to visualize the local thickness. + axis (int, optional): The axis along which to visualize the local thickness. Unused for 2D images. Defaults to 0. - slice_idx (int or float, optional): The initial slice to be visualized. The slice index + slice_idx (int or float, optional): The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice - to be visualized. If value is a float between 0 and 1, it will be multiplied by the + to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used for 3D images. Unused for 2D images. Defaults to None. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. @@ -628,12 +629,13 @@ def local_thickness( image_lt = qim3d.processing.local_thickness(image) qim3d.viz.local_thickness(image, image_lt, slice_idx=10) """ + def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None): if slice_idx is not None: image = image.take(slice_idx, axis=axis) image_lt = image_lt.take(slice_idx, axis=axis) - fig, axs = plt.subplots(1, 3, figsize=figsize,layout="constrained") + fig, axs = plt.subplots(1, 3, figsize=figsize, layout="constrained") axs[0].imshow(image, cmap="gray") axs[0].set_title("Original image") @@ -643,9 +645,11 @@ def local_thickness( axs[1].set_title("Local thickness") axs[1].axis("off") - plt.colorbar(axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical") + plt.colorbar( + axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical" + ) - axs[2].hist(image_lt[image_lt>0].ravel(), bins=32, edgecolor='black') + axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor="black") axs[2].set_title("Local thickness histogram") axs[2].set_xlabel("Local thickness") axs[2].set_ylabel("Count") @@ -661,7 +665,9 @@ def local_thickness( if len(image.shape) == 3: if max_projection: if slice_idx is not None: - log.warning("slice_idx is not used for max_projection. It will be ignored.") + log.warning( + "slice_idx is not used for max_projection. It will be ignored." + ) image = image.max(axis=axis) image_lt = image_lt.max(axis=axis) return _local_thickness(image, image_lt, show, figsize) @@ -670,17 +676,26 @@ def local_thickness( slice_idx = image.shape[axis] // 2 elif isinstance(slice_idx, float): if slice_idx < 0 or slice_idx > 1: - raise ValueError("Values of slice_idx of float type must be between 0 and 1.") - slice_idx = int(slice_idx * image.shape[0])-1 - slide_idx_slider=widgets.IntSlider(min=0, max=image.shape[axis]-1, step=1, value=slice_idx, description="Slice index") + raise ValueError( + "Values of slice_idx of float type must be between 0 and 1." + ) + slice_idx = int(slice_idx * image.shape[0]) - 1 + slide_idx_slider = widgets.IntSlider( + min=0, + max=image.shape[axis] - 1, + step=1, + value=slice_idx, + description="Slice index", + layout=widgets.Layout(width="450px"), + ) widget_obj = widgets.interactive( _local_thickness, image=widgets.fixed(image), - image_lt=widgets.fixed(image_lt), + image_lt=widgets.fixed(image_lt), show=widgets.fixed(True), figsize=widgets.fixed(figsize), axis=widgets.fixed(axis), - slice_idx=slide_idx_slider + slice_idx=slide_idx_slider, ) widget_obj.layout = widgets.Layout(align_items="center") if show: @@ -688,7 +703,11 @@ def local_thickness( return widget_obj else: if max_projection: - log.warning("max_projection is only used for 3D images. It will be ignored.") + log.warning( + "max_projection is only used for 3D images. It will be ignored." + ) if slice_idx is not None: log.warning("slice_idx is only used for 3D images. It will be ignored.") - return _local_thickness(image, image_lt, show, figsize) \ No newline at end of file + return _local_thickness(image, image_lt, show, figsize) + + diff --git a/qim3d/viz/structure_tensor.py b/qim3d/viz/structure_tensor.py new file mode 100644 index 00000000..e009bfea --- /dev/null +++ b/qim3d/viz/structure_tensor.py @@ -0,0 +1,186 @@ +import numpy as np +from typing import Optional, Union, Tuple +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec +import ipywidgets as widgets +import logging as log + +def vectors( + volume: np.ndarray, + vec: np.ndarray, + axis: int = 0, + slice_idx: Optional[Union[int, float]] = None, + interactive: bool = True, + figsize: Tuple[int, int] = (10, 5), + show: bool = False, +) -> Union[plt.Figure, widgets.interactive]: + """ + Displays a grid of eigenvectors from the structure tensor to visualize the orientation of the structures in the volume. + + Args: + volume (np.ndarray): The 3D volume to be sliced. + vec (np.ndarray): The eigenvectors of the structure tensor. + axis (int, optional): The axis along which to visualize the local thickness. Defaults to 0. + slice_idx (int or float, optional): The initial slice to be visualized. The slice index + can afterwards be changed. If value is an integer, it will be the index of the slice + to be visualized. If value is a float between 0 and 1, it will be multiplied by the + number of slices and rounded to the nearest integer. If None, the middle slice will + be used. Defaults to None. + grid_size (int, optional): The size of the grid. Defaults to 10. + interactive (bool, optional): If True, returns an interactive widget. Defaults to True. + figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5). + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + + Raises: + ValueError: If the axis to slice along is not 0, 1, or 2. + ValueError: If the slice index is not an integer or a float between 0 and 1. + + """ + + # Define Grid size limits + min_grid_size = max(1, volume.shape[axis] // 50) + max_grid_size = max(1, volume.shape[axis] // 10) + if max_grid_size <= min_grid_size: + max_grid_size = min_grid_size * 5 + + # Testing + grid_size = (min_grid_size + max_grid_size) // 2 + + if grid_size < min_grid_size or grid_size > max_grid_size: + # Adjust grid size as little as possible to be within the limits + grid_size = min(max(min_grid_size, grid_size), max_grid_size) + log.warning(f"Adjusting grid size to {grid_size} as it is out of bounds.") + + def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show): + + # Create subplots + fig, ax = plt.subplots(1, 2, figsize=figsize, layout="constrained") + + # Choose the appropriate slice based on the specified dimension + if axis == 0: + data_slice = volume[slice_idx, :, :] + vectors_slice_x = vec[0, slice_idx, :, :] + vectors_slice_y = vec[1, slice_idx, :, :] + elif axis == 1: + data_slice = volume[:, slice_idx, :] + vectors_slice_x = vec[0, :, slice_idx, :] + vectors_slice_y = vec[2, :, slice_idx, :] + elif axis == 2: + data_slice = volume[:, :, slice_idx] + vectors_slice_x = vec[1, :, :, slice_idx] + vectors_slice_y = vec[2, :, :, slice_idx] + else: + raise ValueError("Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.") + + ax[0].imshow(data_slice, cmap=plt.cm.gray) + + # Create meshgrid with the correct dimensions + xmesh, ymesh = np.mgrid[0 : data_slice.shape[0], 0 : data_slice.shape[1]] + + # Create a slice object for selecting the grid points + g = slice(grid_size // 2, None, grid_size) + + # Plot vectors + ax[0].quiver( + ymesh[g, g], + xmesh[g, g], + vectors_slice_x[g, g], + vectors_slice_y[g, g], + color="orange", + angles="xy", + ) + ax[0].quiver( + ymesh[g, g], + xmesh[g, g], + -vectors_slice_x[g, g], + -vectors_slice_y[g, g], + color="orange", + angles="xy", + ) + + # Set title and turn off axis + ax[0].set_title(f"Slice {slice_idx}" if not interactive else None) + ax[0].set_axis_off() + + # Orientations histogram + nbins = 36 + angles = np.arctan2(vectors_slice_y, vectors_slice_x) # angles from 0 to pi + distribution, bin_edges = np.histogram(angles, bins=nbins, range=(0.0, np.pi)) + + # Find the bin with the maximum count + peak_bin_idx = np.argmax(distribution) + # Calculate the center of the peak bin + peak_angle_rad = (bin_edges[peak_bin_idx] + bin_edges[peak_bin_idx + 1]) / 2 + # Convert the peak angle to degrees + peak_angle_deg = np.degrees(peak_angle_rad) + bin_centers = (np.arange(nbins) + 0.5) * np.pi / nbins # half circle (180 deg) + colors = plt.cm.hsv(bin_centers / np.pi) + ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=colors) + ax[1].set_xlabel("Angle [radians]") + ax[1].set_xlim([0, np.pi]) + ax[1].set_aspect(np.pi / ax[1].get_ylim()[1]) + ax[1].set_xticks([0, np.pi / 2, np.pi]) + ax[1].set_xticklabels(["0", "$\\frac{\\pi}{2}$", "$\\pi$"]) + ax[1].set_ylabel("Count") + ax[1].set_title(f"Histogram over angles (peak at {round(peak_angle_deg)}°)") + + if show: + plt.show() + + plt.close() + + return fig + + if vec.ndim == 5: + vec = vec[0, ...] + log.warning( + "Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used." + ) + + if slice_idx is None: + slice_idx = volume.shape[axis] // 2 + elif isinstance(slice_idx, float): + if slice_idx < 0 or slice_idx > 1: + raise ValueError( + "Values of slice_idx of float type must be between 0 and 1." + ) + slice_idx = int(slice_idx * volume.shape[0]) - 1 + + if interactive: + slide_idx_slider = widgets.IntSlider( + min=0, + max=volume.shape[axis] - 1, + step=1, + value=slice_idx, + description="Slice index", + layout=widgets.Layout(width="450px"), + ) + + grid_size_slider = widgets.IntSlider( + min=min_grid_size, + max=max_grid_size, + step=1, + value=grid_size, + description="Grid size", + layout=widgets.Layout(width="450px"), + ) + + widget_obj = widgets.interactive( + _structure_tensor, + volume=widgets.fixed(volume), + vec=widgets.fixed(vec), + axis=widgets.fixed(axis), + slice_idx=slide_idx_slider, + grid_size=grid_size_slider, + figsize=widgets.fixed(figsize), + show=widgets.fixed(True), + ) + # Arrange sliders horizontally + sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider]) + widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]]) + widget_obj.layout.align_items = "center" + if show: + display(widget_obj) + return widget_obj + else: + return _structure_tensor(volume, vec, axis, slice_idx, figsize, show) diff --git a/requirements.txt b/requirements.txt index 6726c009..3338fefb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ dask>=2023.6.0, k3d>=2.16.1 olefile>=0.46 psutil>=5.9.0 +structure-tensor>=0.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py index e674b6d6..92b4f0ec 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ setup( "dask>=2023.6.0", "k3d>=2.16.1", "olefile>=0.46", - "psutil>=5.9.0" + "psutil>=5.9.0", + "structure-tensor>=0.2.1" ], ) -- GitLab