Skip to content
Snippets Groups Projects
Commit 36085e42 authored by fima's avatar fima :beers:
Browse files

Merge branch 'local_thickness_wrapper' into 'main'

Local thickness wrapper

See merge request !60
parents f2a23b41 386b0e98
No related branches found
No related tags found
1 merge request!60Local thickness wrapper
from .filters import *
\ No newline at end of file
from .filters import *
from .local_thickness import local_thickness
\ No newline at end of file
"""Wrapper for the local thickness function from the localthickness package including visualization functions."""
import localthickness as lt
import numpy as np
from typing import Optional
from skimage.filters import threshold_otsu
from qim3d.io.logger import log
from qim3d.viz import local_thickness as viz_local_thickness
def local_thickness(
image: np.ndarray,
scale: float = 1,
mask: Optional[np.ndarray] = None,
visualize=False,
**viz_kwargs
) -> np.ndarray:
"""Wrapper for the local thickness function from the localthickness package (https://github.com/vedranaa/local-thickness)
Args:
image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
If binary, it will be passed directly to the local thickness function.
If grayscale, it will be binarized using Otsu's method.
scale (float, optional): Downscaling factor, e.g. 0.5 for halving each dim of the image.
Default is 1.
mask (np.ndarray, optional): binary mask of the same size of the image defining parts of the
image to be included in the computation of the local thickness. Default is None.
visualize (bool, optional): Whether to visualize the local thickness. Default is False.
**viz_kwargs: Additional keyword arguments for the visualization function. Only used if visualize=True.
Returns:
local_thickness (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input image/volume.
!!! quote "Reference"
Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW).
<https://doi.org/10.1109/cvprw59228.2023.00456>
```bibtex
@inproceedings{Dahl_2023, title={Fast Local Thickness},
url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456},
DOI={10.1109/cvprw59228.2023.00456},
booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
publisher={IEEE},
author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm},
year={2023},
month=jun }
```
"""
# Check if input is binary
if np.unique(image).size > 2:
# If not, binarize it using Otsu's method, log the threshold and compute the local thickness
threshold = threshold_otsu(image=image)
log.warning(
"Input image is not binary. It will be binarized using Otsu's method with threshold: {}".format(
threshold
)
)
local_thickness = lt.local_thickness(image > threshold, scale=scale, mask=mask)
else:
# If it is binary, compute the local thickness directly
local_thickness = lt.local_thickness(image, scale=scale, mask=mask)
# Visualize the local thickness if requested
if visualize:
display(viz_local_thickness(image, local_thickness, **viz_kwargs))
return local_thickness
import qim3d
import numpy as np
from skimage.draw import disk, ellipsoid
import pytest
def test_local_thickness_2d():
# Create a binary 2D image
shape = (100, 100)
img = np.zeros(shape, dtype=bool)
rr1, cc1 = disk((65, 65), 30, shape=shape)
rr2, cc2 = disk((25, 25), 20, shape=shape)
img[rr1, cc1] = True
img[rr2, cc2] = True
lt_manual = np.zeros(shape)
lt_manual[rr1, cc1] = 30
lt_manual[rr2, cc2] = 20
# Compute local thickness
lt = qim3d.processing.local_thickness(img)
assert np.allclose(lt, lt_manual, rtol=1e-1)
def test_local_thickness_3d():
disk3d = ellipsoid(15,15,15)
# Remove weird border pixels
border_thickness = 2
disk3d = disk3d[border_thickness:-border_thickness, border_thickness:-border_thickness, border_thickness:-border_thickness]
disk3d = np.pad(disk3d, border_thickness, mode='constant')
lt = qim3d.processing.local_thickness(disk3d)
lt_manual = np.zeros(disk3d.shape)
lt_manual[disk3d] = 15
assert np.allclose(lt, lt_manual, rtol=1e-1)
......@@ -20,11 +20,11 @@ def test_connected_components_property(setup_data):
[0,0,0,1,0,0],
[2,2,0,0,3,0],
[0,0,0,4,0,0]])
assert np.array_equal(connected_components.connected_components, components)
assert np.array_equal(connected_components.get_cc(), components)
def test_num_connected_components_property(setup_data):
connected_components, _, num_components = setup_data
assert connected_components.num_connected_components == num_components
assert len(connected_components) == num_components
def test_get_connected_component_with_index(setup_data):
connected_components, _, _ = setup_data
......@@ -32,18 +32,18 @@ def test_get_connected_component_with_index(setup_data):
[0,0,0,1,0,0],
[0,0,0,0,0,0],
[0,0,0,0,0,0]], dtype=bool)
print(connected_components.get_connected_component(index=1))
print(connected_components.get_cc(index=1))
print(expected_component)
assert np.array_equal(connected_components.get_connected_component(index=1), expected_component)
assert np.array_equal(connected_components.get_cc(index=1), expected_component)
def test_get_connected_component_without_index(setup_data):
connected_components, _, _ = setup_data
component = connected_components.get_connected_component()
component = connected_components.get_cc()
assert np.any(component)
def test_get_connected_component_with_invalid_index(setup_data):
connected_components, _, num_components = setup_data
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=0)
connected_components.get_cc(index=0)
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=num_components + 1)
\ No newline at end of file
connected_components.get_cc(index=num_components + 1)
\ No newline at end of file
......@@ -34,7 +34,7 @@ def test_get_local_ip():
return False
local_ip = qim3d.utils.internal_tools.get_local_ip()
assert validate_ip(local_ip) == True
def test_stringify_path1():
......
......@@ -9,6 +9,8 @@ from torch import ones
import qim3d
from qim3d.utils.internal_tools import temp_data
import matplotlib.pyplot as plt
import ipywidgets as widgets
# unit tests for grid overview
def test_grid_overview():
......@@ -213,3 +215,31 @@ def test_orthogonal_slider_description():
# unit tests for local thickness visualization
def test_local_thickness_2d():
blobs = qim3d.examples.blobs_256x256
lt = qim3d.processing.local_thickness(blobs)
fig = qim3d.viz.local_thickness(blobs, lt)
# Assert that returned figure is a matplotlib figure
assert isinstance(fig, plt.Figure)
def test_local_thickness_3d():
fly = qim3d.examples.fly_150x256x256
lt = qim3d.processing.local_thickness(fly)
obj = qim3d.viz.local_thickness(fly, lt)
# Assert that returned object is an interactive widget
assert isinstance(obj, widgets.interactive)
def test_local_thickness_3d_max_projection():
fly = qim3d.examples.fly_150x256x256
lt = qim3d.processing.local_thickness(fly)
fig = qim3d.viz.local_thickness(fly, lt, max_projection=True)
# Assert that returned object is an interactive widget
assert isinstance(fig, plt.Figure)
......@@ -79,7 +79,7 @@ def get_local_ip():
try:
# doesn't even have to be reachable
_socket.connect(("192.255.255.255", 1))
ip_address = _socket.getsockname()
ip_address = _socket.getsockname()[0]
except socket.error:
ip_address = "127.0.0.1"
finally:
......
from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
from .k3d import vol
......@@ -3,7 +3,7 @@ Provides a collection of visualization functions.
"""
import math
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
......@@ -584,3 +584,106 @@ def plot_cc(
return figs
return
def local_thickness(
image: np.ndarray,
image_lt: np.ndarray,
max_projection: bool = False,
axis: int = 0,
slice_idx: Optional[Union[int, float]] = None,
show: bool = False,
figsize: Tuple[int, int] = (15, 5)
) -> Union[plt.Figure, widgets.interactive]:
"""Visualizes the local thickness of a 2D or 3D image.
Args:
image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
image/volume.
max_projection (bool, optional): If True, displays the maximum projection of the local
thickness. Only used for 3D images. Defaults to False.
axis (int, optional): The axis along which to visualize the local thickness.
Unused for 2D images.
Defaults to 0.
slice_idx (int or float, optional): The initial slice to be visualized. The slice index
can afterwards be changed. If value is an integer, it will be the index of the slice
to be visualized. If value is a float between 0 and 1, it will be multiplied by the
number of slices and rounded to the nearest integer. If None, the middle slice will
be used for 3D images. Unused for 2D images. Defaults to None.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5).
Raises:
ValueError: If the slice index is not an integer or a float between 0 and 1.
Returns:
If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.
Example:
image_lt = qim3d.processing.local_thickness(image)
qim3d.viz.local_thickness(image, image_lt, slice_idx=10)
"""
def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None):
if slice_idx is not None:
image = image.take(slice_idx, axis=axis)
image_lt = image_lt.take(slice_idx, axis=axis)
fig, axs = plt.subplots(1, 3, figsize=figsize,layout="constrained")
axs[0].imshow(image, cmap="gray")
axs[0].set_title("Original image")
axs[0].axis("off")
axs[1].imshow(image_lt, cmap="viridis")
axs[1].set_title("Local thickness")
axs[1].axis("off")
plt.colorbar(axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical")
axs[2].hist(image_lt[image_lt>0].ravel(), bins=32, edgecolor='black')
axs[2].set_title("Local thickness histogram")
axs[2].set_xlabel("Local thickness")
axs[2].set_ylabel("Count")
if show:
plt.show()
plt.close()
return fig
# Get the middle slice if the input is 3D
if len(image.shape) == 3:
if max_projection:
if slice_idx is not None:
log.warning("slice_idx is not used for max_projection. It will be ignored.")
image = image.max(axis=axis)
image_lt = image_lt.max(axis=axis)
return _local_thickness(image, image_lt, show, figsize)
else:
if slice_idx is None:
slice_idx = image.shape[axis] // 2
elif isinstance(slice_idx, float):
if slice_idx < 0 or slice_idx > 1:
raise ValueError("Values of slice_idx of float type must be between 0 and 1.")
slice_idx = int(slice_idx * image.shape[0])-1
slide_idx_slider=widgets.IntSlider(min=0, max=image.shape[axis]-1, step=1, value=slice_idx, description="Slice index")
widget_obj = widgets.interactive(
_local_thickness,
image=widgets.fixed(image),
image_lt=widgets.fixed(image_lt),
show=widgets.fixed(True),
figsize=widgets.fixed(figsize),
axis=widgets.fixed(axis),
slice_idx=slide_idx_slider
)
widget_obj.layout = widgets.Layout(align_items="center")
if show:
display(widget_obj)
return widget_obj
else:
if max_projection:
log.warning("max_projection is only used for 3D images. It will be ignored.")
if slice_idx is not None:
log.warning("slice_idx is only used for 3D images. It will be ignored.")
return _local_thickness(image, image_lt, show, figsize)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment