diff --git a/.cache/plugin/social/30bca4f1e5153363b72f81158901cb46.png b/.cache/plugin/social/30bca4f1e5153363b72f81158901cb46.png deleted file mode 100644 index 58d44bda3fe3304e93cc2ff9c8e45056b562500e..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/30bca4f1e5153363b72f81158901cb46.png and /dev/null differ diff --git a/.cache/plugin/social/77e45edbc302bd44ece0a9657d60c7c5.png b/.cache/plugin/social/77e45edbc302bd44ece0a9657d60c7c5.png deleted file mode 100644 index 01e762eecd8ff2ce4595eca42571cb377b7a7a33..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/77e45edbc302bd44ece0a9657d60c7c5.png and /dev/null differ diff --git a/.cache/plugin/social/9ea86137447637015901362a370a6b39.png b/.cache/plugin/social/9ea86137447637015901362a370a6b39.png deleted file mode 100644 index 13465f19f8231bb70e56999e97b02dfe2511a8fd..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/9ea86137447637015901362a370a6b39.png and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Black.ttf b/.cache/plugin/social/Roboto-Black.ttf deleted file mode 100644 index 0112e7da626ca2f959eca850c806779ba55dbfbd..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Black.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-BlackItalic.ttf b/.cache/plugin/social/Roboto-BlackItalic.ttf deleted file mode 100644 index b2c6aca57bc0d92ab3197d595766bf9285deea00..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-BlackItalic.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Bold.ttf b/.cache/plugin/social/Roboto-Bold.ttf deleted file mode 100644 index 43da14d84ecb949ca5f5e8ecca3a514aa7fe1c7d..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Bold.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-BoldItalic.ttf b/.cache/plugin/social/Roboto-BoldItalic.ttf deleted file mode 100644 index bcfdab4311f2201f45b341b36310e1cdb8051e34..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-BoldItalic.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Italic.ttf b/.cache/plugin/social/Roboto-Italic.ttf deleted file mode 100644 index 1b5eaa361c7306b4246c48497c79475c0e05c5e2..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Italic.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Light.ttf b/.cache/plugin/social/Roboto-Light.ttf deleted file mode 100644 index e7307e72c5e7bced5d36c776d0986bf71b605f15..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Light.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-LightItalic.ttf b/.cache/plugin/social/Roboto-LightItalic.ttf deleted file mode 100644 index 2d277afb231f7613a49d983217c1aba871741433..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-LightItalic.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Medium.ttf b/.cache/plugin/social/Roboto-Medium.ttf deleted file mode 100644 index ac0f908b9c9c73da558b45d65cc5c6094874d3e8..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Medium.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-MediumItalic.ttf b/.cache/plugin/social/Roboto-MediumItalic.ttf deleted file mode 100644 index fc36a4785c50c04c9b18260e4709cda077ed352d..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-MediumItalic.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Regular.ttf b/.cache/plugin/social/Roboto-Regular.ttf deleted file mode 100644 index ddf4bfacb396e97546364ccfeeb9c31dfaea4c25..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Regular.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-Thin.ttf b/.cache/plugin/social/Roboto-Thin.ttf deleted file mode 100644 index 2e0dee6a833c4b568d44ac99727f7e0c17c6eb67..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-Thin.ttf and /dev/null differ diff --git a/.cache/plugin/social/Roboto-ThinItalic.ttf b/.cache/plugin/social/Roboto-ThinItalic.ttf deleted file mode 100644 index 084f9c0f5365952d4d860431a1c2dca147e4a9b5..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/Roboto-ThinItalic.ttf and /dev/null differ diff --git a/.cache/plugin/social/ace003ce21650eb2da0ec3bf97602b8e.png b/.cache/plugin/social/ace003ce21650eb2da0ec3bf97602b8e.png deleted file mode 100644 index c114e225ba74d96f002a3b3efd73ffbafdfc4eab..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/ace003ce21650eb2da0ec3bf97602b8e.png and /dev/null differ diff --git a/.cache/plugin/social/ca78b9b203f9d03bb7b080d1112fde76.png b/.cache/plugin/social/ca78b9b203f9d03bb7b080d1112fde76.png deleted file mode 100644 index 117ff59d6f66c6d670d6f301896c8105465490f7..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/ca78b9b203f9d03bb7b080d1112fde76.png and /dev/null differ diff --git a/.cache/plugin/social/d56e638108974200bf734508f679238c.png b/.cache/plugin/social/d56e638108974200bf734508f679238c.png deleted file mode 100644 index 5beeda96b5b1f285c90b5897117d7145c3d11f68..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/d56e638108974200bf734508f679238c.png and /dev/null differ diff --git a/.cache/plugin/social/d6220b9c296ffb25159d8c73696f9f17.png b/.cache/plugin/social/d6220b9c296ffb25159d8c73696f9f17.png deleted file mode 100644 index 14651cef6ba4891ea61d11d9704a16046c113323..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/d6220b9c296ffb25159d8c73696f9f17.png and /dev/null differ diff --git a/.cache/plugin/social/e28344235ccac04098483abf8cf0980c.png b/.cache/plugin/social/e28344235ccac04098483abf8cf0980c.png deleted file mode 100644 index 3d1ec6f148c5094deb3a71225e89d4c1dff43ab6..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/e28344235ccac04098483abf8cf0980c.png and /dev/null differ diff --git a/.cache/plugin/social/e4665e2507cbe100841b0959a0e48ef6.png b/.cache/plugin/social/e4665e2507cbe100841b0959a0e48ef6.png deleted file mode 100644 index 32b0ba5e450506b46d2aea1abecc4113a7615234..0000000000000000000000000000000000000000 Binary files a/.cache/plugin/social/e4665e2507cbe100841b0959a0e48ef6.png and /dev/null differ diff --git a/.gitignore b/.gitignore index 1bb0e882a78f648f548b25151e4493d23e700fbe..ac25d6c8ab51b5a04ac6bd7d8291d08fa86d2215 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,8 @@ build/ # Development and editor files .vscode/ .idea/ +.cache/ +.pytest_cache/ *.swp *.swo *.pyc diff --git a/qim3d/tests/utils/test_connected_components.py b/qim3d/tests/utils/test_connected_components.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f15359f62b3fd330f027501a78e4df92e9cd52 --- /dev/null +++ b/qim3d/tests/utils/test_connected_components.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest + +from qim3d.utils.cc import get_3d_cc + + +@pytest.fixture(scope="module") +def setup_data(): + components = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [1,1,0,0,1,0], + [0,0,0,1,0,0]]) + num_components = 4 + connected_components = get_3d_cc(components) + return connected_components, components, num_components + +def test_connected_components_property(setup_data): + connected_components, _, _ = setup_data + components = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [2,2,0,0,3,0], + [0,0,0,4,0,0]]) + assert np.array_equal(connected_components.connected_components, components) + +def test_num_connected_components_property(setup_data): + connected_components, _, num_components = setup_data + assert connected_components.num_connected_components == num_components + +def test_get_connected_component_with_index(setup_data): + connected_components, _, _ = setup_data + expected_component = np.array([[0,0,1,1,0,0], + [0,0,0,1,0,0], + [0,0,0,0,0,0], + [0,0,0,0,0,0]], dtype=bool) + print(connected_components.get_connected_component(index=1)) + print(expected_component) + assert np.array_equal(connected_components.get_connected_component(index=1), expected_component) + +def test_get_connected_component_without_index(setup_data): + connected_components, _, _ = setup_data + component = connected_components.get_connected_component() + assert np.any(component) + +def test_get_connected_component_with_invalid_index(setup_data): + connected_components, _, num_components = setup_data + with pytest.raises(AssertionError): + connected_components.get_connected_component(index=0) + with pytest.raises(AssertionError): + connected_components.get_connected_component(index=num_components + 1) \ No newline at end of file diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 82dd98829e7a2f2c224398732b5982b5a89f0b2c..d7e372c888effd1bc4c5b6dc3ae113d76b6fab64 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -3,9 +3,13 @@ import torch import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt +import pytest +from torch import ones + import qim3d from qim3d.utils.internal_tools import temp_data + # unit tests for grid overview def test_grid_overview(): random_tuple = (torch.ones(1,256,256),torch.ones(256,256)) diff --git a/qim3d/tests/viz/test_visualizations.py b/qim3d/tests/viz/test_visualizations.py index 75e84eb9fe06cffe7fa62a600a4821ebf8265fae..c9a8beceac2ae601b03adbc77d473fd339b577d1 100644 --- a/qim3d/tests/viz/test_visualizations.py +++ b/qim3d/tests/viz/test_visualizations.py @@ -1,6 +1,8 @@ -import qim3d import pytest +import qim3d + + #unit test for plot_metrics() def test_plot_metrics(): metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py index f3cd4eb1506bc7dba1cfe1698bdc5a0f1bbea9d7..7361d8cdd03af2d818a464d498ac6c854946c7b6 100644 --- a/qim3d/utils/__init__.py +++ b/qim3d/utils/__init__.py @@ -1,8 +1,8 @@ -from . import internal_tools -from .models import train_model, model_summary, inference -from .augmentations import Augmentation -from .data import Dataset, prepare_datasets, prepare_dataloaders #from .doi import get_bibtex, get_reference -from . import doi +from . import doi, internal_tools +from .augmentations import Augmentation +from .cc import get_3d_cc +from .data import Dataset, prepare_dataloaders, prepare_datasets +from .models import inference, model_summary, train_model from .system import Memory -from .img import overlay_rgb_images \ No newline at end of file +from .img import overlay_rgb_images diff --git a/qim3d/utils/cc.py b/qim3d/utils/cc.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc7cee7f18411ece830e0bf8f72c87bbcd47ff0 --- /dev/null +++ b/qim3d/utils/cc.py @@ -0,0 +1,87 @@ +import numpy as np +import torch +from scipy.ndimage import find_objects, label + +from qim3d.io.logger import log + + +class CC: + def __init__(self, connected_components, num_connected_components): + """ + Initializes a ConnectedComponents object. + + Args: + connected_components (np.ndarray): The connected components. + num_connected_components (int): The number of connected components. + """ + self._connected_components = connected_components + self.cc_count = num_connected_components + + self.shape = connected_components.shape + + def __len__(self): + """ + Returns the number of connected components in the object. + """ + return self.cc_count + + def get_cc(self, index=None, crop=False): + """ + Get the connected component with the given index, if index is None selects a random component. + + Args: + index (int): The index of the connected component. + If none returns all components. + If 'random' returns a random component. + crop (bool): If True, the volume is cropped to the bounding box of the connected component. + + Returns: + np.ndarray: The connected component as a binary mask. + """ + if index is None: + volume = self._connected_components + elif index == "random": + index = np.random.randint(1, self.cc_count + 1) + volume = self._connected_components == index + else: + assert 1 <= index <= self.cc_count, "Index out of range. Needs to be in range [1, cc_count]." + volume = self._connected_components == index + + if crop: + # As we index get_bounding_box element 0 will be the bounding box for the connected component at index + bbox = self.get_bounding_box(index)[0] + volume = volume[bbox] + + return volume + + def get_bounding_box(self, index=None): + """Get the bounding boxes of the connected components. + + Args: + index (int, optional): The index of the connected component. If none selects all components. + + Returns: + list: A list of bounding boxes. + """ + + if index: + assert 1 <= index <= self.cc_count, "Index out of range." + return find_objects(self._connected_components == index) + else: + return find_objects(self._connected_components) + + +def get_3d_cc(image: np.ndarray | torch.Tensor) -> CC: + """ Get the connected components of a 3D volume. + + Args: + image (np.ndarray | torch.Tensor): An array-like object to be labeled. Any non-zero values in `input` are + counted as features and zero values are considered the background. + + Returns: + CC: A ConnectedComponents object containing the connected components and the number of connected components. + """ + connected_components, num_connected_components = label(image) + log.info(f"Total number of connected components found: {num_connected_components}") + + return CC(connected_components, num_connected_components) diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 5f1a2614e45ffb5495eb9f9f34c9c2f3f76a8105..0d02c563ab1ecdc6b74b6c77af45d16082b354cc 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,3 +1,3 @@ from .visualizations import plot_metrics -from .img import grid_pred, grid_overview, slices, slicer, orthogonal +from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc from .k3d import vol diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 9ac1f29f502af066a701aeb489426872d9b4cb0c..033d8f8590834a0d616f9aaa25be98d733a9cbf8 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -1,15 +1,19 @@ """ Provides a collection of visualization functions. """ + import math from typing import List, Optional, Union import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap -from matplotlib import colormaps -import torch import numpy as np +import torch +from matplotlib import colormaps +from matplotlib.colors import LinearSegmentedColormap + +import qim3d.io import ipywidgets as widgets from qim3d.io.logger import log +from qim3d.utils.cc import CC def grid_overview( @@ -226,7 +230,7 @@ def slices( img_width: int = 2, show: bool = False, show_position: bool = True, - interpolation: Optional[str] = None, + interpolation: Optional[str] = "none", ) -> plt.Figure: """Displays one or several slices from a 3d volume. @@ -512,3 +516,71 @@ def orthogonal( x_slicer.children[0].description = "X" return widgets.HBox([z_slicer, y_slicer, x_slicer]) + + +def plot_cc( + connected_components: CC, + component_indexs: list | tuple = None, + max_cc_to_plot=32, + overlay=None, + crop=False, + show=True, + **kwargs, +) -> list[plt.Figure]: + """ + Plot the connected components of an image. + + Parameters: + connected_components (CC): The connected components object. + components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None. + max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32. + overlay (optional): Overlay image. Defaults to None. + crop (bool, optional): Whether to crop the overlay image. Defaults to False. + show (bool, optional): Whether to show the figure. Defaults to True. + **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`. + + Returns: + figs (list[plt.Figure]): List of figures, if `show=False`. + """ + figs = [] + # if no components are given, plot the first max_cc_to_plot=32 components + if component_indexs is None: + if len(connected_components) > max_cc_to_plot: + log.warning( + f"More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components." + ) + component_indexs = range( + 1, min(max_cc_to_plot + 1, len(connected_components) + 1) + ) + + for component in component_indexs: + if overlay is not None: + assert ( + overlay.shape == connected_components.shape + ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." + + # plots overlay masked to connected component + if crop: + # Crop the overlay image based on the bounding box of the component + bb = connected_components.get_bounding_box(component)[0] + cc = connected_components.get_cc(component, crop=True) + overlay_crop = overlay[bb] + # use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers + overlay_crop = np.where(cc == 0, 0, overlay_crop) + fig = slices(overlay_crop, show=show, **kwargs) + else: + cc = connected_components.get_cc(component, crop=False) + overlay_crop = np.where(cc == 0, 0, overlay) + fig = slices(overlay_crop, show=show, **kwargs) + else: + # Plot the connected component without overlay + fig = slices( + connected_components.get_cc(component, crop=crop), show=show, **kwargs + ) + + figs.append(fig) + + if not show: + return figs + + return