Skip to content
Snippets Groups Projects
Commit 685033d4 authored by s204159's avatar s204159 :sunglasses: Committed by fima
Browse files

Implemented 3D connected components as wrapper class for scipy.ndimage.label

parent 75cb18fe
No related branches found
No related tags found
1 merge request!48Implemented 3D connected components as wrapper class for scipy.ndimage.label
.cache/plugin/social/e4665e2507cbe100841b0959a0e48ef6.png

45.5 KiB

...@@ -11,6 +11,8 @@ build/ ...@@ -11,6 +11,8 @@ build/
# Development and editor files # Development and editor files
.vscode/ .vscode/
.idea/ .idea/
.cache/
.pytest_cache/
*.swp *.swp
*.swo *.swo
*.pyc *.pyc
......
import numpy as np
import pytest
from qim3d.utils.cc import get_3d_cc
@pytest.fixture(scope="module")
def setup_data():
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[1,1,0,0,1,0],
[0,0,0,1,0,0]])
num_components = 4
connected_components = get_3d_cc(components)
return connected_components, components, num_components
def test_connected_components_property(setup_data):
connected_components, _, _ = setup_data
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[2,2,0,0,3,0],
[0,0,0,4,0,0]])
assert np.array_equal(connected_components.connected_components, components)
def test_num_connected_components_property(setup_data):
connected_components, _, num_components = setup_data
assert connected_components.num_connected_components == num_components
def test_get_connected_component_with_index(setup_data):
connected_components, _, _ = setup_data
expected_component = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[0,0,0,0,0,0],
[0,0,0,0,0,0]], dtype=bool)
print(connected_components.get_connected_component(index=1))
print(expected_component)
assert np.array_equal(connected_components.get_connected_component(index=1), expected_component)
def test_get_connected_component_without_index(setup_data):
connected_components, _, _ = setup_data
component = connected_components.get_connected_component()
assert np.any(component)
def test_get_connected_component_with_invalid_index(setup_data):
connected_components, _, num_components = setup_data
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=0)
with pytest.raises(AssertionError):
connected_components.get_connected_component(index=num_components + 1)
\ No newline at end of file
...@@ -3,9 +3,13 @@ import torch ...@@ -3,9 +3,13 @@ import torch
import numpy as np import numpy as np
import ipywidgets as widgets import ipywidgets as widgets
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pytest
from torch import ones
import qim3d import qim3d
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview # unit tests for grid overview
def test_grid_overview(): def test_grid_overview():
random_tuple = (torch.ones(1,256,256),torch.ones(256,256)) random_tuple = (torch.ones(1,256,256),torch.ones(256,256))
......
import qim3d
import pytest import pytest
import qim3d
#unit test for plot_metrics() #unit test for plot_metrics()
def test_plot_metrics(): def test_plot_metrics():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
......
from . import internal_tools
from .models import train_model, model_summary, inference
from .augmentations import Augmentation
from .data import Dataset, prepare_datasets, prepare_dataloaders
#from .doi import get_bibtex, get_reference #from .doi import get_bibtex, get_reference
from . import doi from . import doi, internal_tools
from .augmentations import Augmentation
from .cc import get_3d_cc
from .data import Dataset, prepare_dataloaders, prepare_datasets
from .models import inference, model_summary, train_model
from .system import Memory from .system import Memory
from .img import overlay_rgb_images from .img import overlay_rgb_images
import numpy as np
import torch
from scipy.ndimage import find_objects, label
from qim3d.io.logger import log
class CC:
def __init__(self, connected_components, num_connected_components):
"""
Initializes a ConnectedComponents object.
Args:
connected_components (np.ndarray): The connected components.
num_connected_components (int): The number of connected components.
"""
self._connected_components = connected_components
self.cc_count = num_connected_components
self.shape = connected_components.shape
def __len__(self):
"""
Returns the number of connected components in the object.
"""
return self.cc_count
def get_cc(self, index=None, crop=False):
"""
Get the connected component with the given index, if index is None selects a random component.
Args:
index (int): The index of the connected component.
If none returns all components.
If 'random' returns a random component.
crop (bool): If True, the volume is cropped to the bounding box of the connected component.
Returns:
np.ndarray: The connected component as a binary mask.
"""
if index is None:
volume = self._connected_components
elif index == "random":
index = np.random.randint(1, self.cc_count + 1)
volume = self._connected_components == index
else:
assert 1 <= index <= self.cc_count, "Index out of range. Needs to be in range [1, cc_count]."
volume = self._connected_components == index
if crop:
# As we index get_bounding_box element 0 will be the bounding box for the connected component at index
bbox = self.get_bounding_box(index)[0]
volume = volume[bbox]
return volume
def get_bounding_box(self, index=None):
"""Get the bounding boxes of the connected components.
Args:
index (int, optional): The index of the connected component. If none selects all components.
Returns:
list: A list of bounding boxes.
"""
if index:
assert 1 <= index <= self.cc_count, "Index out of range."
return find_objects(self._connected_components == index)
else:
return find_objects(self._connected_components)
def get_3d_cc(image: np.ndarray | torch.Tensor) -> CC:
""" Get the connected components of a 3D volume.
Args:
image (np.ndarray | torch.Tensor): An array-like object to be labeled. Any non-zero values in `input` are
counted as features and zero values are considered the background.
Returns:
CC: A ConnectedComponents object containing the connected components and the number of connected components.
"""
connected_components, num_connected_components = label(image)
log.info(f"Total number of connected components found: {num_connected_components}")
return CC(connected_components, num_connected_components)
from .visualizations import plot_metrics from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices, slicer, orthogonal from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc
from .k3d import vol from .k3d import vol
""" """
Provides a collection of visualization functions. Provides a collection of visualization functions.
""" """
import math import math
from typing import List, Optional, Union from typing import List, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
import torch
import numpy as np import numpy as np
import torch
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
import qim3d.io
import ipywidgets as widgets import ipywidgets as widgets
from qim3d.io.logger import log from qim3d.io.logger import log
from qim3d.utils.cc import CC
def grid_overview( def grid_overview(
...@@ -226,7 +230,7 @@ def slices( ...@@ -226,7 +230,7 @@ def slices(
img_width: int = 2, img_width: int = 2,
show: bool = False, show: bool = False,
show_position: bool = True, show_position: bool = True,
interpolation: Optional[str] = None, interpolation: Optional[str] = "none",
) -> plt.Figure: ) -> plt.Figure:
"""Displays one or several slices from a 3d volume. """Displays one or several slices from a 3d volume.
...@@ -512,3 +516,71 @@ def orthogonal( ...@@ -512,3 +516,71 @@ def orthogonal(
x_slicer.children[0].description = "X" x_slicer.children[0].description = "X"
return widgets.HBox([z_slicer, y_slicer, x_slicer]) return widgets.HBox([z_slicer, y_slicer, x_slicer])
def plot_cc(
connected_components: CC,
component_indexs: list | tuple = None,
max_cc_to_plot=32,
overlay=None,
crop=False,
show=True,
**kwargs,
) -> list[plt.Figure]:
"""
Plot the connected components of an image.
Parameters:
connected_components (CC): The connected components object.
components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
overlay (optional): Overlay image. Defaults to None.
crop (bool, optional): Whether to crop the overlay image. Defaults to False.
show (bool, optional): Whether to show the figure. Defaults to True.
**kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.
Returns:
figs (list[plt.Figure]): List of figures, if `show=False`.
"""
figs = []
# if no components are given, plot the first max_cc_to_plot=32 components
if component_indexs is None:
if len(connected_components) > max_cc_to_plot:
log.warning(
f"More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components."
)
component_indexs = range(
1, min(max_cc_to_plot + 1, len(connected_components) + 1)
)
for component in component_indexs:
if overlay is not None:
assert (
overlay.shape == connected_components.shape
), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
# plots overlay masked to connected component
if crop:
# Crop the overlay image based on the bounding box of the component
bb = connected_components.get_bounding_box(component)[0]
cc = connected_components.get_cc(component, crop=True)
overlay_crop = overlay[bb]
# use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers
overlay_crop = np.where(cc == 0, 0, overlay_crop)
fig = slices(overlay_crop, show=show, **kwargs)
else:
cc = connected_components.get_cc(component, crop=False)
overlay_crop = np.where(cc == 0, 0, overlay)
fig = slices(overlay_crop, show=show, **kwargs)
else:
# Plot the connected component without overlay
fig = slices(
connected_components.get_cc(component, crop=crop), show=show, **kwargs
)
figs.append(fig)
if not show:
return figs
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment