diff --git a/qim3d/processing/__init__.py b/qim3d/processing/__init__.py index ac51f92fb4028cf960c6d5d4d73255856d2c9c00..771b7222bed1aa8e05ce4481c86758123ab26de0 100644 --- a/qim3d/processing/__init__.py +++ b/qim3d/processing/__init__.py @@ -1 +1,2 @@ -from .filters import * \ No newline at end of file +from .filters import * +from .local_thickness import local_thickness \ No newline at end of file diff --git a/qim3d/processing/local_thickness.py b/qim3d/processing/local_thickness.py new file mode 100644 index 0000000000000000000000000000000000000000..e581c01376d96c428dcb1b39b2fc47a56f9898f2 --- /dev/null +++ b/qim3d/processing/local_thickness.py @@ -0,0 +1,70 @@ +"""Wrapper for the local thickness function from the localthickness package including visualization functions.""" + +import localthickness as lt +import numpy as np +from typing import Optional +from skimage.filters import threshold_otsu +from qim3d.io.logger import log +from qim3d.viz import local_thickness as viz_local_thickness + + +def local_thickness( + image: np.ndarray, + scale: float = 1, + mask: Optional[np.ndarray] = None, + visualize=False, + **viz_kwargs +) -> np.ndarray: + """Wrapper for the local thickness function from the localthickness package (https://github.com/vedranaa/local-thickness) + + Args: + image (np.ndarray): 2D or 3D NumPy array representing the image/volume. + If binary, it will be passed directly to the local thickness function. + If grayscale, it will be binarized using Otsu's method. + scale (float, optional): Downscaling factor, e.g. 0.5 for halving each dim of the image. + Default is 1. + mask (np.ndarray, optional): binary mask of the same size of the image defining parts of the + image to be included in the computation of the local thickness. Default is None. + visualize (bool, optional): Whether to visualize the local thickness. Default is False. + **viz_kwargs: Additional keyword arguments for the visualization function. Only used if visualize=True. + + Returns: + local_thickness (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input image/volume. + + + !!! quote "Reference" + Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW). + <https://doi.org/10.1109/cvprw59228.2023.00456> + + ```bibtex + @inproceedings{Dahl_2023, title={Fast Local Thickness}, + url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456}, + DOI={10.1109/cvprw59228.2023.00456}, + booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)}, + publisher={IEEE}, + author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm}, + year={2023}, + month=jun } + + ``` + """ + + # Check if input is binary + if np.unique(image).size > 2: + # If not, binarize it using Otsu's method, log the threshold and compute the local thickness + threshold = threshold_otsu(image=image) + log.warning( + "Input image is not binary. It will be binarized using Otsu's method with threshold: {}".format( + threshold + ) + ) + local_thickness = lt.local_thickness(image > threshold, scale=scale, mask=mask) + else: + # If it is binary, compute the local thickness directly + local_thickness = lt.local_thickness(image, scale=scale, mask=mask) + + # Visualize the local thickness if requested + if visualize: + display(viz_local_thickness(image, local_thickness, **viz_kwargs)) + + return local_thickness diff --git a/qim3d/tests/processing/test_local_thickness.py b/qim3d/tests/processing/test_local_thickness.py new file mode 100644 index 0000000000000000000000000000000000000000..edc18c4d31d422c8f85b1cf966f6e1a02a729474 --- /dev/null +++ b/qim3d/tests/processing/test_local_thickness.py @@ -0,0 +1,37 @@ +import qim3d +import numpy as np +from skimage.draw import disk, ellipsoid +import pytest + +def test_local_thickness_2d(): + # Create a binary 2D image + shape = (100, 100) + img = np.zeros(shape, dtype=bool) + rr1, cc1 = disk((65, 65), 30, shape=shape) + rr2, cc2 = disk((25, 25), 20, shape=shape) + img[rr1, cc1] = True + img[rr2, cc2] = True + + lt_manual = np.zeros(shape) + lt_manual[rr1, cc1] = 30 + lt_manual[rr2, cc2] = 20 + + # Compute local thickness + lt = qim3d.processing.local_thickness(img) + + assert np.allclose(lt, lt_manual, rtol=1e-1) + +def test_local_thickness_3d(): + disk3d = ellipsoid(15,15,15) + + # Remove weird border pixels + border_thickness = 2 + disk3d = disk3d[border_thickness:-border_thickness, border_thickness:-border_thickness, border_thickness:-border_thickness] + disk3d = np.pad(disk3d, border_thickness, mode='constant') + + lt = qim3d.processing.local_thickness(disk3d) + + lt_manual = np.zeros(disk3d.shape) + lt_manual[disk3d] = 15 + + assert np.allclose(lt, lt_manual, rtol=1e-1) diff --git a/qim3d/tests/utils/test_connected_components.py b/qim3d/tests/utils/test_connected_components.py index f7f15359f62b3fd330f027501a78e4df92e9cd52..cb2b21089b1db5b2517a58524aa5df264809ce06 100644 --- a/qim3d/tests/utils/test_connected_components.py +++ b/qim3d/tests/utils/test_connected_components.py @@ -20,11 +20,11 @@ def test_connected_components_property(setup_data): [0,0,0,1,0,0], [2,2,0,0,3,0], [0,0,0,4,0,0]]) - assert np.array_equal(connected_components.connected_components, components) + assert np.array_equal(connected_components.get_cc(), components) def test_num_connected_components_property(setup_data): connected_components, _, num_components = setup_data - assert connected_components.num_connected_components == num_components + assert len(connected_components) == num_components def test_get_connected_component_with_index(setup_data): connected_components, _, _ = setup_data @@ -32,18 +32,18 @@ def test_get_connected_component_with_index(setup_data): [0,0,0,1,0,0], [0,0,0,0,0,0], [0,0,0,0,0,0]], dtype=bool) - print(connected_components.get_connected_component(index=1)) + print(connected_components.get_cc(index=1)) print(expected_component) - assert np.array_equal(connected_components.get_connected_component(index=1), expected_component) + assert np.array_equal(connected_components.get_cc(index=1), expected_component) def test_get_connected_component_without_index(setup_data): connected_components, _, _ = setup_data - component = connected_components.get_connected_component() + component = connected_components.get_cc() assert np.any(component) def test_get_connected_component_with_invalid_index(setup_data): connected_components, _, num_components = setup_data with pytest.raises(AssertionError): - connected_components.get_connected_component(index=0) + connected_components.get_cc(index=0) with pytest.raises(AssertionError): - connected_components.get_connected_component(index=num_components + 1) \ No newline at end of file + connected_components.get_cc(index=num_components + 1) \ No newline at end of file diff --git a/qim3d/tests/utils/test_internal_tools.py b/qim3d/tests/utils/test_internal_tools.py index bb239ca703669273933a888638dd7088fd0a7078..ed318e7a5c6a93ca2ca0caba2085bb3076e62add 100644 --- a/qim3d/tests/utils/test_internal_tools.py +++ b/qim3d/tests/utils/test_internal_tools.py @@ -34,7 +34,7 @@ def test_get_local_ip(): return False local_ip = qim3d.utils.internal_tools.get_local_ip() - + assert validate_ip(local_ip) == True def test_stringify_path1(): diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index d7e372c888effd1bc4c5b6dc3ae113d76b6fab64..2ac9eed9bb09418f888b9ec9bf0e3a0bf54d8979 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -9,6 +9,8 @@ from torch import ones import qim3d from qim3d.utils.internal_tools import temp_data +import matplotlib.pyplot as plt +import ipywidgets as widgets # unit tests for grid overview def test_grid_overview(): @@ -213,3 +215,31 @@ def test_orthogonal_slider_description(): + + +# unit tests for local thickness visualization +def test_local_thickness_2d(): + blobs = qim3d.examples.blobs_256x256 + lt = qim3d.processing.local_thickness(blobs) + fig = qim3d.viz.local_thickness(blobs, lt) + + # Assert that returned figure is a matplotlib figure + assert isinstance(fig, plt.Figure) + +def test_local_thickness_3d(): + fly = qim3d.examples.fly_150x256x256 + lt = qim3d.processing.local_thickness(fly) + obj = qim3d.viz.local_thickness(fly, lt) + + # Assert that returned object is an interactive widget + assert isinstance(obj, widgets.interactive) + +def test_local_thickness_3d_max_projection(): + fly = qim3d.examples.fly_150x256x256 + lt = qim3d.processing.local_thickness(fly) + fig = qim3d.viz.local_thickness(fly, lt, max_projection=True) + + # Assert that returned object is an interactive widget + assert isinstance(fig, plt.Figure) + + diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index 73324998041a99b0dc30f006a10312cdf372af6f..9149b2cd01d4f28b56f64f65b0d3eaea2e6c3db2 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -79,7 +79,7 @@ def get_local_ip(): try: # doesn't even have to be reachable _socket.connect(("192.255.255.255", 1)) - ip_address = _socket.getsockname() + ip_address = _socket.getsockname()[0] except socket.error: ip_address = "127.0.0.1" finally: diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 0d02c563ab1ecdc6b74b6c77af45d16082b354cc..eedcc00b20c7b8f447e47db5540ab169502cb5a1 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,3 +1,3 @@ from .visualizations import plot_metrics -from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc +from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness from .k3d import vol diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 033d8f8590834a0d616f9aaa25be98d733a9cbf8..cfa7022391e8cbc27cdd6627bab171881d3b79b8 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -3,7 +3,7 @@ Provides a collection of visualization functions. """ import math -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import matplotlib.pyplot as plt import numpy as np import torch @@ -584,3 +584,106 @@ def plot_cc( return figs return + +def local_thickness( + image: np.ndarray, + image_lt: np.ndarray, + max_projection: bool = False, + axis: int = 0, + slice_idx: Optional[Union[int, float]] = None, + show: bool = False, + figsize: Tuple[int, int] = (15, 5) + ) -> Union[plt.Figure, widgets.interactive]: + """Visualizes the local thickness of a 2D or 3D image. + + Args: + image (np.ndarray): 2D or 3D NumPy array representing the image/volume. + image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input + image/volume. + max_projection (bool, optional): If True, displays the maximum projection of the local + thickness. Only used for 3D images. Defaults to False. + axis (int, optional): The axis along which to visualize the local thickness. + Unused for 2D images. + Defaults to 0. + slice_idx (int or float, optional): The initial slice to be visualized. The slice index + can afterwards be changed. If value is an integer, it will be the index of the slice + to be visualized. If value is a float between 0 and 1, it will be multiplied by the + number of slices and rounded to the nearest integer. If None, the middle slice will + be used for 3D images. Unused for 2D images. Defaults to None. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5). + + Raises: + ValueError: If the slice index is not an integer or a float between 0 and 1. + + Returns: + If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure. + + Example: + image_lt = qim3d.processing.local_thickness(image) + qim3d.viz.local_thickness(image, image_lt, slice_idx=10) + """ + def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None): + if slice_idx is not None: + image = image.take(slice_idx, axis=axis) + image_lt = image_lt.take(slice_idx, axis=axis) + + fig, axs = plt.subplots(1, 3, figsize=figsize,layout="constrained") + + axs[0].imshow(image, cmap="gray") + axs[0].set_title("Original image") + axs[0].axis("off") + + axs[1].imshow(image_lt, cmap="viridis") + axs[1].set_title("Local thickness") + axs[1].axis("off") + + plt.colorbar(axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical") + + axs[2].hist(image_lt[image_lt>0].ravel(), bins=32, edgecolor='black') + axs[2].set_title("Local thickness histogram") + axs[2].set_xlabel("Local thickness") + axs[2].set_ylabel("Count") + + if show: + plt.show() + + plt.close() + + return fig + + # Get the middle slice if the input is 3D + if len(image.shape) == 3: + if max_projection: + if slice_idx is not None: + log.warning("slice_idx is not used for max_projection. It will be ignored.") + image = image.max(axis=axis) + image_lt = image_lt.max(axis=axis) + return _local_thickness(image, image_lt, show, figsize) + else: + if slice_idx is None: + slice_idx = image.shape[axis] // 2 + elif isinstance(slice_idx, float): + if slice_idx < 0 or slice_idx > 1: + raise ValueError("Values of slice_idx of float type must be between 0 and 1.") + slice_idx = int(slice_idx * image.shape[0])-1 + slide_idx_slider=widgets.IntSlider(min=0, max=image.shape[axis]-1, step=1, value=slice_idx, description="Slice index") + widget_obj = widgets.interactive( + _local_thickness, + image=widgets.fixed(image), + image_lt=widgets.fixed(image_lt), + show=widgets.fixed(True), + figsize=widgets.fixed(figsize), + axis=widgets.fixed(axis), + slice_idx=slide_idx_slider + ) + widget_obj.layout = widgets.Layout(align_items="center") + if show: + display(widget_obj) + return widget_obj + else: + if max_projection: + log.warning("max_projection is only used for 3D images. It will be ignored.") + if slice_idx is not None: + log.warning("slice_idx is only used for 3D images. It will be ignored.") + return _local_thickness(image, image_lt, show, figsize) \ No newline at end of file