diff --git a/docs/notebooks/blob_detection.ipynb b/docs/notebooks/blob_detection.ipynb index 86f6d94691fed612a44439db08d3382fa2d47c6b..d23662f96887ba47594d1cf426a7aa8164655b8e 100644 --- a/docs/notebooks/blob_detection.ipynb +++ b/docs/notebooks/blob_detection.ipynb @@ -73,13 +73,13 @@ "cement = qim3d.examples.cement_128x128x128\n", "\n", "# Visualize slices of the original cement volume\n", - "qim3d.viz.slices(cement, n_slices = 5, show = True)\n", + "qim3d.viz.slices_grid(cement, n_slices = 5, show = True)\n", "\n", "# Apply Gaussian filter to the cement volume\n", "cement_filtered = qim3d.processing.gaussian(cement, sigma = 2)\n", "\n", "# Visualize slices of the filtered cement volume\n", - "qim3d.viz.slices(cement_filtered)" + "qim3d.viz.slices_grid(cement_filtered)" ] }, { diff --git a/docs/notebooks/local_thickness.ipynb b/docs/notebooks/local_thickness.ipynb index 717c854639f539f7035e42efdcf27048e180189b..ee405d830d89bad317b4871764bb1b3a7a0077f0 100644 --- a/docs/notebooks/local_thickness.ipynb +++ b/docs/notebooks/local_thickness.ipynb @@ -179,13 +179,13 @@ "cement = qim3d.examples.cement_128x128x128\n", "\n", "# Visualize slices of the original cement volume\n", - "qim3d.viz.slices(cement, n_slices = 5, show = True)\n", + "qim3d.viz.slices_grid(cement, n_slices = 5, show = True)\n", "\n", "# Apply Gaussian filter to the cement volume\n", "cement_filtered = qim3d.processing.gaussian(cement, sigma = 2)\n", "\n", "# Visualize slices of the filtered cement volume\n", - "qim3d.viz.slices(cement_filtered)" + "qim3d.viz.slices_grid(cement_filtered)" ] }, { diff --git a/docs/releases.md b/docs/releases.md index ff67814f84c819bfd5923c13182adb85a307cfea..cee2d6a410e228ecd75e4d9a1f1b4d84e26e146f 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -114,7 +114,7 @@ This version focus on the increased usability of the `qim3d` library - Online documentation available at [https://platform.qim.dk/qim3d](https://platform.qim.dk/qim3d) - Virtual stacks also available for `txm` files - Updated GUI launch pipeline -- New functionalities for `qim3d.viz.slices` +- New functionalities for `qim3d.viz.slices_grid` - Introduction of `qim3d.processing.filters` 🎉 - Introduction of `qim3d.viz.vol` 🎉 @@ -140,7 +140,7 @@ Includes new develoments toward the usability of the library, as well as its int - For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality. - Improved data loader - Now the extensions `tif`, `h5` and `txm` are supported. -- Added `qim3d.viz.slices` for easy slice visualization. +- Added `qim3d.viz.slices_grid` for easy slice visualization. - U-net model creation - Model availabe from `qim3d.models.UNet` - Data augmentation class at `qim3d.utils.Augmentation` diff --git a/qim3d/__init__.py b/qim3d/__init__.py index 07782e4e1a6ec54d7f3f605268d707ceeab557a7..b08fa27ce5a67cf3b378936b768110bf3467e7ae 100644 --- a/qim3d/__init__.py +++ b/qim3d/__init__.py @@ -8,7 +8,7 @@ Documentation available at https://platform.qim.dk/qim3d/ """ -__version__ = "0.4.5" +__version__ = "0.9.0" import importlib as _importlib @@ -33,18 +33,23 @@ class _LazyLoader: # List of submodules _submodules = [ - "examples", - "generate", - "gui", - "io", - "models", - "processing", - "tests", - "utils", - "viz", - "cli", + 'examples', + 'generate', + 'gui', + 'io', + 'ml', + 'processing', + 'tests', + 'utils', + 'viz', + 'cli', + 'filters', + 'segmentation', + 'mesh', + 'features', + 'operations', ] # Creating lazy loaders for each submodule for submodule in _submodules: - globals()[submodule] = _LazyLoader(f"qim3d.{submodule}") + globals()[submodule] = _LazyLoader(f'qim3d.{submodule}') diff --git a/qim3d/cli.py b/qim3d/cli/__init__.py similarity index 98% rename from qim3d/cli.py rename to qim3d/cli/__init__.py index 9f6e1ae03f436ef75141f823132403c5b8ec28a7..ef60bdb89d58107a08a7e12a48aa7b42b9654e25 100644 --- a/qim3d/cli.py +++ b/qim3d/cli/__init__.py @@ -177,7 +177,7 @@ def main(): elif args.method == "k3d": volume = qim3d.io.load(str(args.source)) print("\nGenerating k3d plot...") - qim3d.viz.vol(volume, show=False, save=str(args.destination)) + qim3d.viz.volumetric(volume, show=False, save=str(args.destination)) print(f"Done, plot available at <{args.destination}>") if not args.no_browser: print("Opening in default browser...") diff --git a/qim3d/detection/__init__.py b/qim3d/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..302d7909203cfc7e6ccf1223d0a7b1e5c29be6bd --- /dev/null +++ b/qim3d/detection/__init__.py @@ -0,0 +1 @@ +from qim3d.detection._common_detection_methods import * \ No newline at end of file diff --git a/qim3d/processing/detection.py b/qim3d/detection/_common_detection_methods.py similarity index 98% rename from qim3d/processing/detection.py rename to qim3d/detection/_common_detection_methods.py index e3650cf111b71427e5e673b952244f9064e74c36..ff57026a3529c58b8e5fc984ab5e5679b804f285 100644 --- a/qim3d/processing/detection.py +++ b/qim3d/detection/_common_detection_methods.py @@ -1,7 +1,9 @@ """ Blob detection using Difference of Gaussian (DoG) method """ import numpy as np -from qim3d.utils.logger import log +from qim3d.utils._logger import log + +__all__ = ["blob_detection"] def blob_detection( vol: np.ndarray, diff --git a/qim3d/examples/__init__.py b/qim3d/examples/__init__.py index cf57469608720b5ca8c7c7b6b77c6f4cc8cd1f1b..fe15a92e8af1eb9a461c8196084ad358fe35b840 100644 --- a/qim3d/examples/__init__.py +++ b/qim3d/examples/__init__.py @@ -1,7 +1,7 @@ """ Example images for testing and demonstration purposes. """ from pathlib import Path as _Path -from qim3d.utils.logger import log as _log +from qim3d.utils._logger import log as _log from qim3d.io import load as _load # Save the original log level and set to ERROR diff --git a/qim3d/features/__init__.py b/qim3d/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..349cedcb8d92943f5148029aeabc14c5be9dc5bb --- /dev/null +++ b/qim3d/features/__init__.py @@ -0,0 +1 @@ +from ._common_features_methods import volume, area, sphericity diff --git a/qim3d/processing/features.py b/qim3d/features/_common_features_methods.py similarity index 97% rename from qim3d/processing/features.py rename to qim3d/features/_common_features_methods.py index 9dd087028375eab5dfb949525cc79f5d20d924ba..e69b6cba63867c27cbedeba05999e201917f77b7 100644 --- a/qim3d/processing/features.py +++ b/qim3d/features/_common_features_methods.py @@ -1,6 +1,6 @@ import numpy as np import qim3d.processing -from qim3d.utils.logger import log +from qim3d.utils._logger import log import trimesh import qim3d @@ -44,7 +44,7 @@ def volume(obj, **mesh_kwargs) -> float: """ if isinstance(obj, np.ndarray): log.info("Converting volume to mesh.") - obj = qim3d.processing.create_mesh(obj, **mesh_kwargs) + obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) return obj.volume diff --git a/qim3d/filters/__init__.py b/qim3d/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a166fca5343a186b51cbe33f0af5144696c7237 --- /dev/null +++ b/qim3d/filters/__init__.py @@ -0,0 +1 @@ +from ._common_filter_methods import * \ No newline at end of file diff --git a/qim3d/processing/filters.py b/qim3d/filters/_common_filter_methods.py similarity index 98% rename from qim3d/processing/filters.py rename to qim3d/filters/_common_filter_methods.py index 3c3c1a5b02e432e53a0fe1358e9db81a306272c4..f2d421979259e1f676a54f5ec9512cf2cb6c364e 100644 --- a/qim3d/processing/filters.py +++ b/qim3d/filters/_common_filter_methods.py @@ -8,7 +8,7 @@ from skimage import morphology import dask.array as da import dask_image.ndfilters as dask_ndfilters -from qim3d.utils.logger import log +from qim3d.utils._logger import log __all__ = [ "Gaussian", @@ -119,7 +119,7 @@ class Pipeline: vol = qim3d.examples.fly_150x256x256 # Show original - qim3d.viz.slices(vol, axis=0, show=True) + qim3d.viz.slices_grid(vol, axis=0, show=True) # Create filter pipeline pipeline = Pipeline( @@ -134,7 +134,7 @@ class Pipeline: vol_filtered = pipeline(vol) # Show filtered - qim3d.viz.slices(vol_filtered, axis=0) + qim3d.viz.slices_grid(vol_filtered, axis=0) ```   diff --git a/qim3d/generate/__init__.py b/qim3d/generate/__init__.py index 6ad08aa71d5439db003afef54a64ff65aad1a79f..3f0a48780a5b9730678905dd3a5eb1c272ab700b 100644 --- a/qim3d/generate/__init__.py +++ b/qim3d/generate/__init__.py @@ -1,2 +1,2 @@ -from .blob_ import blob -from .collection_ import collection \ No newline at end of file +from ._generators import noise_object +from ._aggregators import noise_object_collection diff --git a/qim3d/generate/collection_.py b/qim3d/generate/_aggregators.py similarity index 98% rename from qim3d/generate/collection_.py rename to qim3d/generate/_aggregators.py index dc54da7e5efa65ee5dea5643fa7eed26fe90e475..9d10df37e4efb270ac539e7ff4c3f22cf7cd6b06 100644 --- a/qim3d/generate/collection_.py +++ b/qim3d/generate/_aggregators.py @@ -3,7 +3,7 @@ import scipy.ndimage from tqdm.notebook import tqdm import qim3d.generate -from qim3d.utils.logger import log +from qim3d.utils._logger import log def random_placement( @@ -120,7 +120,7 @@ def specific_placement( return collection, placed, positions -def collection( +def noise_object_collection( collection_shape: tuple = (200, 200, 200), num_objects: int = 15, positions: list[tuple] = None, @@ -255,7 +255,7 @@ def collection( ```python # Visualize slices - qim3d.viz.slices(vol, n_slices=15) + qim3d.viz.slices_grid(vol, n_slices=15) ```  @@ -285,7 +285,7 @@ def collection( ```python # Visualize slices - qim3d.viz.slices(vol, n_slices=15, axis=1) + qim3d.viz.slices_grid(vol, n_slices=15, axis=1) ```  """ @@ -350,7 +350,7 @@ def collection( log.debug(f"- Threshold: {threshold:.3f}") # Generate synthetic object - blob = qim3d.generate.blob( + blob = qim3d.generate.noise_object( base_shape=blob_shape, final_shape=final_shape, noise_scale=noise_scale, diff --git a/qim3d/generate/blob_.py b/qim3d/generate/_generators.py similarity index 96% rename from qim3d/generate/blob_.py rename to qim3d/generate/_generators.py index 153bdcfc86c898f3802dc18a3450efba24fe4e10..78bf607bc5643a8b6570a29feac0f217059ff800 100644 --- a/qim3d/generate/blob_.py +++ b/qim3d/generate/_generators.py @@ -4,7 +4,7 @@ from noise import pnoise3 import qim3d.processing -def blob( +def noise_object( base_shape: tuple = (128, 128, 128), final_shape: tuple = (128, 128, 128), noise_scale: float = 0.05, @@ -32,7 +32,7 @@ def blob( dtype (str, optional): Desired data type of the output volume. Defaults to "uint8". Returns: - synthetic_blob (numpy.ndarray): Generated 3D volume with specified parameters. + noise_object (numpy.ndarray): Generated 3D volume with specified parameters. Raises: TypeError: If `final_shape` is not a tuple or does not have three elements. @@ -52,7 +52,7 @@ def blob( ```python # Visualize slices - qim3d.viz.slices(synthetic_blob, vmin = 0, vmax = 255, n_slices = 15) + qim3d.viz.slices_grid(synthetic_blob, vmin = 0, vmax = 255, n_slices = 15) ```  @@ -69,14 +69,14 @@ def blob( object_shape = "cylinder" ) - # Visualize synthetic blob + # Visualize synthetic object qim3d.viz.vol(vol) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_blob_cylinder.html" width="100%" height="500" frameborder="0"></iframe> ```python # Visualize slices - qim3d.viz.slices(vol, n_slices=15, axis=1) + qim3d.viz.slices_grid(vol, n_slices=15, axis=1) ```  @@ -100,7 +100,7 @@ def blob( ```python # Visualize - qim3d.viz.slices(vol, n_slices=15) + qim3d.viz.slices_grid(vol, n_slices=15) ```  """ diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py index 99863cf4450710af6c49dfe16efb4313a8ab4f32..4498e1ad2fda3201efec9937d997232ec98db3b5 100644 --- a/qim3d/gui/annotation_tool.py +++ b/qim3d/gui/annotation_tool.py @@ -29,7 +29,7 @@ import numpy as np from PIL import Image from qim3d.io import load, save -from qim3d.processing.operations import overlay_rgb_images +from qim3d.operations._common_operations_methods import overlay_rgb_images from qim3d.gui.interface import BaseInterface # TODO: img in launch should be self.img diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py index 4d2b1f51bb807de6bb226616091588c91d0c4837..ecabeb81135bdbf1d4b0f64c184be62af3dbbf65 100644 --- a/qim3d/gui/data_explorer.py +++ b/qim3d/gui/data_explorer.py @@ -25,8 +25,8 @@ import numpy as np import outputformat as ouf from qim3d.io import load -from qim3d.utils.logger import log -from qim3d.utils import misc +from qim3d.utils._logger import log +from qim3d.utils import _misc from qim3d.gui.interface import BaseInterface @@ -550,7 +550,7 @@ class Interface(BaseInterface): def show_data_summary(self): summary_dict = { "Last modified": datetime.datetime.fromtimestamp(os.path.getmtime(self.file_path)).strftime("%Y-%m-%d %H:%M"), - "File size": misc.sizeof(os.path.getsize(self.file_path)), + "File size": _misc.sizeof(os.path.getsize(self.file_path)), "Z-size": str(self.vol.shape[self.axis_dict["Z"]]), "Y-size": str(self.vol.shape[self.axis_dict["Y"]]), "X-size": str(self.vol.shape[self.axis_dict["X"]]), diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py index 16bf8ffe67eb08618a1585dd8941b6dcd5b3aef0..1a20f91ea093a1c65daa0aefd73117445289c588 100644 --- a/qim3d/gui/iso3d.py +++ b/qim3d/gui/iso3d.py @@ -23,7 +23,7 @@ import plotly.graph_objects as go from scipy import ndimage from qim3d.io import load -from qim3d.utils.logger import log +from qim3d.utils._logger import log from qim3d.gui.interface import InterfaceWithExamples diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py index cf106dedeb8f47092ce21b90380161f624a1c7bf..780746a093be02faa74d2595538aa4208f2fc056 100644 --- a/qim3d/gui/layers2d.py +++ b/qim3d/gui/layers2d.py @@ -27,7 +27,7 @@ from .interface import BaseInterface # from qim3d.processing import layers2d as l2d from qim3d.processing import overlay_rgb_images, segment_layers, get_lines from qim3d.io import load -from qim3d.viz.layers2d import image_with_lines +from qim3d.viz._layers2d import image_with_lines #TODO figure out how not update anything and go through processing when there are no data loaded # So user could play with the widgets but it doesnt throw error diff --git a/qim3d/io/__init__.py b/qim3d/io/__init__.py index b56cec9b0104ee984ab10729e553ada2297a305e..8456005474cbda035509ff635a784e8330facc3c 100644 --- a/qim3d/io/__init__.py +++ b/qim3d/io/__init__.py @@ -1,7 +1,6 @@ -from .loading import DataLoader, load, load_mesh -from .downloader import Downloader -from .saving import DataSaver, save, save_mesh -from .sync import Sync -from .convert import convert -from ..utils import logger -from .ome_zarr import export_ome_zarr, import_ome_zarr +from ._loading import load, load_mesh +from ._downloader import Downloader +from ._saving import save, save_mesh +# from ._sync import Sync # this will be added back after future development +from ._convert import convert +from ._ome_zarr import export_ome_zarr, import_ome_zarr diff --git a/qim3d/io/convert.py b/qim3d/io/_convert.py similarity index 98% rename from qim3d/io/convert.py rename to qim3d/io/_convert.py index 411c35b274a7b653969515299db5957512a9670e..4ab6c0a99e71eff51949309359dd42c1276f96dd 100644 --- a/qim3d/io/convert.py +++ b/qim3d/io/_convert.py @@ -8,8 +8,8 @@ import tifffile as tiff import zarr from tqdm import tqdm -from qim3d.utils.misc import stringify_path -from qim3d.io.saving import save +from qim3d.utils._misc import stringify_path +from qim3d.io._saving import save class Convert: diff --git a/qim3d/io/downloader.py b/qim3d/io/_downloader.py similarity index 99% rename from qim3d/io/downloader.py rename to qim3d/io/_downloader.py index c94284920e39cf587835e7c2f0b8a4ed7aca4a94..fe42e410c120b5ab6efe77037d1071579add55de 100644 --- a/qim3d/io/downloader.py +++ b/qim3d/io/_downloader.py @@ -8,7 +8,7 @@ from tqdm import tqdm from pathlib import Path from qim3d.io import load -from qim3d.utils.logger import log +from qim3d.utils._logger import log import outputformat as ouf diff --git a/qim3d/io/loading.py b/qim3d/io/_loading.py similarity index 99% rename from qim3d/io/loading.py rename to qim3d/io/_loading.py index 464bddde43c543947e177ce3f32dc91b2f4d2965..e5f2191e8f71fd276e9b9460d6dabf4c6e8d4995 100644 --- a/qim3d/io/loading.py +++ b/qim3d/io/_loading.py @@ -23,10 +23,10 @@ from dask import delayed from PIL import Image, UnidentifiedImageError import qim3d -from qim3d.utils.logger import log -from qim3d.utils.misc import get_file_size, sizeof, stringify_path -from qim3d.utils.system import Memory -from qim3d.utils.progress_bar import FileLoadingProgressBar +from qim3d.utils._logger import log +from qim3d.utils._misc import get_file_size, sizeof, stringify_path +from qim3d.utils._system import Memory +from qim3d.utils._progress_bar import FileLoadingProgressBar import trimesh dask.config.set(scheduler="processes") diff --git a/qim3d/io/ome_zarr.py b/qim3d/io/_ome_zarr.py similarity index 98% rename from qim3d/io/ome_zarr.py rename to qim3d/io/_ome_zarr.py index feca06225890176d5233e15c5ee7602bead750b2..ae5173157927dbf4555b05b3a9bacc53f193380a 100644 --- a/qim3d/io/ome_zarr.py +++ b/qim3d/io/_ome_zarr.py @@ -31,9 +31,9 @@ from skimage.transform import ( resize, ) -from qim3d.utils.logger import log -from qim3d.utils.progress_bar import OmeZarrExportProgressBar -from qim3d.utils.ome_zarr import get_n_chunks +from qim3d.utils._logger import log +from qim3d.utils._progress_bar import OmeZarrExportProgressBar +from qim3d.utils._ome_zarr import get_n_chunks ListOfArrayLike = Union[List[da.Array], List[np.ndarray]] diff --git a/qim3d/io/saving.py b/qim3d/io/_saving.py similarity index 99% rename from qim3d/io/saving.py rename to qim3d/io/_saving.py index efc2333ad2f608cda6247c9f2b9070cf817ba828..d7721ee2a082a0a7db3f993c58ec9207eefcc5f4 100644 --- a/qim3d/io/saving.py +++ b/qim3d/io/_saving.py @@ -37,8 +37,8 @@ from pydicom.dataset import FileDataset, FileMetaDataset from pydicom.uid import UID import trimesh -from qim3d.utils.logger import log -from qim3d.utils.misc import sizeof, stringify_path +from qim3d.utils._logger import log +from qim3d.utils._misc import sizeof, stringify_path class DataSaver: diff --git a/qim3d/io/sync.py b/qim3d/io/_sync.py similarity index 99% rename from qim3d/io/sync.py rename to qim3d/io/_sync.py index 3992cd33f29acf26949f44b0e89e244ced0aa20e..9085ae88b3a4e60b6ce2a0744c8983c825b635a3 100644 --- a/qim3d/io/sync.py +++ b/qim3d/io/_sync.py @@ -2,7 +2,7 @@ import os import subprocess import outputformat as ouf -from qim3d.utils.logger import log +from qim3d.utils._logger import log from pathlib import Path diff --git a/qim3d/mesh/__init__.py b/qim3d/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..932dfd9e1f859491652cfde541a71ba6adc28f24 --- /dev/null +++ b/qim3d/mesh/__init__.py @@ -0,0 +1 @@ +from ._common_mesh_methods import from_volume diff --git a/qim3d/processing/mesh.py b/qim3d/mesh/_common_mesh_methods.py similarity index 98% rename from qim3d/processing/mesh.py rename to qim3d/mesh/_common_mesh_methods.py index b4ff678398c765bebe496f7ee6d5d1015b7d7b1d..535fcae886ce081c20c5bb01628341beba5163c3 100644 --- a/qim3d/processing/mesh.py +++ b/qim3d/mesh/_common_mesh_methods.py @@ -2,10 +2,10 @@ import numpy as np from skimage import measure, filters import trimesh from typing import Tuple, Any -from qim3d.utils.logger import log +from qim3d.utils._logger import log -def create_mesh( +def from_volume( volume: np.ndarray, level: float = None, step_size=1, diff --git a/qim3d/ml/__init__.py b/qim3d/ml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf15999bd205e6552a949787b44f4ba3da4ec043 --- /dev/null +++ b/qim3d/ml/__init__.py @@ -0,0 +1,4 @@ +from ._augmentations import Augmentation +from ._data import Dataset, prepare_dataloaders, prepare_datasets +from ._ml_utils import inference, model_summary, train_model +from .models import * \ No newline at end of file diff --git a/qim3d/models/augmentations.py b/qim3d/ml/_augmentations.py similarity index 100% rename from qim3d/models/augmentations.py rename to qim3d/ml/_augmentations.py diff --git a/qim3d/models/data.py b/qim3d/ml/_data.py similarity index 99% rename from qim3d/models/data.py rename to qim3d/ml/_data.py index fc61d262f753ebc637c27e7d60e38d0730a714c8..38fbdac79c3fb7d055c457376351a29b1ad0bf79 100644 --- a/qim3d/models/data.py +++ b/qim3d/ml/_data.py @@ -1,7 +1,7 @@ """Provides a custom Dataset class for building a PyTorch dataset.""" from pathlib import Path from PIL import Image -from qim3d.utils.logger import log +from qim3d.utils._logger import log import torch import numpy as np diff --git a/qim3d/models/models.py b/qim3d/ml/_ml_utils.py similarity index 99% rename from qim3d/models/models.py rename to qim3d/ml/_ml_utils.py index f04052498907ec6f0bda81097c93bfada7faf8ed..7c8a3b86d005e467c15a378e3216011cbf3ac60c 100644 --- a/qim3d/models/models.py +++ b/qim3d/ml/_ml_utils.py @@ -4,8 +4,8 @@ import torch import numpy as np from torchinfo import summary -from qim3d.utils.logger import log -from qim3d.viz.metrics import plot_metrics +from qim3d.utils._logger import log +from qim3d.viz._metrics import plot_metrics from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4624be5fb111dac4b820b52f5ab072bf744c9154 --- /dev/null +++ b/qim3d/ml/models/__init__.py @@ -0,0 +1 @@ +from ._unet import UNet, Hyperparameters diff --git a/qim3d/models/unet.py b/qim3d/ml/models/_unet.py similarity index 99% rename from qim3d/models/unet.py rename to qim3d/ml/models/_unet.py index 6c57f55e44eb13403a39d36b8a85c5d400cc00ea..353d10bde5dc08f1b44044b9e482d2adf66a346d 100644 --- a/qim3d/models/unet.py +++ b/qim3d/ml/models/_unet.py @@ -2,7 +2,7 @@ import torch.nn as nn -from qim3d.utils.logger import log +from qim3d.utils._logger import log class UNet(nn.Module): diff --git a/qim3d/models/__init__.py b/qim3d/models/__init__.py deleted file mode 100644 index a145a44fd76a561174a7f28154bd20dc0134d6fd..0000000000000000000000000000000000000000 --- a/qim3d/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .unet import UNet, Hyperparameters -from .augmentations import Augmentation -from .data import Dataset, prepare_dataloaders, prepare_datasets -from .models import inference, model_summary, train_model diff --git a/qim3d/operations/__init__.py b/qim3d/operations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..add2a01195ac464bf6c9b9d16a572f9561e235db --- /dev/null +++ b/qim3d/operations/__init__.py @@ -0,0 +1 @@ +from ._common_operations_methods import * diff --git a/qim3d/processing/operations.py b/qim3d/operations/_common_operations_methods.py similarity index 76% rename from qim3d/processing/operations.py rename to qim3d/operations/_common_operations_methods.py index e559a16d0100186e36ba96b597b6bdfa58f286fb..1c14e16329dbf0a9cce14695c56f927c5b69d820 100644 --- a/qim3d/processing/operations.py +++ b/qim3d/operations/_common_operations_methods.py @@ -1,7 +1,8 @@ import numpy as np -import qim3d.processing.filters as filters -from qim3d.utils.logger import log +import qim3d.filters as filters +from qim3d.utils._logger import log +__all__ = ["remove_background", "fade_mask", "overlay_rgb_images"] def remove_background( vol: np.ndarray, @@ -29,7 +30,7 @@ def remove_background( import qim3d vol = qim3d.examples.cement_128x128x128 - qim3d.viz.slices(vol, vmin=0, vmax=255) + qim3d.viz.slices_grid(vol, vmin=0, vmax=255) ```  @@ -37,7 +38,7 @@ def remove_background( vol_filtered = qim3d.processing.operations.remove_background(vol, min_object_radius=3, background="bright") - qim3d.viz.slices(vol_filtered, vmin=0, vmax=255) + qim3d.viz.slices_grid(vol_filtered, vmin=0, vmax=255) ```  """ @@ -52,74 +53,6 @@ def remove_background( return pipeline(vol) -def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, int]: - """ - Apply watershed segmentation to a binary volume. - - Args: - bin_vol (np.ndarray): Binary volume to segment. The input should be a 3D binary image where non-zero elements - represent the objects to be segmented. - min_distance (int): Minimum number of pixels separating peaks in the distance transform. Peaks that are - too close will be merged, affecting the number of segmented objects. Default is 5. - - Returns: - tuple[np.ndarray, int]: - - Labeled volume (np.ndarray): A 3D array of the same shape as the input `bin_vol`, where each segmented object - is assigned a unique integer label. - - num_labels (int): The total number of unique objects found in the labeled volume. - - Example: - ```python - import qim3d - - vol = qim3d.examples.cement_128x128x128 - binary = qim3d.processing.filters.gaussian(vol, sigma = 2)<60 - - qim3d.viz.slices(binary, axis=1) - ``` -  - - ```python - labeled_volume, num_labels = qim3d.processing.operations.watershed(binary) - - cmap = qim3d.viz.colormaps.objects(num_labels) - qim3d.viz.slices(labeled_volume, axis = 1, cmap = cmap) - ``` -  - - """ - import skimage - import scipy - - if len(np.unique(bin_vol)) > 2: - raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.") - - # Compute distance transform of binary volume - distance = scipy.ndimage.distance_transform_edt(bin_vol) - - # Find peak coordinates in distance transform - coords = skimage.feature.peak_local_max( - distance, min_distance=min_distance, labels=bin_vol - ) - - # Create a mask with peak coordinates - mask = np.zeros(distance.shape, dtype=bool) - mask[tuple(coords.T)] = True - - # Label peaks - markers, _ = scipy.ndimage.label(mask) - - # Apply watershed segmentation - labeled_volume = skimage.segmentation.watershed( - -distance, markers=markers, mask=bin_vol - ) - - # Extract number of objects found - num_labels = len(np.unique(labeled_volume)) - 1 - log.info(f"Total number of objects found: {num_labels}") - - return labeled_volume, num_labels - def fade_mask( vol: np.ndarray, diff --git a/qim3d/processing/__init__.py b/qim3d/processing/__init__.py index 8f6edf04173e5781339f0ca33ff304a29851f0f9..913d6ddb12fd7a22df1ba522bc25b2db93d35dbf 100644 --- a/qim3d/processing/__init__.py +++ b/qim3d/processing/__init__.py @@ -1,9 +1,3 @@ -from .local_thickness_ import local_thickness -from .structure_tensor_ import structure_tensor -from .detection import blob_detection -from .filters import * -from .operations import * -from .cc import get_3d_cc -from .layers2d import segment_layers, get_lines -from .mesh import create_mesh -from .features import volume, area, sphericity +from ._local_thickness import local_thickness +from ._structure_tensor import structure_tensor +from ._layers import segment_layers, get_lines diff --git a/qim3d/processing/layers2d.py b/qim3d/processing/_layers.py similarity index 100% rename from qim3d/processing/layers2d.py rename to qim3d/processing/_layers.py diff --git a/qim3d/processing/local_thickness_.py b/qim3d/processing/_local_thickness.py similarity index 99% rename from qim3d/processing/local_thickness_.py rename to qim3d/processing/_local_thickness.py index f106cbcbb6de1f14e47d256fd654106a9e5ff5ac..a7e877d6d01b51c1aa3af9d403dd5447b7a42f6d 100644 --- a/qim3d/processing/local_thickness_.py +++ b/qim3d/processing/_local_thickness.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional -from qim3d.utils.logger import log +from qim3d.utils._logger import log import qim3d diff --git a/qim3d/processing/structure_tensor_.py b/qim3d/processing/_structure_tensor.py similarity index 98% rename from qim3d/processing/structure_tensor_.py rename to qim3d/processing/_structure_tensor.py index 97a0ba731bda4fae777cf00db4aa91ccfaeb7aa7..10eba89076a3c1158c810f167ec7049735b15237 100644 --- a/qim3d/processing/structure_tensor_.py +++ b/qim3d/processing/_structure_tensor.py @@ -3,7 +3,7 @@ from typing import Tuple import logging import numpy as np -from qim3d.utils.logger import log +from qim3d.utils._logger import log def structure_tensor( @@ -108,7 +108,7 @@ def structure_tensor( val, vec = st.eig_special_3d(s_vol, full=full) if visualize: - from qim3d.viz.structure_tensor import vectors + from qim3d.viz._structure_tensor import vectors display(vectors(vol, vec, **viz_kwargs)) diff --git a/qim3d/segmentation/__init__.py b/qim3d/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efdd62321da14d0e45ebf6d05c9245d923607e9e --- /dev/null +++ b/qim3d/segmentation/__init__.py @@ -0,0 +1,2 @@ +from ._common_segmentation_methods import * +from ._connected_components import get_3d_cc diff --git a/qim3d/segmentation/_common_segmentation_methods.py b/qim3d/segmentation/_common_segmentation_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4f231d16a43b5718328e817905e8df28e7c551 --- /dev/null +++ b/qim3d/segmentation/_common_segmentation_methods.py @@ -0,0 +1,72 @@ +import numpy as np +from qim3d.utils._logger import log + +__all__ = ["watershed"] + +def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, int]: + """ + Apply watershed segmentation to a binary volume. + + Args: + bin_vol (np.ndarray): Binary volume to segment. The input should be a 3D binary image where non-zero elements + represent the objects to be segmented. + min_distance (int): Minimum number of pixels separating peaks in the distance transform. Peaks that are + too close will be merged, affecting the number of segmented objects. Default is 5. + + Returns: + tuple[np.ndarray, int]: + - Labeled volume (np.ndarray): A 3D array of the same shape as the input `bin_vol`, where each segmented object + is assigned a unique integer label. + - num_labels (int): The total number of unique objects found in the labeled volume. + + Example: + ```python + import qim3d + + vol = qim3d.examples.cement_128x128x128 + binary = qim3d.processing.filters.gaussian(vol, sigma = 2)<60 + + qim3d.viz.slices_grid(binary, axis=1) + ``` +  + + ```python + labeled_volume, num_labels = qim3d.processing.operations.watershed(binary) + + cmap = qim3d.viz.colormaps.objects(num_labels) + qim3d.viz.slices_grid(labeled_volume, axis = 1, cmap = cmap) + ``` +  + + """ + import skimage + import scipy + + if len(np.unique(bin_vol)) > 2: + raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.") + + # Compute distance transform of binary volume + distance = scipy.ndimage.distance_transform_edt(bin_vol) + + # Find peak coordinates in distance transform + coords = skimage.feature.peak_local_max( + distance, min_distance=min_distance, labels=bin_vol + ) + + # Create a mask with peak coordinates + mask = np.zeros(distance.shape, dtype=bool) + mask[tuple(coords.T)] = True + + # Label peaks + markers, _ = scipy.ndimage.label(mask) + + # Apply watershed segmentation + labeled_volume = skimage.segmentation.watershed( + -distance, markers=markers, mask=bin_vol + ) + + # Extract number of objects found + num_labels = len(np.unique(labeled_volume)) - 1 + log.info(f"Total number of objects found: {num_labels}") + + return labeled_volume, num_labels \ No newline at end of file diff --git a/qim3d/processing/cc.py b/qim3d/segmentation/_connected_components.py similarity index 98% rename from qim3d/processing/cc.py rename to qim3d/segmentation/_connected_components.py index 43283c230fd71cd36f841861f58c28a1a3ec7777..f43f0dc216a1c49f2d35719e8a6a397340092107 100644 --- a/qim3d/processing/cc.py +++ b/qim3d/segmentation/_connected_components.py @@ -1,6 +1,6 @@ import numpy as np from scipy.ndimage import find_objects, label -from qim3d.utils.logger import log +from qim3d.utils._logger import log class CC: diff --git a/qim3d/tests/__init__.py b/qim3d/tests/__init__.py index c08e68dde17af7ac5c45d2fab81447f168750f6a..1fa15c27df5773c234deedf1e0bc500d41b6ecd4 100644 --- a/qim3d/tests/__init__.py +++ b/qim3d/tests/__init__.py @@ -8,7 +8,7 @@ import shutil from PIL import Image import socket import numpy as np -from qim3d.utils.logger import log +from qim3d.utils._logger import log def mock_plot(): diff --git a/qim3d/tests/processing/test_connected_components.py b/qim3d/tests/processing/test_connected_components.py index 0e972123f11b4b9f163127afc110872155e84e95..9d59aff94281eb08acc361a2c633b4e9e30d6916 100644 --- a/qim3d/tests/processing/test_connected_components.py +++ b/qim3d/tests/processing/test_connected_components.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from qim3d.processing.cc import get_3d_cc +from qim3d.segmentation._connected_components import get_3d_cc @pytest.fixture(scope="module") diff --git a/qim3d/tests/processing/test_filters.py b/qim3d/tests/processing/test_filters.py index e63628118775b34d2c98f897d303d89e4c18b60a..5a6d0993fb4ac6cecaf9c26d06b450eeb2a9f893 100644 --- a/qim3d/tests/processing/test_filters.py +++ b/qim3d/tests/processing/test_filters.py @@ -1,5 +1,5 @@ import qim3d -from qim3d.processing.filters import * +from qim3d.filters import * import numpy as np import pytest import re diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index bcbed39d1cfca3dc5179c8b33c19527a9a5db81e..ec9599105a9cdccb3d44fdfa048f0a452cd47c8a 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -52,14 +52,14 @@ def test_grid_pred(): # unit tests for slices function def test_slices_numpy_array_input(): example_volume = np.ones((10, 10, 10)) - fig = qim3d.viz.slices(example_volume, n_slices=1) + fig = qim3d.viz.slices_grid(example_volume, n_slices=1) assert isinstance(fig, plt.Figure) def test_slices_wrong_input_format(): input = "not_a_volume" with pytest.raises(ValueError, match="Data type not supported"): - qim3d.viz.slices(input) + qim3d.viz.slices_grid(input) def test_slices_not_volume(): @@ -68,7 +68,7 @@ def test_slices_not_volume(): ValueError, match="The provided object is not a volume as it has less than 3 dimensions.", ): - qim3d.viz.slices(example_volume) + qim3d.viz.slices_grid(example_volume) def test_slices_wrong_position_format1(): @@ -77,7 +77,7 @@ def test_slices_wrong_position_format1(): ValueError, match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', ): - qim3d.viz.slices(example_volume, position="invalid_slice") + qim3d.viz.slices_grid(example_volume, position="invalid_slice") def test_slices_wrong_position_format2(): @@ -86,7 +86,7 @@ def test_slices_wrong_position_format2(): ValueError, match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', ): - qim3d.viz.slices(example_volume, position=1.5) + qim3d.viz.slices_grid(example_volume, position=1.5) def test_slices_wrong_position_format3(): @@ -95,7 +95,7 @@ def test_slices_wrong_position_format3(): ValueError, match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', ): - qim3d.viz.slices(example_volume, position=[1, 2, 3.5]) + qim3d.viz.slices_grid(example_volume, position=[1, 2, 3.5]) def test_slices_invalid_axis_value(): @@ -104,14 +104,14 @@ def test_slices_invalid_axis_value(): ValueError, match="Invalid value for 'axis'. It should be an integer between 0 and 2", ): - qim3d.viz.slices(example_volume, axis=3) + qim3d.viz.slices_grid(example_volume, axis=3) def test_slices_interpolation_option(): example_volume = np.ones((10, 10, 10)) img_width = 3 interpolation_method = "bilinear" - fig = qim3d.viz.slices( + fig = qim3d.viz.slices_grid( example_volume, n_slices=1, img_width=img_width, @@ -130,7 +130,7 @@ def test_slices_multiple_slices(): example_volume = np.ones((10, 10, 10)) img_width = 3 n_slices = 3 - fig = qim3d.viz.slices(example_volume, n_slices=n_slices, img_width=img_width) + fig = qim3d.viz.slices_grid(example_volume, n_slices=n_slices, img_width=img_width) # Add assertions for the expected number of subplots in the figure assert len(fig.get_axes()) == n_slices @@ -141,13 +141,13 @@ def test_slices_axis_argument(): img_width = 3 # Call the function with different values of the axis - fig_axis_0 = qim3d.viz.slices( + fig_axis_0 = qim3d.viz.slices_grid( example_volume, n_slices=1, img_width=img_width, axis=0 ) - fig_axis_1 = qim3d.viz.slices( + fig_axis_1 = qim3d.viz.slices_grid( example_volume, n_slices=1, img_width=img_width, axis=1 ) - fig_axis_2 = qim3d.viz.slices( + fig_axis_2 = qim3d.viz.slices_grid( example_volume, n_slices=1, img_width=img_width, axis=2 ) diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py index e9eb7e7fbc36af21a80f7421b91a879b75a1763d..0c7d5b30a7b3daf15b80d5995e6d0e4a1f18e5cc 100644 --- a/qim3d/utils/__init__.py +++ b/qim3d/utils/__init__.py @@ -1,7 +1,7 @@ -from . import doi -from .system import Memory +from . import _doi +from ._system import Memory -from .misc import ( +from ._misc import ( get_local_ip, port_from_str, gradio_header, @@ -13,4 +13,4 @@ from .misc import ( scale_to_float16, ) -from .server import start_http_server \ No newline at end of file +from ._server import start_http_server \ No newline at end of file diff --git a/qim3d/utils/cli.py b/qim3d/utils/_cli.py similarity index 100% rename from qim3d/utils/cli.py rename to qim3d/utils/_cli.py diff --git a/qim3d/utils/doi.py b/qim3d/utils/_doi.py similarity index 98% rename from qim3d/utils/doi.py rename to qim3d/utils/_doi.py index 60e01d76281159a86a08e3868527a7442f2a230d..b3ece6636b062d24d2002405a08abb1726242f4f 100644 --- a/qim3d/utils/doi.py +++ b/qim3d/utils/_doi.py @@ -1,7 +1,7 @@ """ Deals with DOI for references """ import json import requests -from qim3d.utils.logger import log +from qim3d.utils._logger import log def _validate_response(response): diff --git a/qim3d/utils/logger.py b/qim3d/utils/_logger.py similarity index 100% rename from qim3d/utils/logger.py rename to qim3d/utils/_logger.py diff --git a/qim3d/utils/misc.py b/qim3d/utils/_misc.py similarity index 100% rename from qim3d/utils/misc.py rename to qim3d/utils/_misc.py diff --git a/qim3d/utils/ome_zarr.py b/qim3d/utils/_ome_zarr.py similarity index 100% rename from qim3d/utils/ome_zarr.py rename to qim3d/utils/_ome_zarr.py diff --git a/qim3d/utils/progress_bar.py b/qim3d/utils/_progress_bar.py similarity index 99% rename from qim3d/utils/progress_bar.py rename to qim3d/utils/_progress_bar.py index 6790b1bfdd2cb482d6f992297ea759f3161e0b10..629026f4b08a84e664925ce7a0a1a09a9c214434 100644 --- a/qim3d/utils/progress_bar.py +++ b/qim3d/utils/_progress_bar.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from tqdm.auto import tqdm -from qim3d.utils.misc import get_file_size +from qim3d.utils._misc import get_file_size class RepeatTimer(Timer): diff --git a/qim3d/utils/server.py b/qim3d/utils/_server.py similarity index 98% rename from qim3d/utils/server.py rename to qim3d/utils/_server.py index 65344f1b8483d4066a99c32ad50a9870914184ab..9d20fccd76c47043ea50c0c74a4c2bee11097d5e 100644 --- a/qim3d/utils/server.py +++ b/qim3d/utils/_server.py @@ -1,7 +1,7 @@ import os from http.server import SimpleHTTPRequestHandler, HTTPServer import threading -from qim3d.utils.logger import log +from qim3d.utils._logger import log class CustomHTTPRequestHandler(SimpleHTTPRequestHandler): def end_headers(self): diff --git a/qim3d/utils/system.py b/qim3d/utils/_system.py similarity index 98% rename from qim3d/utils/system.py rename to qim3d/utils/_system.py index 39dc3af72ffd133e6d758d99070d0e0c1e36819d..197926270bcc3a073308b5e18f6a09ba904b92fb 100644 --- a/qim3d/utils/system.py +++ b/qim3d/utils/_system.py @@ -2,8 +2,8 @@ import os import time import psutil -from qim3d.utils.misc import sizeof -from qim3d.utils.logger import log +from qim3d.utils._misc import sizeof +from qim3d.utils._logger import log import numpy as np diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 33d94416d98cdf1af5d2c9fca548416eee65a359..bc65ee16184325418e881e7c2303061a37cc60c9 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,18 +1,18 @@ from . import colormaps -from .cc import plot_cc -from .detection import circles -from .explore import ( - interactive_fade_mask, - orthogonal, +from ._cc import plot_cc +from ._detection import circles +from ._data_exploration import ( + fade_mask, slicer, - slices, + slicer_orthogonal, + slices_grid, chunks, histogram, ) from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError -from .k3d import vol, mesh -from .local_thickness_ import local_thickness -from .structure_tensor import vectors -from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked -from .preview import image_preview -from . import layers2d +from ._k3d import volumetric, mesh +from ._local_thickness import local_thickness +from ._structure_tensor import vectors +from ._metrics import plot_metrics, grid_overview, grid_pred, vol_masked +from ._preview import image_preview +from . import _layers2d diff --git a/qim3d/viz/cc.py b/qim3d/viz/_cc.py similarity index 90% rename from qim3d/viz/cc.py rename to qim3d/viz/_cc.py index b7dd999ec409e8094aac1b04f02f23d21173966a..1c0ed56098696de99ccb1706d58321102317aeb8 100644 --- a/qim3d/viz/cc.py +++ b/qim3d/viz/_cc.py @@ -1,8 +1,7 @@ import matplotlib.pyplot as plt import numpy as np - -from qim3d.utils.logger import log -import qim3d.viz.colormaps +import qim3d +from qim3d.utils._logger import log def plot_cc( @@ -12,9 +11,9 @@ def plot_cc( overlay=None, crop=False, show=True, - cmap:str = 'viridis', - vmin:float = None, - vmax:float = None, + cmap: str = "viridis", + vmin: float = None, + vmax: float = None, **kwargs, ) -> list[plt.Figure]: """ @@ -30,7 +29,7 @@ def plot_cc( cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`. + **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices_grid`. Returns: figs (list[plt.Figure]): List of figures, if `show=False`. @@ -75,14 +74,16 @@ def plot_cc( else: cc = connected_components.get_cc(component, crop=False) overlay_crop = np.where(cc == 0, 0, overlay) - fig = qim3d.viz.slices(overlay_crop, show=show, cmap = cmap, vmin = vmin, vmax = vmax, **kwargs) + fig = qim3d.viz.slices_grid( + overlay_crop, show=show, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs + ) else: # assigns discrete color map to each connected component if not given if "cmap" not in kwargs: - kwargs["cmap"] = qim3d.viz.colormaps.objects(len(component_indexs)) + kwargs["cmap"] = qim3d.viz.colormaps.segmentation(len(component_indexs)) # Plot the connected component without overlay - fig = qim3d.viz.slices( + fig = qim3d.viz.slices_grid( connected_components.get_cc(component, crop=crop), show=show, **kwargs ) diff --git a/qim3d/viz/explore.py b/qim3d/viz/_data_exploration.py similarity index 61% rename from qim3d/viz/explore.py rename to qim3d/viz/_data_exploration.py index 8fe2a545adcf4540c8fdc50285759837e05b6fac..2198cd94a0eb703b7a3b95f15a4407397d642a06 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/_data_exploration.py @@ -14,150 +14,162 @@ from IPython.display import SVG, display import matplotlib import numpy as np import zarr -from qim3d.utils.logger import log +from qim3d.utils._logger import log import seaborn as sns import qim3d -def slices( - vol: np.ndarray, - axis: int = 0, - position: Optional[Union[str, int, List[int]]] = None, - n_slices: int = 5, - max_cols: int = 5, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 2, - img_width: int = 2, - show: bool = False, - show_position: bool = True, + +def slices_grid( + volume: np.ndarray, + slice_axis: int = 0, + slice_positions: Optional[Union[str, int, List[int]]] = None, + num_slices: int = 15, + max_columns: int = 5, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, + image_size=None, + image_height: int = 2, + image_width: int = 2, + display_figure: bool = False, + display_positions: bool = True, interpolation: Optional[str] = None, - img_size=None, - cbar: bool = False, - cbar_style: str = "small", - **imshow_kwargs, + color_bar: bool = False, + color_bar_style: str = "small", + **matplotlib_imshow_kwargs, ) -> plt.Figure: """Displays one or several slices from a 3d volume. - By default if `position` is None, slices plots `n_slices` linearly spaced slices. - If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position. - If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted. + By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices. + If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position. + If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted. Args: vol np.ndarray: The 3D volume to be sliced. - axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. - position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. - n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. - max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. - img_width (int, optional): Width of the figure. - show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. - show_position (bool, optional): If True, displays the position of the slices. Defaults to True. + slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + slice_positions (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. + num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15. + max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. + image_width (int, optional): Width of the figure. + display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to True. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. - cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. - cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'. + color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. + color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'. Returns: fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. Raises: ValueError: If the input is not a numpy.ndarray or da.core.Array. - ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. + ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. ValueError: If the file or array is not a volume with at least 3 dimensions. ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end". - ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'. + ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'. Example: ```python import qim3d vol = qim3d.examples.shell_225x128x128 - qim3d.viz.slices(vol, n_slices=15) + qim3d.viz.slices_grid_grid(vol, num_slices=15) ```  """ - if img_size: - img_height = img_size - img_width = img_size + if image_size: + image_height = image_size + image_width = image_size - # If we pass python None to the imshow function, it will set to + # If we pass python None to the imshow function, it will set to # default value 'antialiased' if interpolation is None: - interpolation = 'none' + interpolation = "none" # Numpy array or Torch tensor input - if not isinstance(vol, (np.ndarray, da.core.Array)): + if not isinstance(volume, (np.ndarray, da.core.Array)): raise ValueError("Data type not supported") - if vol.ndim < 3: + if volume.ndim < 3: raise ValueError( "The provided object is not a volume as it has less than 3 dimensions." ) - cbar_style_options = ["small", "large"] - if cbar_style not in cbar_style_options: - raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.") - - if isinstance(vol, da.core.Array): - vol = vol.compute() + color_bar_style_options = ["small", "large"] + if color_bar_style not in color_bar_style_options: + raise ValueError( + f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}." + ) + + if isinstance(volume, da.core.Array): + volume = volume.compute() # Ensure axis is a valid choice - if not (0 <= axis < vol.ndim): + if not (0 <= slice_axis < volume.ndim): raise ValueError( - f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}." + f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}." ) - if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects': - num_labels = len(np.unique(vol)) + # Here we deal with the case that the user wants to use the objects colormap directly + if ( + type(color_map) == matplotlib.colors.LinearSegmentedColormap + or color_map == "segmentation" + ): + num_labels = len(np.unique(volume)) - if cmap == 'objects': - cmap = qim3d.viz.colormaps.objects(num_labels) - # If vmin and vmax are not set like this, then in case the - # number of objects changes on new slice, objects might change - # colors. So when using a slider, the same object suddently + if color_map == "segmentation": + color_map = qim3d.viz.colormaps.segmentation(num_labels) + # If value_min and value_max are not set like this, then in case the + # number of objects changes on new slice, objects might change + # colors. So when using a slider, the same object suddently # changes color (flickers), which is confusing and annoying. - vmin = 0 - vmax = num_labels - + value_min = 0 + value_max = num_labels # Get total number of slices in the specified dimension - n_total = vol.shape[axis] + n_total = volume.shape[slice_axis] # Position is not provided - will use linearly spaced slices - if position is None: - slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int) + if slice_positions is None: + slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int) # Position is a string - elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]: - if position.lower() == "start": - slice_idxs = _get_slice_range(0, n_slices, n_total) - elif position.lower() == "mid": - slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total) - elif position.lower() == "end": - slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total) + elif isinstance(slice_positions, str) and slice_positions.lower() in [ + "start", + "mid", + "end", + ]: + if slice_positions.lower() == "start": + slice_idxs = _get_slice_range(0, num_slices, n_total) + elif slice_positions.lower() == "mid": + slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total) + elif slice_positions.lower() == "end": + slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total) # Position is an integer - elif isinstance(position, int): - slice_idxs = _get_slice_range(position, n_slices, n_total) + elif isinstance(slice_positions, int): + slice_idxs = _get_slice_range(slice_positions, num_slices, n_total) # Position is a list of integers - elif isinstance(position, list) and all(isinstance(idx, int) for idx in position): - slice_idxs = position + elif isinstance(slice_positions, list) and all( + isinstance(idx, int) for idx in slice_positions + ): + slice_idxs = slice_positions else: raise ValueError( 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' ) # Make grid - nrows = math.ceil(n_slices / max_cols) - ncols = min(n_slices, max_cols) + nrows = math.ceil(num_slices / max_columns) + ncols = min(num_slices, max_columns) # Generate figure fig, axs = plt.subplots( nrows=nrows, ncols=ncols, - figsize=(ncols * img_height, nrows * img_width), + figsize=(ncols * image_height, nrows * image_width), constrained_layout=True, ) @@ -165,47 +177,53 @@ def slices( axs = [axs] # Convert to a list for uniformity # Convert to NumPy array in order to use the numpy.take method - if isinstance(vol, da.core.Array): - vol = vol.compute() + if isinstance(volume, da.core.Array): + volume = volume.compute() - if cbar: - # In this case, we want the vrange to be constant across the - # slices, which makes them all comparable to a single cbar. - new_vmin = vmin if vmin is not None else np.min(vol) - new_vmax = vmax if vmax is not None else np.max(vol) + if color_bar: + # In this case, we want the vrange to be constant across the + # slices, which makes them all comparable to a single color_bar. + new_value_min = value_min if value_min is not None else np.min(volume) + new_value_max = value_max if value_max is not None else np.max(volume) # Run through each ax of the grid for i, ax_row in enumerate(axs): for j, ax in enumerate(np.atleast_1d(ax_row)): - slice_idx = i * max_cols + j + slice_idx = i * max_columns + j try: - slice_img = vol.take(slice_idxs[slice_idx], axis=axis) + slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis) - if not cbar: - # If vmin is higher than the highest value in the - # image ValueError is raised. We don't want to + if not color_bar: + # If value_min is higher than the highest value in the + # image ValueError is raised. We don't want to # override the values because next slices might be okay - new_vmin = ( + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if ( + isinstance(value_min, (float, int)) + and value_min > np.max(slice_img) + ) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if ( + isinstance(value_max, (float, int)) + and value_max < np.min(slice_img) + ) + else value_max ) ax.imshow( slice_img, - cmap=cmap, + cmap=color_map, interpolation=interpolation, - vmin=new_vmin, - vmax=new_vmax, - **imshow_kwargs, + vmin=new_value_min, + vmax=new_value_max, + **matplotlib_imshow_kwargs, ) - if show_position: + if display_positions: ax.text( 0.0, 1.0, @@ -221,7 +239,7 @@ def slices( ax.text( 1.0, 0.0, - f"axis {axis} ", + f"axis {slice_axis} ", transform=ax.transAxes, color="white", fontsize=8, @@ -237,33 +255,40 @@ def slices( # Hide the axis, so that we have a nice grid ax.axis("off") - if cbar: + if color_bar: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) fig.tight_layout() - norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True) - mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + norm = matplotlib.colors.Normalize( + vmin=new_value_min, vmax=new_value_max, clip=True + ) + mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map) - if cbar_style =="small": + if color_bar_style == "small": # Figure coordinates of top-right axis tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images - cbar_ax = fig.add_axes( + color_bar_ax = fig.add_axes( [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] ) - fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") - elif cbar_style == "large": + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") + elif color_bar_style == "large": # Figure coordinates of bottom- and top-right axis br_pos = np.atleast_1d(axs[-1])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images - cbar_ax = fig.add_axes( - [br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015] + color_bar_ax = fig.add_axes( + [ + br_pos.xmax + 0.05 / ncols, + br_pos.y0 + 0.0015, + 0.05 / ncols, + (tr_pos.y1 - br_pos.y0) - 0.0015, + ] ) - fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") - if show: + if display_figure: plt.show() plt.close() @@ -271,49 +296,51 @@ def slices( return fig -def _get_slice_range(position: int, n_slices: int, n_total): +def _get_slice_range(position: int, num_slices: int, n_total): """Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" - start_idx = position - n_slices // 2 + start_idx = position - num_slices // 2 end_idx = ( - position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1 + position + num_slices // 2 + if num_slices % 2 == 0 + else position + num_slices // 2 + 1 ) slice_idxs = np.arange(start_idx, end_idx) if slice_idxs[0] < 0: - slice_idxs = np.arange(0, n_slices) + slice_idxs = np.arange(0, num_slices) elif slice_idxs[-1] > n_total: - slice_idxs = np.arange(n_total - n_slices, n_total) + slice_idxs = np.arange(n_total - num_slices, n_total) return slice_idxs def slicer( - vol: np.ndarray, - axis: int = 0, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 3, - img_width: int = 3, - show_position: bool = False, + volume: np.ndarray, + slice_axis: int = 0, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, + image_height: int = 3, + image_width: int = 3, + display_positions: bool = False, interpolation: Optional[str] = None, - img_size=None, - cbar: bool = False, - **imshow_kwargs, + image_size=None, + color_bar: bool = False, + **matplotlib_imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. Args: - vol (np.ndarray): The 3D volume to be sliced. - axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. Defaults to 3. - img_width (int, optional): Width of the figure. Defaults to 3. - show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + volume (np.ndarray): The 3D volume to be sliced. + slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. Defaults to 3. + image_width (int, optional): Width of the figure. Defaults to 3. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. - cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. + color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. Returns: slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. @@ -328,98 +355,98 @@ def slicer(  """ - if img_size: - img_height = img_size - img_width = img_size + if image_size: + image_height = image_size + image_width = image_size # Create the interactive widget - def _slicer(position): - fig = slices( - vol, - axis=axis, - cmap=cmap, - vmin=vmin, - vmax=vmax, - img_height=img_height, - img_width=img_width, - show_position=show_position, + def _slicer(slice_positions): + fig = slices_grid( + volume, + slice_axis=slice_axis, + color_map=color_map, + value_min=value_min, + value_max=value_max, + image_height=image_height, + image_width=image_width, + display_positions=display_positions, interpolation=interpolation, - position=position, - n_slices=1, - show=True, - cbar=cbar, - **imshow_kwargs, + slice_positions=slice_positions, + num_slices=1, + display_figure=True, + color_bar=color_bar, + **matplotlib_imshow_kwargs, ) return fig position_slider = widgets.IntSlider( - value=vol.shape[axis] // 2, + value=volume.shape[slice_axis] // 2, min=0, - max=vol.shape[axis] - 1, + max=volume.shape[slice_axis] - 1, description="Slice", continuous_update=True, ) - slicer_obj = widgets.interactive(_slicer, position=position_slider) + slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider) slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj -def orthogonal( - vol: np.ndarray, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 3, - img_width: int = 3, - show_position: bool = False, +def slicer_orthogonal( + volume: np.ndarray, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, + image_height: int = 3, + image_width: int = 3, + display_positions: bool = False, interpolation: Optional[str] = None, - img_size=None, + image_size=None, ): """Interactive widget for visualizing orthogonal slices of a 3D volume. Args: - vol (np.ndarray): The 3D volume to be sliced. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. - img_width (int, optional): Width of the figure. - show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + volume (np.ndarray): The 3D volume to be sliced. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. + image_width (int, optional): Width of the figure. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. Returns: - orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. + slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. Example: ```python import qim3d vol = qim3d.examples.fly_150x256x256 - qim3d.viz.orthogonal(vol, cmap="magma") + qim3d.viz.slicer_orthogonal(vol, color_map="magma") ``` -  +  """ - if img_size: - img_height = img_size - img_width = img_size - - get_slicer_for_axis = lambda axis: slicer( - vol, - axis=axis, - cmap=cmap, - vmin=vmin, - vmax=vmax, - img_height=img_height, - img_width=img_width, - show_position=show_position, + if image_size: + image_height = image_size + image_width = image_size + + get_slicer_for_axis = lambda slice_axis: slicer( + volume, + slice_axis=slice_axis, + color_map=color_map, + value_min=value_min, + value_max=value_max, + image_height=image_height, + image_width=image_width, + display_positions=display_positions, interpolation=interpolation, ) - z_slicer = get_slicer_for_axis(axis=0) - y_slicer = get_slicer_for_axis(axis=1) - x_slicer = get_slicer_for_axis(axis=2) + z_slicer = get_slicer_for_axis(slice_axis=0) + y_slicer = get_slicer_for_axis(slice_axis=1) + x_slicer = get_slicer_for_axis(slice_axis=2) z_slicer.children[0].description = "Z" y_slicer.children[0].description = "Y" @@ -428,29 +455,29 @@ def orthogonal( return widgets.HBox([z_slicer, y_slicer, x_slicer]) -def interactive_fade_mask( - vol: np.ndarray, +def fade_mask( + volume: np.ndarray, axis: int = 0, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, ): """Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. Args: - vol (np.ndarray): The volume to apply edge fading to. + volume (np.ndarray): The volume to apply edge fading to. axis (int, optional): The axis along which to apply the fading. Defaults to 0. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None Example: ```python import qim3d vol = qim3d.examples.cement_128x128x128 - qim3d.viz.interactive_fade_mask(vol) + qim3d.viz.fade_mask(vol) ```  @@ -460,58 +487,62 @@ def interactive_fade_mask( def _slicer(position, decay_rate, ratio, geometry, invert): fig, axes = plt.subplots(1, 3, figsize=(9, 3)) - slice_img = vol[position, :, :] - # If vmin is higher than the highest value in the image ValueError is raised + slice_img = volume[position, :, :] + # If value_min is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - new_vmin = ( + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img)) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img)) + else value_max ) - axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) + axes[0].imshow( + slice_img, cmap=color_map, value_min=new_value_min, value_max=new_value_max + ) axes[0].set_title("Original") axes[0].axis("off") mask = qim3d.processing.operations.fade_mask( - np.ones_like(vol), + np.ones_like(volume), decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert, ) - axes[1].imshow(mask[position, :, :], cmap=cmap) + axes[1].imshow(mask[position, :, :], cmap=color_map) axes[1].set_title("Mask") axes[1].axis("off") - masked_vol = qim3d.processing.operations.fade_mask( - vol, + masked_volume = qim3d.processing.operations.fade_mask( + volume, decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert, ) - # If vmin is higher than the highest value in the image ValueError is raised + # If value_min is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - slice_img = masked_vol[position, :, :] - new_vmin = ( + slice_img = masked_volume[position, :, :] + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img)) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img)) + else value_max + ) + axes[2].imshow( + slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max ) - axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) axes[2].set_title("Masked") axes[2].axis("off") @@ -524,9 +555,9 @@ def interactive_fade_mask( ) position_slider = widgets.IntSlider( - value=vol.shape[0] // 2, + value=volume.shape[0] // 2, min=0, - max=vol.shape[0] - 1, + max=volume.shape[0] - 1, description="Slice", continuous_update=False, ) @@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs): viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) - fig = qim3d.viz.slices(chunk, **kwargs) + fig = qim3d.viz.slices_grid_grid(chunk, **kwargs) display(fig) - elif visualization_method == "vol": + elif visualization_method == "volume": viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) - out = qim3d.viz.vol(chunk, show=False, **kwargs) + out = qim3d.viz.volumetric(chunk, show=False, **kwargs) display(out) else: log.info(f"Invalid visualization method: {visualization_method}") @@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs): ) method_dropdown = widgets.Dropdown( - options=["slicer", "slices", "vol"], + options=["slicer", "slices", "volume"], value="slicer", description="Visualization", style={"description_width": description_width, "text_align": "left"}, @@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs): def histogram( - vol: np.ndarray, + volume: np.ndarray, bins: Union[int, str] = "auto", slice_idx: Union[int, str] = None, axis: int = 0, @@ -833,11 +864,11 @@ def histogram( ): """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. - + Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Args: - vol (np.ndarray): A 3D NumPy array representing the volume to be visualized. + volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". axis (int, optional): Axis along which to take a slice. Default is 0. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. @@ -879,24 +910,24 @@ def histogram(  """ - if not (0 <= axis < vol.ndim): - raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") + if not (0 <= axis < volume.ndim): + raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.") if slice_idx == "middle": - slice_idx = vol.shape[axis] // 2 + slice_idx = volume.shape[axis] // 2 if slice_idx: - if 0 <= slice_idx < vol.shape[axis]: - img_slice = np.take(vol, indices=slice_idx, axis=axis) + if 0 <= slice_idx < volume.shape[axis]: + img_slice = np.take(volume, indices=slice_idx, axis=axis) data = img_slice.ravel() title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" else: raise ValueError( - f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}." + f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}." ) else: - data = vol.ravel() - title = f"Intensity histogram for whole volume {vol.shape}" + data = volume.ravel() + title = f"Intensity histogram for whole volume {volume.shape}" fig, ax = plt.subplots(figsize=figsize) diff --git a/qim3d/viz/detection.py b/qim3d/viz/_detection.py similarity index 97% rename from qim3d/viz/detection.py rename to qim3d/viz/_detection.py index fcf493cd7e550aa183f09e2dbae56c9e7bc3de7c..a0926105e4c01cb99ab708d1a42c395246411d3e 100644 --- a/qim3d/viz/detection.py +++ b/qim3d/viz/_detection.py @@ -1,5 +1,5 @@ import matplotlib.pyplot as plt -from qim3d.utils.logger import log +from qim3d.utils._logger import log import numpy as np import ipywidgets as widgets from IPython.display import clear_output, display @@ -29,7 +29,7 @@ def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs): def _slicer(z_slice): clear_output(wait=True) - fig = qim3d.viz.slices( + fig = qim3d.viz.slices_grid( vol, n_slices=1, position=z_slice, diff --git a/qim3d/viz/k3d.py b/qim3d/viz/_k3d.py similarity index 85% rename from qim3d/viz/k3d.py rename to qim3d/viz/_k3d.py index c3018e081bee7fd2179bae8ae92a3a7884f00d3f..b80e9556c22ce8d2fa19806b64f1cecc1f6be625 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/_k3d.py @@ -10,17 +10,17 @@ Volumetric visualization using K3D import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import Colormap -from qim3d.utils.logger import log -from qim3d.utils.misc import downscale_img, scale_to_float16 +from qim3d.utils._logger import log +from qim3d.utils._misc import downscale_img, scale_to_float16 -def vol( +def volumetric( img, aspectmode="data", show=True, save=False, grid_visible=False, - cmap=None, + color_map='magma', constant_opacity=False, vmin=None, vmax=None, @@ -43,8 +43,8 @@ def vol( If a string is provided, it's interpreted as the file path where the HTML file will be saved. Defaults to False. grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. - cmap (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None. - constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding cmap; otherwise, the plot may appear poorly. Defaults to False. + color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None. + constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False. vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None. vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. @@ -60,7 +60,7 @@ def vol( ValueError: If `aspectmode` is not `'data'` or `'cube'`. Tip: - The function can be used for object label visualization using a `cmap` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering. + The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering. Example: Display a volume inline: @@ -69,7 +69,7 @@ def vol( import qim3d vol = qim3d.examples.bone_128x128x128 - qim3d.viz.vol(vol) + qim3d.viz.volumetric(vol) ``` <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe> @@ -78,7 +78,7 @@ def vol( ```python import qim3d vol = qim3d.examples.bone_128x128x128 - plot = qim3d.viz.vol(vol, show=False, save="plot.html") + plot = qim3d.viz.volumetric(vol, show=False, save="plot.html") ``` """ @@ -129,21 +129,21 @@ def vol( if vmax: color_range[1] = vmax - # Handle the different formats that cmap can take - if cmap: - if isinstance(cmap, str): - cmap = plt.get_cmap(cmap) # Convert to Colormap object - if isinstance(cmap, Colormap): - # Convert to the format of cmap required by k3d.volume - attr_vals = np.linspace(0.0, 1.0, num=cmap.N) - RGB_vals = cmap(np.arange(0, cmap.N))[:, :3] - cmap = np.column_stack((attr_vals, RGB_vals)).tolist() + # Handle the different formats that color_map can take + if color_map: + if isinstance(color_map, str): + color_map = plt.get_cmap(color_map) # Convert to Colormap object + if isinstance(color_map, Colormap): + # Convert to the format of color_map required by k3d.volume + attr_vals = np.linspace(0.0, 1.0, num=color_map.N) + RGB_vals = color_map(np.arange(0, color_map.N))[:, :3] + color_map = np.column_stack((attr_vals, RGB_vals)).tolist() # Default k3d.volume settings opacity_function = [] interpolation = True if constant_opacity: - # without these settings, the plot will look bad when cmap is created with qim3d.viz.colormaps.objects + # without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)] interpolation = False @@ -155,7 +155,7 @@ def vol( if aspectmode.lower() == "data" else None ), - color_map=cmap, + color_map=color_map, samples=samples, color_range=color_range, opacity_function=opacity_function, diff --git a/qim3d/viz/layers2d.py b/qim3d/viz/_layers2d.py similarity index 100% rename from qim3d/viz/layers2d.py rename to qim3d/viz/_layers2d.py diff --git a/qim3d/viz/local_thickness_.py b/qim3d/viz/_local_thickness.py similarity index 99% rename from qim3d/viz/local_thickness_.py rename to qim3d/viz/_local_thickness.py index e1195e7dfae54058f18e66e5a9a83867c6dea503..c07438eec0c6dfb06f72c80373d4698fbffec6bc 100644 --- a/qim3d/viz/local_thickness_.py +++ b/qim3d/viz/_local_thickness.py @@ -1,4 +1,4 @@ -from qim3d.utils.logger import log +from qim3d.utils._logger import log import numpy as np import matplotlib.pyplot as plt from typing import Optional, Union, Tuple diff --git a/qim3d/viz/metrics.py b/qim3d/viz/_metrics.py similarity index 99% rename from qim3d/viz/metrics.py rename to qim3d/viz/_metrics.py index 7cbd08ea71decebbd729fe6e20c6c42a50171629..53a7f3a61178534e703c545287ddcd204d26b940 100644 --- a/qim3d/viz/metrics.py +++ b/qim3d/viz/_metrics.py @@ -4,7 +4,7 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib import colormaps -from qim3d.utils.logger import log +from qim3d.utils._logger import log def plot_metrics( diff --git a/qim3d/viz/preview.py b/qim3d/viz/_preview.py similarity index 100% rename from qim3d/viz/preview.py rename to qim3d/viz/_preview.py diff --git a/qim3d/viz/structure_tensor.py b/qim3d/viz/_structure_tensor.py similarity index 99% rename from qim3d/viz/structure_tensor.py rename to qim3d/viz/_structure_tensor.py index 21b0e24d35e4f2302686b8eae343f488449384b0..63d141c41ae9dd365e89a9d82c959d08173d76a6 100644 --- a/qim3d/viz/structure_tensor.py +++ b/qim3d/viz/_structure_tensor.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec import ipywidgets as widgets import logging -from qim3d.utils.logger import log +from qim3d.utils._logger import log previous_logging_level = logging.getLogger().getEffectiveLevel() diff --git a/qim3d/viz/colormaps/__init__.py b/qim3d/viz/colormaps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9807422f1fec28680c17a383a7d7377096220d6c --- /dev/null +++ b/qim3d/viz/colormaps/__init__.py @@ -0,0 +1,2 @@ +from ._segmentation import segmentation +from ._qim_colors import qim \ No newline at end of file diff --git a/qim3d/viz/colormaps/_qim_colors.py b/qim3d/viz/colormaps/_qim_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..429151e1664ac911a2d07d9a8c27059153e55fec --- /dev/null +++ b/qim3d/viz/colormaps/_qim_colors.py @@ -0,0 +1,24 @@ +from matplotlib import colormaps +from matplotlib.colors import LinearSegmentedColormap + + +qim = LinearSegmentedColormap.from_list( + "qim", + [ + (0.6, 0.0, 0.0), # 990000 + (1.0, 0.6, 0.0), # ff9900 + ], +) +""" +Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'``` + +Example: + ```python + + import qim3d + + display(qim3d.viz.colormaps.qim) + ``` +  +""" +colormaps.register(qim) diff --git a/qim3d/viz/colormaps.py b/qim3d/viz/colormaps/_segmentation.py similarity index 81% rename from qim3d/viz/colormaps.py rename to qim3d/viz/colormaps/_segmentation.py index 807a083885544eb2901ccb86261f1fd834b12100..b6ee2f8834b39eaf166da8f1471d9e36d44c2485 100644 --- a/qim3d/viz/colormaps.py +++ b/qim3d/viz/colormaps/_segmentation.py @@ -7,7 +7,6 @@ from typing import Union, Tuple import numpy as np import math from matplotlib.colors import LinearSegmentedColormap -from matplotlib import colormaps def rearrange_colors(randRGBcolors_old, min_dist=0.5): @@ -33,8 +32,8 @@ def rearrange_colors(randRGBcolors_old, min_dist=0.5): return randRGBcolors_new -def objects( - nlabels: int, +def segmentation( + num_labels: int, style: str = "bright", first_color_background: bool = True, last_color_background: bool = False, @@ -46,7 +45,7 @@ def objects( Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks Args: - nlabels (int): Number of labels (size of colormap). + num_labels (int): Number of labels (size of colormap). style (str, optional): 'bright' for strong colors, 'soft' for pastel colors, 'earth' for yellow/green/blue colors, 'ocean' for blue/purple/pink colors. Defaults to 'bright'. first_color_background (bool, optional): If True, the first color is used as background. Defaults to True. last_color_background (bool, optional): If True, the last color is used as background. Defaults to False. @@ -62,10 +61,10 @@ def objects( ```python import qim3d - cmap_bright = qim3d.viz.colormaps.objects(nlabels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7) - cmap_soft = qim3d.viz.colormaps.objects(nlabels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2) - cmap_earth = qim3d.viz.colormaps.objects(nlabels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8) - cmap_ocean = qim3d.viz.colormaps.objects(nlabels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9) + cmap_bright = qim3d.viz.colormaps.objects(num_labels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7) + cmap_soft = qim3d.viz.colormaps.objects(num_labels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2) + cmap_earth = qim3d.viz.colormaps.objects(num_labels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8) + cmap_ocean = qim3d.viz.colormaps.objects(num_labels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9) display(cmap_bright) display(cmap_soft) @@ -89,7 +88,7 @@ def objects( Tip: It can be easily used when calling visualization functions as ```python - qim3d.viz.slices(segmented_volume, cmap = 'objects') + qim3d.viz.slices_grid(segmented_volume, cmap = 'objects') ``` which automatically detects number of unique classes and creates the colormap object with defualt arguments. @@ -116,8 +115,8 @@ def objects( f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.' ) - # Add one to nlabels to include the background color - nlabels += 1 + # Add one to num_labels to include the background color + num_labels += 1 # Create a new random generator, to locally set seed rng = np.random.default_rng(seed) @@ -130,7 +129,7 @@ def objects( rng.uniform(low=0.4, high=1), rng.uniform(low=0.9, high=1), ) - for i in range(nlabels) + for i in range(num_labels) ] # Convert HSV list to RGB @@ -150,7 +149,7 @@ def objects( rng.uniform(low=low, high=high), rng.uniform(low=low, high=high), ) - for i in range(nlabels) + for i in range(num_labels) ] # Generate color map for earthy colors, based on LAB @@ -161,7 +160,7 @@ def objects( rng.uniform(low=-120, high=70), rng.uniform(low=-70, high=70), ) - for i in range(nlabels) + for i in range(num_labels) ] # Convert LAB list to RGB @@ -177,7 +176,7 @@ def objects( rng.uniform(low=-128, high=160), rng.uniform(low=-128, high=0), ) - for i in range(nlabels) + for i in range(num_labels) ] # Convert LAB list to RGB @@ -196,28 +195,8 @@ def objects( randRGBcolors[-1] = background_color # Create colormap - objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=nlabels) + objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=num_labels) return objects -qim = LinearSegmentedColormap.from_list( - "qim", - [ - (0.6, 0.0, 0.0), # 990000 - (1.0, 0.6, 0.0), # ff9900 - ], -) -""" -Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'``` - -Example: - ```python - - import qim3d - - display(qim3d.viz.colormaps.qim) - ``` -  -""" -colormaps.register(qim) diff --git a/qim3d/viz/itk_vtk_viewer/run.py b/qim3d/viz/itk_vtk_viewer/run.py index 8ff029fc63f29df8e3dbffd81ec29c81ce74c52b..8fa5e6b8edc85573ccbb78e19c7838c08a8d6d23 100644 --- a/qim3d/viz/itk_vtk_viewer/run.py +++ b/qim3d/viz/itk_vtk_viewer/run.py @@ -3,7 +3,7 @@ import platform from pathlib import Path import os import qim3d.utils -from qim3d.utils.logger import log +from qim3d.utils._logger import log # from .helpers import get_qim_dir, get_nvm_dir, get_viewer_binaries, get_viewer_dir, get_node_binaries_dir, NotInstalledError, SOURCE_FNM from .helpers import *