diff --git a/.gitignore b/.gitignore index 47cea172a9aaddf34d7d9cdba081ea79a2f9718d..d0bfaecc64dcd278ebdc9b47371ce85bbaaf53ef 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ build/ .idea/ .cache/ .pytest_cache/ +.ruff_cache/ *.swp *.swo *.pyc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe60fae38acc23e797d0e808c0816cf22bbddcd0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: detect-private-key + - id: check-added-large-files + - id: check-docstring-first + - id: debug-statements + - id: double-quote-string-fixer + - id: name-tests-test + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 + hooks: + # Run the formatter and fix code styling + - id: ruff-format + + # Run the linter and fix what is possible + - id: ruff + args: ['--fix'] \ No newline at end of file diff --git a/docs/assets/screenshots/interactive_thresholding.gif b/docs/assets/screenshots/interactive_thresholding.gif new file mode 100644 index 0000000000000000000000000000000000000000..f7ec78d06c3608954446a03f0e6c9ce88dcaf59c Binary files /dev/null and b/docs/assets/screenshots/interactive_thresholding.gif differ diff --git a/docs/assets/screenshots/pygel3d_visualization.png b/docs/assets/screenshots/pygel3d_visualization.png new file mode 100644 index 0000000000000000000000000000000000000000..93d87dde129906bfaa1944ded8f5e0df7783251c Binary files /dev/null and b/docs/assets/screenshots/pygel3d_visualization.png differ diff --git a/docs/assets/screenshots/viz-line_profile.gif b/docs/assets/screenshots/viz-line_profile.gif new file mode 100644 index 0000000000000000000000000000000000000000..74b322faa292ad685d1d304c1e0f276114064d1a Binary files /dev/null and b/docs/assets/screenshots/viz-line_profile.gif differ diff --git a/docs/doc/cli/cli.md b/docs/doc/cli/cli.md index 58fae7e2253830a9ab3156712fe7fac14d2c3f24..123fee26161db2f28737d6302cccb10520b9d88a 100644 --- a/docs/doc/cli/cli.md +++ b/docs/doc/cli/cli.md @@ -44,7 +44,7 @@ The command line interface allows you to run graphical user interfaces directly !!! Example - Here's an example of how to open the [Data Explorer](gui.md#qim3d.gui.data_explorer) + Here's an example of how to open the [Data Explorer](../gui/gui.md#qim3d.gui.data_explorer) ``` title="Command" qim3d gui --data-explorer diff --git a/docs/doc/data_handling/io.md b/docs/doc/data_handling/io.md index 82c984bbced233b2a35836bdd072e3c73aa8bf8b..607d81ede891e5700eb593ab7ddfa876a6037ac0 100644 --- a/docs/doc/data_handling/io.md +++ b/docs/doc/data_handling/io.md @@ -8,5 +8,5 @@ - Downloader - export_ome_zarr - import_ome_zarr - - save_mesh - - load_mesh \ No newline at end of file + - load_mesh + - save_mesh \ No newline at end of file diff --git a/docs/doc/gui/gui.md b/docs/doc/gui/gui.md index 681041fc4f2515ce6c89112b07c12d1d407f44f3..4bc2de99a9e81559313a703df31a1ef40be4c811 100644 --- a/docs/doc/gui/gui.md +++ b/docs/doc/gui/gui.md @@ -21,7 +21,7 @@ The `qim3d` library provides a set of custom made GUIs that ease the interaction ``` In general, the GUIs can be launched directly from the command line. -For details see [here](cli.md#qim3d-gui). +For details see [here](../cli/cli.md#qim3d-gui). ::: qim3d.gui.data_explorer options: diff --git a/docs/doc/releases/releases.md b/docs/doc/releases/releases.md index 8174ea97c706b522de00460a71f1cdf7c59da2f8..bdb5c2a8007d46a3622a5215bfb2231752d45dbc 100644 --- a/docs/doc/releases/releases.md +++ b/docs/doc/releases/releases.md @@ -11,7 +11,7 @@ hide: Below, you'll find details about the version history of `qim3d`. -Remember to keep your pip installation [up to date](index.md/#get-the-latest-version) so that you have the latest features! +Remember to keep your pip installation [up to date](../../index.md/#get-the-latest-version) so that you have the latest features! ### v1.0.0 (21/01/2025) diff --git a/docs/doc/visualization/viz.md b/docs/doc/visualization/viz.md index 0f22a927ac31aaab6a0940fdbaa075caab389d0f..2cdb54d6130c8ec7c47bc856a79bc75f7bcabcaa 100644 --- a/docs/doc/visualization/viz.md +++ b/docs/doc/visualization/viz.md @@ -23,6 +23,8 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - plot_cc - colormaps - fade_mask + - line_profile + - threshold ::: qim3d.viz.colormaps options: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1c7971317a569910a6347357db6ffd10ae0deb63 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,92 @@ +# See list of rules here: https://docs.astral.sh/ruff/rules/ + +[tool.ruff] +line-length = 88 +indent-width = 4 + +[tool.ruff.lint] +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +select = [ + "F", + "E", # Errors + "W", # Warnings + "I", # Imports + "N", # Naming + "D", # Documentation + "UP", # Upgrades + "YTT", + "ANN", + "ASYNC", + "S", + "BLE", + "B", + "A", + "COM", + "C4", + "T10", + "DJ", + "EM", + "EXE", + "ISC", + "LOG", + "PIE", + "PYI", + "PT", + "RSE", + "SLF", + "SLOT", + "SIM", + "TID", + "TCH", + "INT", + "ERA", + "PGH", +] + +ignore = [ + "F821", + "F841", + "E501", + "E731", + "D100", + "D101", + "D107", + "D201", + "D202", + "D205", + "D211", + "D212", + "D401", + "D407", + "ANN002", + "ANN003", + "ANN101", + "ANN201", + "ANN204", + "S101", + "S301", + "S311", + "S507", + "S603", + "S605", + "S607", + "B008", + "B026", + "B028", + "B905", + "W291", + "W293", + "COM812", + "ISC001", + "SIM113", +] + +[tool.ruff.format] +# Use single quotes for strings +quote-style = "single" \ No newline at end of file diff --git a/qim3d/__init__.py b/qim3d/__init__.py index e27175b8a3eaa9ae11eb98623937e61716dc1d85..5857fad7f1ed0bb1493b9a622be948cac277aea2 100644 --- a/qim3d/__init__.py +++ b/qim3d/__init__.py @@ -1,6 +1,7 @@ -"""qim3d: A Python package for 3D image processing and visualization. +""" +qim3d: A Python package for 3D image processing and visualization. -The qim3d library is designed to make it easier to work with 3D imaging data in Python. +The qim3d library is designed to make it easier to work with 3D imaging data in Python. It offers a range of features, including data loading and manipulation, image processing and filtering, visualization of 3D data, and analysis of imaging results. @@ -8,13 +9,14 @@ Documentation available at https://platform.qim.dk/qim3d/ """ -__version__ = "1.0.0" +__version__ = '1.1.0' import importlib as _importlib class _LazyLoader: + """Lazy loader to load submodules only when they are accessed""" def __init__(self, module_name): @@ -48,7 +50,7 @@ _submodules = [ 'mesh', 'features', 'operations', - 'detection' + 'detection', ] # Creating lazy loaders for each submodule diff --git a/qim3d/cli/__init__.py b/qim3d/cli/__init__.py index 5ba85ad2cd0a31833cae9b2b9cf2e331933492d1..33bbfe023355a0d8c0390bf7bacc9bab91177ed7 100644 --- a/qim3d/cli/__init__.py +++ b/qim3d/cli/__init__.py @@ -1,9 +1,11 @@ import argparse -import webbrowser +import os import platform +import webbrowser + import outputformat as ouf + import qim3d -import os QIM_TITLE = ouf.rainbow( rf""" @@ -16,126 +18,123 @@ QIM_TITLE = ouf.rainbow( """, return_str=True, - cmap="hot", + cmap='hot', ) def parse_tuple(arg): # Remove parentheses if they are included and split by comma - return tuple(map(int, arg.strip("()").split(","))) + return tuple(map(int, arg.strip('()').split(','))) def main(): - parser = argparse.ArgumentParser(description="qim3d command-line interface.") - subparsers = parser.add_subparsers(title="Subcommands", dest="subcommand") + parser = argparse.ArgumentParser(description='qim3d command-line interface.') + subparsers = parser.add_subparsers(title='Subcommands', dest='subcommand') # GUIs - gui_parser = subparsers.add_parser("gui", help="Graphical User Interfaces.") + gui_parser = subparsers.add_parser('gui', help='Graphical User Interfaces.') gui_parser.add_argument( - "--data-explorer", action="store_true", help="Run data explorer." - ) - gui_parser.add_argument("--iso3d", action="store_true", help="Run iso3d.") - gui_parser.add_argument( - "--annotation-tool", action="store_true", help="Run annotation tool." + '--data-explorer', action='store_true', help='Run data explorer.' ) + gui_parser.add_argument('--iso3d', action='store_true', help='Run iso3d.') gui_parser.add_argument( - "--local-thickness", action="store_true", help="Run local thickness tool." + '--annotation-tool', action='store_true', help='Run annotation tool.' ) gui_parser.add_argument( - "--layers", action="store_true", help="Run Layers." + '--local-thickness', action='store_true', help='Run local thickness tool.' ) - gui_parser.add_argument("--host", default="0.0.0.0", help="Desired host.") + gui_parser.add_argument('--layers', action='store_true', help='Run Layers.') + gui_parser.add_argument('--host', default='0.0.0.0', help='Desired host.') gui_parser.add_argument( - "--platform", action="store_true", help="Use QIM platform address" + '--platform', action='store_true', help='Use QIM platform address' ) gui_parser.add_argument( - "--no-browser", action="store_true", help="Do not launch browser." + '--no-browser', action='store_true', help='Do not launch browser.' ) # Viz - viz_parser = subparsers.add_parser("viz", help="Volumetric visualization.") - viz_parser.add_argument("source", help="Path to the image file") + viz_parser = subparsers.add_parser('viz', help='Volumetric visualization.') + viz_parser.add_argument('source', help='Path to the image file') viz_parser.add_argument( - "-m", - "--method", + '-m', + '--method', type=str, - metavar="METHOD", - default="itk-vtk", - help="Which method is used to display file.", + metavar='METHOD', + default='itk-vtk', + help='Which method is used to display file.', ) viz_parser.add_argument( - "--destination", default="k3d.html", help="Path to save html file." + '--destination', default='k3d.html', help='Path to save html file.' ) viz_parser.add_argument( - "--no-browser", action="store_true", help="Do not launch browser." + '--no-browser', action='store_true', help='Do not launch browser.' ) # Preview preview_parser = subparsers.add_parser( - "preview", help="Preview of the image in CLI" + 'preview', help='Preview of the image in CLI' ) preview_parser.add_argument( - "filename", + 'filename', type=str, - metavar="FILENAME", - help="Path to image that will be displayed", + metavar='FILENAME', + help='Path to image that will be displayed', ) preview_parser.add_argument( - "--slice", + '--slice', type=int, - metavar="S", + metavar='S', default=None, - help="Specifies which slice of the image will be displayed.\nDefaults to middle slice. If number exceeds number of slices, last slice will be displayed.", + help='Specifies which slice of the image will be displayed.\nDefaults to middle slice. If number exceeds number of slices, last slice will be displayed.', ) preview_parser.add_argument( - "--axis", + '--axis', type=int, - metavar="AX", + metavar='AX', default=0, - help="Specifies from which axis will be the slice taken. Defaults to 0.", + help='Specifies from which axis will be the slice taken. Defaults to 0.', ) preview_parser.add_argument( - "--resolution", + '--resolution', type=int, - metavar="RES", + metavar='RES', default=80, - help="Resolution of displayed image. Defaults to 80.", + help='Resolution of displayed image. Defaults to 80.', ) preview_parser.add_argument( - "--absolute_values", - action="store_false", - help="By default set the maximum value to be 255 so the contrast is strong. This turns it off.", + '--absolute_values', + action='store_false', + help='By default set the maximum value to be 255 so the contrast is strong. This turns it off.', ) # File Convert convert_parser = subparsers.add_parser( - "convert", - help="Convert files to different formats without loading the entire file into memory", + 'convert', + help='Convert files to different formats without loading the entire file into memory', ) convert_parser.add_argument( - "input_path", + 'input_path', type=str, - metavar="Input path", - help="Path to image that will be converted", + metavar='Input path', + help='Path to image that will be converted', ) convert_parser.add_argument( - "output_path", + 'output_path', type=str, - metavar="Output path", - help="Path to save converted image", + metavar='Output path', + help='Path to save converted image', ) convert_parser.add_argument( - "--chunks", + '--chunks', type=parse_tuple, - metavar="Chunk shape", + metavar='Chunk shape', default=(64, 64, 64), - help="Chunk size for the zarr file. Defaults to (64, 64, 64).", + help='Chunk size for the zarr file. Defaults to (64, 64, 64).', ) args = parser.parse_args() - if args.subcommand == "gui": - + if args.subcommand == 'gui': arghost = args.host inbrowser = not args.no_browser # Should automatically open in browser @@ -152,7 +151,7 @@ def main(): interface_class = qim3d.gui.layers2d.Interface else: print( - "Please select a tool by choosing one of the following flags:\n\t--data-explorer\n\t--iso3d\n\t--annotation-tool\n\t--local-thickness" + 'Please select a tool by choosing one of the following flags:\n\t--data-explorer\n\t--iso3d\n\t--annotation-tool\n\t--local-thickness' ) return interface = ( @@ -164,31 +163,27 @@ def main(): else: interface.launch(inbrowser=inbrowser, force_light_mode=False) - elif args.subcommand == "viz": - - if args.method == "itk-vtk": - + elif args.subcommand == 'viz': + if args.method == 'itk-vtk': # We need the full path to the file for the viewer current_dir = os.getcwd() full_path = os.path.normpath(os.path.join(current_dir, args.source)) - qim3d.viz.itk_vtk(full_path, open_browser = not args.no_browser) + qim3d.viz.itk_vtk(full_path, open_browser=not args.no_browser) - - elif args.method == "k3d": + elif args.method == 'k3d': volume = qim3d.io.load(str(args.source)) - print("\nGenerating k3d plot...") + print('\nGenerating k3d plot...') qim3d.viz.volumetric(volume, show=False, save=str(args.destination)) - print(f"Done, plot available at <{args.destination}>") + print(f'Done, plot available at <{args.destination}>') if not args.no_browser: - print("Opening in default browser...") + print('Opening in default browser...') webbrowser.open_new_tab(args.destination) else: raise NotImplementedError( f"Method '{args.method}' is not valid. Try 'k3d' or default 'itk-vtk-viewer'" ) - elif args.subcommand == "preview": - + elif args.subcommand == 'preview': image = qim3d.io.load(args.filename) qim3d.viz.image_preview( @@ -199,22 +194,21 @@ def main(): relative_intensity=args.absolute_values, ) - elif args.subcommand == "convert": - + elif args.subcommand == 'convert': qim3d.io.convert(args.input_path, args.output_path, chunk_shape=args.chunks) elif args.subcommand is None: print(QIM_TITLE) welcome_text = ( - "\nqim3d is a Python package for 3D image processing and visualization.\n" + '\nqim3d is a Python package for 3D image processing and visualization.\n' f"For more information, please visit {ouf.c('https://platform.qim.dk/qim3d/', color='orange', return_str=True)}\n" - " \n" + ' \n' "For more information on each subcommand, type 'qim3d <subcommand> --help'.\n" ) print(welcome_text) parser.print_help() - print("\n") + print('\n') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/qim3d/detection/__init__.py b/qim3d/detection/__init__.py index 302d7909203cfc7e6ccf1223d0a7b1e5c29be6bd..6a5aa44a69043d33da008198c7442b763d20f4bc 100644 --- a/qim3d/detection/__init__.py +++ b/qim3d/detection/__init__.py @@ -1 +1 @@ -from qim3d.detection._common_detection_methods import * \ No newline at end of file +from qim3d.detection._common_detection_methods import * diff --git a/qim3d/detection/_common_detection_methods.py b/qim3d/detection/_common_detection_methods.py index 3049baef5dbfee1e9f053cd19e8e8c2f79a01a0a..17ceb922dc5fae4fd1e72f2a8e49a5473f9d58b9 100644 --- a/qim3d/detection/_common_detection_methods.py +++ b/qim3d/detection/_common_detection_methods.py @@ -1,13 +1,15 @@ -""" Blob detection using Difference of Gaussian (DoG) method """ +"""Blob detection using Difference of Gaussian (DoG) method""" import numpy as np + from qim3d.utils._logger import log -__all__ = ["blobs"] +__all__ = ['blobs'] + def blobs( vol: np.ndarray, - background: str = "dark", + background: str = 'dark', min_sigma: float = 1, max_sigma: float = 50, sigma_ratio: float = 1.6, @@ -56,18 +58,19 @@ def blobs( # Visualize detected blobs qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue') ``` -  +  ```python # Visualize binary binary_volume qim3d.viz.slicer(binary_volume) ```  + """ from skimage.feature import blob_dog - if background == "bright": - log.info("Bright background selected, volume will be inverted.") + if background == 'bright': + log.info('Bright background selected, volume will be inverted.') vol = np.invert(vol) blobs = blob_dog( @@ -109,8 +112,8 @@ def blobs( (x_indices - x) ** 2 + (y_indices - y) ** 2 + (z_indices - z) ** 2 ) - binary_volume[z_start:z_end, y_start:y_end, x_start:x_end][ - dist <= radius - ] = True + binary_volume[z_start:z_end, y_start:y_end, x_start:x_end][dist <= radius] = ( + True + ) return blobs, binary_volume diff --git a/qim3d/examples/__init__.py b/qim3d/examples/__init__.py index fe15a92e8af1eb9a461c8196084ad358fe35b840..0402fb62deac2faf56db7e1a712853c0a0e379ab 100644 --- a/qim3d/examples/__init__.py +++ b/qim3d/examples/__init__.py @@ -1,16 +1,17 @@ -""" Example images for testing and demonstration purposes. """ +"""Example images for testing and demonstration purposes.""" from pathlib import Path as _Path -from qim3d.utils._logger import log as _log + from qim3d.io import load as _load +from qim3d.utils._logger import log as _log # Save the original log level and set to ERROR # to suppress the log messages during loading _original_log_level = _log.level -_log.setLevel("ERROR") +_log.setLevel('ERROR') # Load image examples -for _file_path in _Path(__file__).resolve().parent.glob("*.tif"): +for _file_path in _Path(__file__).resolve().parent.glob('*.tif'): globals().update({_file_path.stem: _load(_file_path, progress_bar=False)}) # Restore the original log level diff --git a/qim3d/features/__init__.py b/qim3d/features/__init__.py index 349cedcb8d92943f5148029aeabc14c5be9dc5bb..5bf2fc9049c8058a24d1a12260971b079d04bd84 100644 --- a/qim3d/features/__init__.py +++ b/qim3d/features/__init__.py @@ -1 +1 @@ -from ._common_features_methods import volume, area, sphericity +from ._common_features_methods import area, sphericity, volume diff --git a/qim3d/features/_common_features_methods.py b/qim3d/features/_common_features_methods.py index b448b1d300b4985231262b0428eee158611cd453..4a547b6c30c730e7c0d0ab583f3f9995d91df747 100644 --- a/qim3d/features/_common_features_methods.py +++ b/qim3d/features/_common_features_methods.py @@ -1,23 +1,19 @@ import numpy as np -import qim3d.processing +import qim3d from qim3d.utils._logger import log -import trimesh import qim3d +from pygel3d import hmesh - -def volume(obj: np.ndarray|trimesh.Trimesh, - **mesh_kwargs - ) -> float: +def volume(obj: np.ndarray|hmesh.Manifold) -> float: """ - Compute the volume of a 3D volume or mesh. + Compute the volume of a 3D mesh using the Pygel3D library. Args: - obj (np.ndarray or trimesh.Trimesh): Either a np.ndarray volume or a mesh object of type trimesh.Trimesh. - **mesh_kwargs (Any): Additional arguments for mesh creation if the input is a volume. + obj (numpy.ndarray or pygel3d.hmesh.Manifold): Either a np.ndarray volume or a mesh object of type pygel3d.hmesh.Manifold. Returns: volume (float): The volume of the object. - + Example: Compute volume from a mesh: ```python @@ -27,8 +23,8 @@ def volume(obj: np.ndarray|trimesh.Trimesh, mesh = qim3d.io.load_mesh('path/to/mesh.obj') # Compute the volume of the mesh - vol = qim3d.features.volume(mesh) - print('Volume:', vol) + volume = qim3d.features.volume(mesh) + print(f'Volume: {volume}') ``` Compute volume from a np.ndarray: @@ -37,35 +33,30 @@ def volume(obj: np.ndarray|trimesh.Trimesh, # Generate a 3D blob synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) - synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) # Compute the volume of the blob - volume = qim3d.features.volume(synthetic_blob, level=0.5) - volume = qim3d.features.volume(synthetic_blob, level=0.5) - print('Volume:', volume) + volume = qim3d.features.volume(synthetic_blob) + print(f'Volume: {volume}') ``` """ + if isinstance(obj, np.ndarray): log.info("Converting volume to mesh.") - obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) + obj = qim3d.mesh.from_volume(obj) - return obj.volume + return hmesh.volume(obj) - -def area(obj: np.ndarray|trimesh.Trimesh, - **mesh_kwargs - ) -> float: +def area(obj: np.ndarray|hmesh.Manifold) -> float: """ - Compute the surface area of a 3D volume or mesh. + Compute the surface area of a 3D mesh using the Pygel3D library. Args: - obj (np.ndarray or trimesh.Trimesh): Either a np.ndarray volume or a mesh object of type trimesh.Trimesh. - **mesh_kwargs (Any): Additional arguments for mesh creation if the input is a volume. + obj (numpy.ndarray or pygel3d.hmesh.Manifold): Either a np.ndarray volume or a mesh object of type pygel3d.hmesh.Manifold. Returns: - area (float): The surface area of the object. - + area (float): The surface area of the object. + Example: Compute area from a mesh: ```python @@ -76,8 +67,7 @@ def area(obj: np.ndarray|trimesh.Trimesh, # Compute the surface area of the mesh area = qim3d.features.area(mesh) - area = qim3d.features.area(mesh) - print(f"Area: {area}") + print(f'Area: {area}') ``` Compute area from a np.ndarray: @@ -86,39 +76,30 @@ def area(obj: np.ndarray|trimesh.Trimesh, # Generate a 3D blob synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) - synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) # Compute the surface area of the blob - volume = qim3d.features.area(synthetic_blob, level=0.5) - volume = qim3d.features.area(synthetic_blob, level=0.5) - print('Area:', volume) + area = qim3d.features.area(synthetic_blob) + print(f'Area: {area}') ``` + """ + if isinstance(obj, np.ndarray): log.info("Converting volume to mesh.") - obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) - obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) - - return obj.area + obj = qim3d.mesh.from_volume(obj) + return hmesh.area(obj) -def sphericity(obj: np.ndarray|trimesh.Trimesh, - **mesh_kwargs - ) -> float: +def sphericity(obj: np.ndarray|hmesh.Manifold) -> float: """ - Compute the sphericity of a 3D volume or mesh. - - Sphericity is a measure of how spherical an object is. It is defined as the ratio - of the surface area of a sphere with the same volume as the object to the object's - actual surface area. + Compute the sphericity of a 3D mesh using the Pygel3D library. Args: - obj (np.ndarray or trimesh.Trimesh): Either a np.ndarray volume or a mesh object of type trimesh.Trimesh. - **mesh_kwargs (Any): Additional arguments for mesh creation if the input is a volume. + obj (numpy.ndarray or pygel3d.hmesh.Manifold): Either a np.ndarray volume or a mesh object of type pygel3d.hmesh.Manifold. Returns: - sphericity (float): A float value representing the sphericity of the object. - + sphericity (float): The sphericity of the object. + Example: Compute sphericity from a mesh: ```python @@ -129,7 +110,7 @@ def sphericity(obj: np.ndarray|trimesh.Trimesh, # Compute the sphericity of the mesh sphericity = qim3d.features.sphericity(mesh) - sphericity = qim3d.features.sphericity(mesh) + print(f'Sphericity: {sphericity}') ``` Compute sphericity from a np.ndarray: @@ -138,32 +119,25 @@ def sphericity(obj: np.ndarray|trimesh.Trimesh, # Generate a 3D blob synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) - synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) # Compute the sphericity of the blob - sphericity = qim3d.features.sphericity(synthetic_blob, level=0.5) - sphericity = qim3d.features.sphericity(synthetic_blob, level=0.5) + sphericity = qim3d.features.sphericity(synthetic_blob) + print(f'Sphericity: {sphericity}') ``` - !!! info "Limitations due to pixelation" - Sphericity is particularly sensitive to the resolution of the mesh, as it directly impacts the accuracy of surface area and volume calculations. - Since the mesh is generated from voxel-based 3D volume data, the discrete nature of the voxels leads to pixelation effects that reduce the precision of sphericity measurements. - Higher resolution meshes may mitigate these errors but often at the cost of increased computational demands. """ + if isinstance(obj, np.ndarray): log.info("Converting volume to mesh.") - obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) - obj = qim3d.mesh.from_volume(obj, **mesh_kwargs) + obj = qim3d.mesh.from_volume(obj) volume = qim3d.features.volume(obj) area = qim3d.features.area(obj) - volume = qim3d.features.volume(obj) - area = qim3d.features.area(obj) if area == 0: - log.warning("Surface area is zero, sphericity is undefined.") + log.warning('Surface area is zero, sphericity is undefined.') return np.nan sphericity = (np.pi ** (1 / 3) * (6 * volume) ** (2 / 3)) / area - log.info(f"Sphericity: {sphericity}") - return sphericity + # log.info(f"Sphericity: {sphericity}") + return sphericity \ No newline at end of file diff --git a/qim3d/filters/__init__.py b/qim3d/filters/__init__.py index 8a166fca5343a186b51cbe33f0af5144696c7237..0bada798b8d0824dcebc2a8855720fbb5d3d46d4 100644 --- a/qim3d/filters/__init__.py +++ b/qim3d/filters/__init__.py @@ -1 +1 @@ -from ._common_filter_methods import * \ No newline at end of file +from ._common_filter_methods import * diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py index 3b6ded2733d9c77c3e9f487919b3c8a1984eed6a..c550fef0b1ca7a2f1ffc74cd3dd19af5c678baa6 100644 --- a/qim3d/filters/_common_filter_methods.py +++ b/qim3d/filters/_common_filter_methods.py @@ -1,43 +1,40 @@ """Provides filter functions and classes for image processing""" -from typing import Type, Union +from typing import Type +import dask.array as da +import dask_image.ndfilters as dask_ndfilters import numpy as np from scipy import ndimage from skimage import morphology -import dask.array as da -import dask_image.ndfilters as dask_ndfilters from qim3d.utils import log __all__ = [ - "FilterBase", - "Gaussian", - "Median", - "Maximum", - "Minimum", - "Pipeline", - "Tophat", - "gaussian", - "median", - "maximum", - "minimum", - "tophat", + 'FilterBase', + 'Gaussian', + 'Median', + 'Maximum', + 'Minimum', + 'Pipeline', + 'Tophat', + 'gaussian', + 'median', + 'maximum', + 'minimum', + 'tophat', ] class FilterBase: - def __init__(self, - *args, - dask: bool = False, - chunks: str = "auto", - **kwargs): + def __init__(self, *args, dask: bool = False, chunks: str = 'auto', **kwargs): """ Base class for image filters. Args: *args: Additional positional arguments for filter initialization. **kwargs: Additional keyword arguments for filter initialization. + """ self.args = args self.dask = dask @@ -54,6 +51,7 @@ class Gaussian(FilterBase): sigma (float): Standard deviation for Gaussian kernel. *args: Additional arguments. **kwargs: Additional keyword arguments. + """ super().__init__(*args, **kwargs) self.sigma = sigma @@ -67,14 +65,22 @@ class Gaussian(FilterBase): Returns: The filtered image or volume. + """ return gaussian( - input, sigma=self.sigma, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs + input, + sigma=self.sigma, + dask=self.dask, + chunks=self.chunks, + *self.args, + **self.kwargs, ) class Median(FilterBase): - def __init__(self, size: float = None, footprint: np.ndarray = None, *args, **kwargs): + def __init__( + self, size: float = None, footprint: np.ndarray = None, *args, **kwargs + ): """ Median filter initialization. @@ -83,6 +89,7 @@ class Median(FilterBase): footprint (np.ndarray, optional): The structuring element for filtering. *args: Additional arguments. **kwargs: Additional keyword arguments. + """ if size is None and footprint is None: raise ValueError("Either 'size' or 'footprint' must be provided.") @@ -99,12 +106,22 @@ class Median(FilterBase): Returns: The filtered image or volume. + """ - return median(vol=input, size=self.size, footprint=self.footprint, dask=self.dask, chunks=self.chunks, **self.kwargs) + return median( + vol=input, + size=self.size, + footprint=self.footprint, + dask=self.dask, + chunks=self.chunks, + **self.kwargs, + ) class Maximum(FilterBase): - def __init__(self, size: float = None, footprint: np.ndarray = None, *args, **kwargs): + def __init__( + self, size: float = None, footprint: np.ndarray = None, *args, **kwargs + ): """ Maximum filter initialization. @@ -113,6 +130,7 @@ class Maximum(FilterBase): footprint (np.ndarray, optional): The structuring element for filtering. *args: Additional arguments. **kwargs: Additional keyword arguments. + """ if size is None and footprint is None: raise ValueError("Either 'size' or 'footprint' must be provided.") @@ -129,12 +147,22 @@ class Maximum(FilterBase): Returns: The filtered image or volume. + """ - return maximum(vol=input, size=self.size, footprint=self.footprint, dask=self.dask, chunks=self.chunks, **self.kwargs) + return maximum( + vol=input, + size=self.size, + footprint=self.footprint, + dask=self.dask, + chunks=self.chunks, + **self.kwargs, + ) class Minimum(FilterBase): - def __init__(self, size: float = None, footprint: np.ndarray = None, *args, **kwargs): + def __init__( + self, size: float = None, footprint: np.ndarray = None, *args, **kwargs + ): """ Minimum filter initialization. @@ -143,6 +171,7 @@ class Minimum(FilterBase): footprint (np.ndarray, optional): The structuring element for filtering. *args: Additional arguments. **kwargs: Additional keyword arguments. + """ if size is None and footprint is None: raise ValueError("Either 'size' or 'footprint' must be provided.") @@ -159,8 +188,16 @@ class Minimum(FilterBase): Returns: The filtered image or volume. + """ - return minimum(vol=input, size=self.size, footprint=self.footprint, dask=self.dask, chunks=self.chunks, **self.kwargs) + return minimum( + vol=input, + size=self.size, + footprint=self.footprint, + dask=self.dask, + chunks=self.chunks, + **self.kwargs, + ) class Tophat(FilterBase): @@ -173,11 +210,13 @@ class Tophat(FilterBase): Returns: The filtered image or volume. + """ return tophat(input, dask=self.dask, **self.kwargs) class Pipeline: + """ Example: ```python @@ -216,6 +255,7 @@ class Pipeline: Args: *args: Variable number of filter instances to be applied sequentially. + """ self.filters = {} @@ -232,13 +272,14 @@ class Pipeline: Raises: AssertionError: If `fn` is not an instance of the FilterBase class. + """ if not isinstance(fn, FilterBase): filter_names = [ subclass.__name__ for subclass in FilterBase.__subclasses__() ] raise AssertionError( - f"filters should be instances of one of the following classes: {filter_names}" + f'filters should be instances of one of the following classes: {filter_names}' ) self.filters[name] = fn @@ -248,7 +289,7 @@ class Pipeline: Args: fn (FilterBase): An instance of a FilterBase subclass to be appended. - + Example: ```python import qim3d @@ -262,6 +303,7 @@ class Pipeline: # Append a second filter to the pipeline pipeline.append(Median(size=5)) ``` + """ self._add_filter(str(len(self.filters)), fn) @@ -274,6 +316,7 @@ class Pipeline: Returns: The filtered image or volume after applying all sequential filters. + """ for fn in self.filters.values(): input = fn(input) @@ -281,12 +324,8 @@ class Pipeline: def gaussian( - vol: np.ndarray, - sigma: float, - dask: bool = False, - chunks: str = "auto", - **kwargs - ) -> np.ndarray: + vol: np.ndarray, sigma: float, dask: bool = False, chunks: str = 'auto', **kwargs +) -> np.ndarray: """ Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter. @@ -295,11 +334,11 @@ def gaussian( sigma (float or sequence of floats): The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes. dask (bool, optional): Whether to use Dask for the Gaussian filter. chunks (int or tuple or "'auto'", optional): Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing. - *args (Any): Additional positional arguments for the Gaussian filter. **kwargs (Any): Additional keyword arguments for the Gaussian filter. Returns: filtered_vol (np.ndarray): The filtered image or volume. + """ if dask: @@ -314,12 +353,12 @@ def gaussian( def median( - vol: np.ndarray, + vol: np.ndarray, size: float = None, footprint: np.ndarray = None, - dask: bool = False, - chunks: str = "auto", - **kwargs + dask: bool = False, + chunks: str = 'auto', + **kwargs, ) -> np.ndarray: """ Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter. @@ -337,11 +376,12 @@ def median( Raises: RuntimeError: If neither size nor footprint is defined + """ if size is None: if footprint is None: - raise RuntimeError("no footprint or filter size provided") - + raise RuntimeError('no footprint or filter size provided') + if dask: if not isinstance(vol, da.Array): vol = da.from_array(vol, chunks=chunks) @@ -357,9 +397,9 @@ def maximum( vol: np.ndarray, size: float = None, footprint: np.ndarray = None, - dask: bool = False, - chunks: str = "auto", - **kwargs + dask: bool = False, + chunks: str = 'auto', + **kwargs, ) -> np.ndarray: """ Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter. @@ -374,14 +414,15 @@ def maximum( Returns: filtered_vol (np.ndarray): The filtered image or volume. - + Raises: RuntimeError: If neither size nor footprint is defined + """ if size is None: if footprint is None: - raise RuntimeError("no footprint or filter size provided") - + raise RuntimeError('no footprint or filter size provided') + if dask: if not isinstance(vol, da.Array): vol = da.from_array(vol, chunks=chunks) @@ -394,12 +435,12 @@ def maximum( def minimum( - vol: np.ndarray, + vol: np.ndarray, size: float = None, footprint: np.ndarray = None, - dask: bool = False, - chunks: str = "auto", - **kwargs + dask: bool = False, + chunks: str = 'auto', + **kwargs, ) -> np.ndarray: """ Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter. @@ -417,11 +458,12 @@ def minimum( Raises: RuntimeError: If neither size nor footprint is defined + """ if size is None: if footprint is None: - raise RuntimeError("no footprint or filter size provided") - + raise RuntimeError('no footprint or filter size provided') + if dask: if not isinstance(vol, da.Array): vol = da.from_array(vol, chunks=chunks) @@ -432,10 +474,8 @@ def minimum( res = ndimage.minimum_filter(vol, size, footprint, **kwargs) return res -def tophat(vol: np.ndarray, - dask: bool = False, - **kwargs - ): + +def tophat(vol: np.ndarray, dask: bool = False, **kwargs): """ Remove background from the volume. @@ -448,24 +488,25 @@ def tophat(vol: np.ndarray, Returns: filtered_vol (np.ndarray): The volume with background removed. + """ - radius = kwargs["radius"] if "radius" in kwargs else 3 - background = kwargs["background"] if "background" in kwargs else "dark" + radius = kwargs['radius'] if 'radius' in kwargs else 3 + background = kwargs['background'] if 'background' in kwargs else 'dark' if dask: - log.info("Dask not supported for tophat filter, switching to scipy.") + log.info('Dask not supported for tophat filter, switching to scipy.') - if background == "bright": + if background == 'bright': log.info( - "Bright background selected, volume will be temporarily inverted when applying white_tophat" + 'Bright background selected, volume will be temporarily inverted when applying white_tophat' ) vol = np.invert(vol) selem = morphology.ball(radius) vol = vol - morphology.white_tophat(vol, selem) - if background == "bright": + if background == 'bright': vol = np.invert(vol) - return vol \ No newline at end of file + return vol diff --git a/qim3d/generate/__init__.py b/qim3d/generate/__init__.py index 3f0a48780a5b9730678905dd3a5eb1c272ab700b..1bae47513ab6e6f07f7d7a64d83294b3bc2497c8 100644 --- a/qim3d/generate/__init__.py +++ b/qim3d/generate/__init__.py @@ -1,2 +1,2 @@ -from ._generators import noise_object from ._aggregators import noise_object_collection +from ._generators import noise_object diff --git a/qim3d/generate/_aggregators.py b/qim3d/generate/_aggregators.py index 8ede2caef7634742be5df2b91d218abd86e02340..cbf8407cede918e05329eb45112d9b357aa35714 100644 --- a/qim3d/generate/_aggregators.py +++ b/qim3d/generate/_aggregators.py @@ -22,6 +22,7 @@ def random_placement( Returns: collection (numpy.ndarray): 3D volume of the collection with the blob placed. placed (bool): Flag for placement success. + """ # Find available (zero) elements in collection available_z, available_y, available_x = np.where(collection == 0) @@ -44,14 +45,12 @@ def random_placement( if np.all( collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0 ): - # Check if placement is within bounds (bool) within_bounds = np.all(start >= 0) and np.all( end <= np.array(collection.shape) ) if within_bounds: - # Place blob collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = ( blob @@ -81,6 +80,7 @@ def specific_placement( collection (numpy.ndarray): 3D volume of the collection with the blob placed. placed (bool): Flag for placement success. positions (list[tuple]): List of remaining positions to place blobs. + """ # Flag for placement success placed = False @@ -99,14 +99,12 @@ def specific_placement( if np.all( collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0 ): - # Check if placement is within bounds (bool) within_bounds = np.all(start >= 0) and np.all( end <= np.array(collection.shape) ) if within_bounds: - # Place blob collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = ( blob @@ -253,13 +251,13 @@ def noise_object_collection( ``` <iframe src="https://platform.qim.dk/k3d/synthetic_collection_cylinder.html" width="100%" height="500" frameborder="0"></iframe> - + ```python # Visualize slices qim3d.viz.slices_grid(vol, num_slices=15) ``` -  - +  + Example: ```python import qim3d @@ -283,29 +281,30 @@ def noise_object_collection( qim3d.viz.volumetric(vol) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_collection_tube.html" width="100%" height="500" frameborder="0"></iframe> - + ```python # Visualize slices qim3d.viz.slices_grid(vol, num_slices=15, slice_axis=1) ```  + """ if verbose: original_log_level = log.getEffectiveLevel() - log.setLevel("DEBUG") + log.setLevel('DEBUG') # Check valid input types if not isinstance(collection_shape, tuple) or len(collection_shape) != 3: raise TypeError( - "Shape of collection must be a tuple with three dimensions (z, y, x)" + 'Shape of collection must be a tuple with three dimensions (z, y, x)' ) if len(min_shape) != len(max_shape): - raise ValueError("Object shapes must be tuples of the same length") + raise ValueError('Object shapes must be tuples of the same length') if (positions is not None) and (len(positions) != num_objects): raise ValueError( - "Number of objects must match number of positions, otherwise set positions = None" + 'Number of objects must match number of positions, otherwise set positions = None' ) # Set seed for random number generator @@ -318,8 +317,8 @@ def noise_object_collection( labels = np.zeros_like(collection_array) # Fill the 3D array with synthetic blobs - for i in tqdm(range(num_objects), desc="Objects placed"): - log.debug(f"\nObject #{i+1}") + for i in tqdm(range(num_objects), desc='Objects placed'): + log.debug(f'\nObject #{i+1}') # Sample from blob parameter ranges if min_shape == max_shape: @@ -328,27 +327,27 @@ def noise_object_collection( blob_shape = tuple( rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3) ) - log.debug(f"- Blob shape: {blob_shape}") + log.debug(f'- Blob shape: {blob_shape}') # Scale object shape final_shape = tuple(l * r for l, r in zip(blob_shape, object_shape_zoom)) - final_shape = tuple(int(x) for x in final_shape) # NOTE: Added this + final_shape = tuple(int(x) for x in final_shape) # NOTE: Added this # Sample noise scale noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise) - log.debug(f"- Object noise scale: {noise_scale:.4f}") + log.debug(f'- Object noise scale: {noise_scale:.4f}') gamma = rng.uniform(low=min_gamma, high=max_gamma) - log.debug(f"- Gamma correction: {gamma:.3f}") + log.debug(f'- Gamma correction: {gamma:.3f}') if max_high_value > min_high_value: max_value = rng.integers(low=min_high_value, high=max_high_value) else: max_value = min_high_value - log.debug(f"- Max value: {max_value}") + log.debug(f'- Max value: {max_value}') threshold = rng.uniform(low=min_threshold, high=max_threshold) - log.debug(f"- Threshold: {threshold:.3f}") + log.debug(f'- Threshold: {threshold:.3f}') # Generate synthetic object blob = qim3d.generate.noise_object( @@ -368,7 +367,7 @@ def noise_object_collection( low=min_rotation_degrees, high=max_rotation_degrees ) # Sample rotation angle axes = rng.choice(rotation_axes) # Sample the two axes to rotate around - log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}") + log.debug(f'- Rotation angle: {angle:.2f} at axes: {axes}') blob = scipy.ndimage.rotate(blob, angle, axes, order=1) @@ -397,7 +396,7 @@ def noise_object_collection( if not placed: # Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar) log.error( - f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects." + f'Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects.' ) if verbose: diff --git a/qim3d/generate/_generators.py b/qim3d/generate/_generators.py index ed54b8b4de9fef663caea00fd75429524c7c5f0e..cb359e2a13ac30b1609d08a93fd71b5fe6beec7e 100644 --- a/qim3d/generate/_generators.py +++ b/qim3d/generate/_generators.py @@ -4,6 +4,7 @@ from noise import pnoise3 import qim3d.processing + def noise_object( base_shape: tuple = (128, 128, 128), final_shape: tuple = (128, 128, 128), @@ -14,8 +15,8 @@ def noise_object( threshold: float = 0.5, smooth_borders: bool = False, object_shape: str = None, - dtype: str = "uint8", - ) -> np.ndarray: + dtype: str = 'uint8', +) -> np.ndarray: """ Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction. @@ -97,18 +98,19 @@ def noise_object( qim3d.viz.volumetric(vol) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_blob_tube.html" width="100%" height="500" frameborder="0"></iframe> - + ```python # Visualize qim3d.viz.slices_grid(vol, num_slices=15) ``` -  +  + """ if not isinstance(final_shape, tuple) or len(final_shape) != 3: - raise TypeError("Size must be a tuple of 3 dimensions") + raise TypeError('Size must be a tuple of 3 dimensions') if not np.issubdtype(dtype, np.number): - raise ValueError("Invalid data type") + raise ValueError('Invalid data type') # Initialize the 3D array for the shape volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32) @@ -119,19 +121,18 @@ def noise_object( # Calculate the distance from the center of the shape center = np.array(base_shape) / 2 - dist = np.sqrt((z - center[0])**2 + - (y - center[1])**2 + - (x - center[2])**2) - - dist /= np.sqrt(3 * (center[0]**2)) + dist = np.sqrt((z - center[0]) ** 2 + (y - center[1]) ** 2 + (x - center[2]) ** 2) + + dist /= np.sqrt(3 * (center[0] ** 2)) # Generate Perlin noise and adjust the values based on the distance from the center - vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input + vectorized_pnoise3 = np.vectorize( + pnoise3 + ) # Vectorize pnoise3, since it only takes scalar input - noise = vectorized_pnoise3(z.flatten() * noise_scale, - y.flatten() * noise_scale, - x.flatten() * noise_scale - ).reshape(base_shape) + noise = vectorized_pnoise3( + z.flatten() * noise_scale, y.flatten() * noise_scale, x.flatten() * noise_scale + ).reshape(base_shape) volume = (1 + noise) * (1 - dist) @@ -148,17 +149,22 @@ def noise_object( if object_shape: smooth_borders = False - if smooth_borders: + if smooth_borders: # Maximum value among the six sides of the 3D volume - max_border_value = np.max([ - np.max(volume[0, :, :]), np.max(volume[-1, :, :]), - np.max(volume[:, 0, :]), np.max(volume[:, -1, :]), - np.max(volume[:, :, 0]), np.max(volume[:, :, -1]) - ]) + max_border_value = np.max( + [ + np.max(volume[0, :, :]), + np.max(volume[-1, :, :]), + np.max(volume[:, 0, :]), + np.max(volume[:, -1, :]), + np.max(volume[:, :, 0]), + np.max(volume[:, :, -1]), + ] + ) # Compute threshold such that there will be no straight cuts in the blob threshold = max_border_value / max_value - + # Clip the low values of the volume to create a coherent volume volume[volume < threshold * max_value] = 0 @@ -171,45 +177,50 @@ def noise_object( ) # Fade into a shape if specified - if object_shape == "cylinder": - + if object_shape == 'cylinder': # Arguments for the fade_mask function - geometry = "cylindrical" # Fade in cylindrical geometry - axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest - target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical - - volume = qim3d.operations.fade_mask(volume, - geometry = geometry, - axis = axis, - target_max_normalized_distance = target_max_normalized_distance - ) - - elif object_shape == "tube": - + geometry = 'cylindrical' # Fade in cylindrical geometry + axis = np.argmax( + volume.shape + ) # Fade along the dimension where the object is the largest + target_max_normalized_distance = ( + 1.4 # This value ensures that the object will become cylindrical + ) + + volume = qim3d.operations.fade_mask( + volume, + geometry=geometry, + axis=axis, + target_max_normalized_distance=target_max_normalized_distance, + ) + + elif object_shape == 'tube': # Arguments for the fade_mask function - geometry = "cylindrical" # Fade in cylindrical geometry - axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest - decay_rate = 5 # Decay rate for the fade operation - target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical + geometry = 'cylindrical' # Fade in cylindrical geometry + axis = np.argmax( + volume.shape + ) # Fade along the dimension where the object is the largest + decay_rate = 5 # Decay rate for the fade operation + target_max_normalized_distance = ( + 1.4 # This value ensures that the object will become cylindrical + ) # Fade once for making the object cylindrical - volume = qim3d.operations.fade_mask(volume, - geometry = geometry, - axis = axis, - decay_rate = decay_rate, - target_max_normalized_distance = target_max_normalized_distance, - invert = False - ) + volume = qim3d.operations.fade_mask( + volume, + geometry=geometry, + axis=axis, + decay_rate=decay_rate, + target_max_normalized_distance=target_max_normalized_distance, + invert=False, + ) # Fade again with invert = True for making the object a tube (i.e. with a hole in the middle) - volume = qim3d.operations.fade_mask(volume, - geometry = geometry, - axis = axis, - decay_rate = decay_rate, - invert = True - ) - + volume = qim3d.operations.fade_mask( + volume, geometry=geometry, axis=axis, decay_rate=decay_rate, invert=True + ) + # Convert to desired data type volume = volume.astype(dtype) - return volume \ No newline at end of file + return volume diff --git a/qim3d/gui/__init__.py b/qim3d/gui/__init__.py index cc7c9ff40272df922cf79a315cf8729e98f6a858..2557fb2ca2de61590f204b9fbd7fbbca59d6d2eb 100644 --- a/qim3d/gui/__init__.py +++ b/qim3d/gui/__init__.py @@ -1,26 +1,24 @@ from fastapi import FastAPI + import qim3d.utils -from . import data_explorer -from . import iso3d -from . import local_thickness -from . import annotation_tool -from . import layers2d + +from . import annotation_tool, data_explorer, iso3d, layers2d, local_thickness from .qim_theme import QimTheme -def run_gradio_app(gradio_interface, host="0.0.0.0"): +def run_gradio_app(gradio_interface, host='0.0.0.0'): import gradio as gr import uvicorn # Get port using the QIM API port_dict = qim3d.utils.get_port_dict() - if "gradio_port" in port_dict: - port = port_dict["gradio_port"] - elif "port" in port_dict: - port = port_dict["port"] + if 'gradio_port' in port_dict: + port = port_dict['gradio_port'] + elif 'port' in port_dict: + port = port_dict['port'] else: - raise Exception("Port not specified from QIM API") + raise Exception('Port not specified from QIM API') qim3d.utils.gradio_header(gradio_interface.title, port) @@ -30,7 +28,7 @@ def run_gradio_app(gradio_interface, host="0.0.0.0"): app = gr.mount_gradio_app(app, gradio_interface, path=path) # Full path - print(f"http://{host}:{port}{path}") + print(f'http://{host}:{port}{path}') # Run the FastAPI server usign uvicorn uvicorn.run(app, host=host, port=int(port)) diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py index 8f386da385ab9e26c089a99670fb972f2afc17bd..05312ee79cca19fe839aea8cbee4a89a3add2d3a 100644 --- a/qim3d/gui/annotation_tool.py +++ b/qim3d/gui/annotation_tool.py @@ -27,6 +27,7 @@ import tempfile import gradio as gr import numpy as np from PIL import Image + import qim3d from qim3d.gui.interface import BaseInterface @@ -34,17 +35,19 @@ from qim3d.gui.interface import BaseInterface class Interface(BaseInterface): - def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None): + def __init__( + self, name_suffix: str = '', verbose: bool = False, img: np.ndarray = None + ): super().__init__( - title="Annotation Tool", + title='Annotation Tool', height=768, - width="100%", + width='100%', verbose=verbose, - custom_css="annotation_tool.css", + custom_css='annotation_tool.css', ) self.username = getpass.getuser() - self.temp_dir = os.path.join(tempfile.gettempdir(), f"qim-{self.username}") + self.temp_dir = os.path.join(tempfile.gettempdir(), f'qim-{self.username}') self.name_suffix = name_suffix self.img = img @@ -57,7 +60,7 @@ class Interface(BaseInterface): # Get the temporary files from gradio temp_path_list = [] for filename in os.listdir(self.temp_dir): - if "mask" and self.name_suffix in str(filename): + if 'mask' and self.name_suffix in str(filename): # Get the list of the temporary files temp_path_list.append(os.path.join(self.temp_dir, filename)) @@ -76,9 +79,9 @@ class Interface(BaseInterface): this is safer and backwards compatible (should be) """ self.mask_names = [ - f"red{self.name_suffix}", - f"green{self.name_suffix}", - f"blue{self.name_suffix}", + f'red{self.name_suffix}', + f'green{self.name_suffix}', + f'blue{self.name_suffix}', ] # Clean up old files @@ -86,7 +89,7 @@ class Interface(BaseInterface): files = os.listdir(self.temp_dir) for filename in files: # Check if "mask" is in the filename - if ("mask" in filename) and (self.name_suffix in filename): + if ('mask' in filename) and (self.name_suffix in filename): file_path = os.path.join(self.temp_dir, filename) os.remove(file_path) @@ -94,13 +97,13 @@ class Interface(BaseInterface): files = None def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray: - background = img_editor["background"] - masks = img_editor["layers"][0] + background = img_editor['background'] + masks = img_editor['layers'][0] overlay_image = qim3d.operations.overlay_rgb_images(background, masks) return overlay_image def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]: - masks_rgb = img_editor["layers"][0] + masks_rgb = img_editor['layers'][0] mask_threshold = 200 # This value is based mask_list = [] @@ -114,7 +117,7 @@ class Interface(BaseInterface): # Save only if we have a mask if np.sum(mask) > 0: mask_list.append(mask) - filename = f"mask_{self.mask_names[idx]}.tif" + filename = f'mask_{self.mask_names[idx]}.tif' if not os.path.exists(self.temp_dir): os.makedirs(self.temp_dir) filepath = os.path.join(self.temp_dir, filename) @@ -128,11 +131,11 @@ class Interface(BaseInterface): def define_interface(self, **kwargs): brush = gr.Brush( colors=[ - "rgb(255,50,100)", - "rgb(50,250,100)", - "rgb(50,100,255)", + 'rgb(255,50,100)', + 'rgb(50,250,100)', + 'rgb(50,100,255)', ], - color_mode="fixed", + color_mode='fixed', default_size=10, ) with gr.Row(): @@ -142,26 +145,25 @@ class Interface(BaseInterface): img_editor = gr.ImageEditor( value=( { - "background": self.img, - "layers": [Image.new("RGBA", self.img.shape, (0, 0, 0, 0))], - "composite": None, + 'background': self.img, + 'layers': [Image.new('RGBA', self.img.shape, (0, 0, 0, 0))], + 'composite': None, } if self.img is not None else None ), - type="numpy", - image_mode="RGB", + type='numpy', + image_mode='RGB', brush=brush, - sources="upload", + sources='upload', interactive=True, show_download_button=True, container=False, - transforms=["crop"], + transforms=['crop'], layers=False, ) with gr.Column(scale=1, min_width=256): - with gr.Row(): overlay_img = gr.Image( show_download_button=False, @@ -169,7 +171,7 @@ class Interface(BaseInterface): visible=False, ) with gr.Row(): - masks_download = gr.File(label="Download masks", visible=False) + masks_download = gr.File(label='Download masks', visible=False) # fmt: off img_editor.change( diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py index 80222988f0d079f9fa663ac2be9c6883aaa0dfe7..c0ef407939cd43d5f7822a334137bcc894353a54 100644 --- a/qim3d/gui/data_explorer.py +++ b/qim3d/gui/data_explorer.py @@ -18,54 +18,51 @@ app.launch() import datetime import os import re +from typing import Any, Callable, Dict import gradio as gr +import matplotlib import matplotlib.figure import matplotlib.pyplot as plt import numpy as np import outputformat as ouf +from qim3d.gui.interface import BaseInterface from qim3d.io import load -from qim3d.utils._logger import log from qim3d.utils import _misc - -from qim3d.gui.interface import BaseInterface -from typing import Callable, Any, Dict -import matplotlib +from qim3d.utils._logger import log class Interface(BaseInterface): - def __init__(self, - verbose:bool = False, - figsize:int = 8, - display_saturation_percentile:int = 99, - nbins:int = 32): + def __init__( + self, + verbose: bool = False, + figsize: int = 8, + display_saturation_percentile: int = 99, + nbins: int = 32, + ): """ - Parameters: - ----------- + Parameters + ---------- verbose (bool, optional): If true, prints info during session into terminal. Defualt is False. figsize (int, optional): Sets the size of plots displaying the slices. Default is 8. display_saturation_percentile (int, optional): Sets the display saturation percentile. Defaults to 99. + """ - super().__init__( - title = "Data Explorer", - height = 1024, - width = 900, - verbose = verbose - ) - self.axis_dict = {"Z":0, "Y":1, "X":2} + super().__init__(title='Data Explorer', height=1024, width=900, verbose=verbose) + self.axis_dict = {'Z': 0, 'Y': 1, 'X': 2} self.all_operations = [ - "Z Slicer", - "Y Slicer", - "X Slicer", - "Z max projection", - "Z min projection", - "Intensity histogram", - "Data summary", + 'Z Slicer', + 'Y Slicer', + 'X Slicer', + 'Z max projection', + 'Z min projection', + 'Intensity histogram', + 'Data summary', ] - self.calculated_operations = [] # For changing the visibility of results, we keep track what was calculated and thus will be displayed + self.calculated_operations = [] # For changing the visibility of results, we keep track what was calculated and thus will be displayed - self.vol = None # The loaded volume + self.vol = None # The loaded volume # Plotting self.figsize = figsize @@ -79,7 +76,12 @@ class Interface(BaseInterface): # Spinner state - what phase after clicking run button are we in self.spinner_state = -1 - self.spinner_messages = ["Starting session...", "Loading data...", "Running pipeline...", "Relaunch"] + self.spinner_messages = [ + 'Starting session...', + 'Loading data...', + 'Running pipeline...', + 'Relaunch', + ] # Error message that we want to show, for more details look inside function check error state self.error_message = None @@ -87,57 +89,55 @@ class Interface(BaseInterface): # File selection and parameters with gr.Row(): with gr.Column(scale=2): - gr.Markdown("### File selection") + gr.Markdown('### File selection') with gr.Row(): with gr.Column(scale=99, min_width=128): base_path = gr.Textbox( max_lines=1, container=False, - label="Base path", + label='Base path', value=os.getcwd(), ) with gr.Column(scale=1, min_width=36): - reload_base_path = gr.Button( - value="⟳" - ) + reload_base_path = gr.Button(value='⟳') explorer = gr.FileExplorer( - ignore_glob="*/.*", # ignores hidden files + ignore_glob='*/.*', # ignores hidden files root_dir=os.getcwd(), label=os.getcwd(), render=True, - file_count="single", + file_count='single', interactive=True, - height = 320, + height=320, ) with gr.Column(scale=1): - gr.Markdown("### Parameters") + gr.Markdown('### Parameters') cmap = gr.Dropdown( - value="viridis", + value='viridis', choices=plt.colormaps(), - label="Colormap", + label='Colormap', interactive=True, ) virtual_stack = gr.Checkbox( value=False, - label="Virtual stack", - info="If checked, will use less memory by loading the images on demand.", + label='Virtual stack', + info='If checked, will use less memory by loading the images on demand.', ) load_series = gr.Checkbox( value=False, - label="Load series", - info="If checked, will load the whole series of images in the same folder as the selected file.", + label='Load series', + info='If checked, will load the whole series of images in the same folder as the selected file.', ) series_contains = gr.Textbox( - label="Specify common part of file names for series", - value="", + label='Specify common part of file names for series', + value='', visible=False, ) dataset_name = gr.Textbox( - label="Dataset name (in case of H5 files, for example)", - value="exchange/data", + label='Dataset name (in case of H5 files, for example)', + value='exchange/data', ) def toggle_show(checkbox): @@ -151,7 +151,7 @@ class Interface(BaseInterface): load_series.change(toggle_show, load_series, series_contains) with gr.Column(scale=1): - gr.Markdown("### Operations") + gr.Markdown('### Operations') operations = gr.CheckboxGroup( choices=self.all_operations, value=[self.all_operations[0], self.all_operations[-1]], @@ -161,11 +161,13 @@ class Interface(BaseInterface): ) with gr.Row(): btn_run = gr.Button( - value="Load & Run", variant = "primary", + value='Load & Run', + variant='primary', ) # Visualization and results with gr.Row(): + def create_uniform_image(intensity=1): """ Generates a blank image with a single color. @@ -174,50 +176,50 @@ class Interface(BaseInterface): """ pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255) fig, ax = plt.subplots(figsize=(10, 10)) - ax.imshow(pixels, interpolation="nearest") + ax.imshow(pixels, interpolation='nearest') # Adjustments - ax.axis("off") + ax.axis('off') fig.subplots_adjust(left=0, right=1, bottom=0, top=1) return fig # Z Slicer with gr.Column(visible=False) as result_z_slicer: - zslice_plot = gr.Plot(label="Z slice", value=create_uniform_image(1)) + zslice_plot = gr.Plot(label='Z slice', value=create_uniform_image(1)) zpos = gr.Slider( - minimum=0, maximum=1, value=0.5, step=0.01, label="Z position" + minimum=0, maximum=1, value=0.5, step=0.01, label='Z position' ) # Y Slicer with gr.Column(visible=False) as result_y_slicer: - yslice_plot = gr.Plot(label="Y slice", value=create_uniform_image(1)) + yslice_plot = gr.Plot(label='Y slice', value=create_uniform_image(1)) ypos = gr.Slider( - minimum=0, maximum=1, value=0.5, step=0.01, label="Y position" + minimum=0, maximum=1, value=0.5, step=0.01, label='Y position' ) # X Slicer with gr.Column(visible=False) as result_x_slicer: - xslice_plot = gr.Plot(label="X slice", value=create_uniform_image(1)) + xslice_plot = gr.Plot(label='X slice', value=create_uniform_image(1)) xpos = gr.Slider( - minimum=0, maximum=1, value=0.5, step=0.01, label="X position" + minimum=0, maximum=1, value=0.5, step=0.01, label='X position' ) # Z Max projection with gr.Column(visible=False) as result_z_max_projection: max_projection_plot = gr.Plot( - label="Z max projection", + label='Z max projection', ) # Z Min projection with gr.Column(visible=False) as result_z_min_projection: min_projection_plot = gr.Plot( - label="Z min projection", + label='Z min projection', ) # Intensity histogram with gr.Column(visible=False) as result_intensity_histogram: - hist_plot = gr.Plot(label="Volume intensity histogram") + hist_plot = gr.Plot(label='Volume intensity histogram') # Text box with data summary with gr.Column(visible=False) as result_data_summary: @@ -225,12 +227,10 @@ class Interface(BaseInterface): lines=24, label=None, show_label=False, - - value="Data summary", + value='Data summary', ) ### Gradio objects lists - #################################### # EVENT LISTENERS ################################### @@ -247,29 +247,38 @@ class Interface(BaseInterface): ] results = [ - result_z_slicer, - result_y_slicer, - result_x_slicer, - result_z_max_projection, - result_z_min_projection, - result_intensity_histogram, - result_data_summary, - ] - - reload_base_path.click(fn=self.update_explorer,inputs=base_path, outputs=explorer) - - btn_run.click( - fn=self.update_run_btn, inputs = [], outputs = btn_run).then( - fn=self.start_session, inputs = [load_series, series_contains, explorer, base_path], outputs = []).then( - fn=self.update_run_btn, inputs = [], outputs = btn_run).then( - fn=self.check_error_state, inputs = [], outputs = []).success( - fn=self.load_data, inputs= [virtual_stack, dataset_name, series_contains], outputs= []).then( - fn=self.update_run_btn, inputs = [], outputs = btn_run).then( - fn=self.check_error_state, inputs = [], outputs = []).success( - fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs).then( - fn=self.update_run_btn, inputs = [], outputs = btn_run).then( - fn=self.check_error_state, inputs = [], outputs = []).success( - fn=self.show_results, inputs = operations, outputs = results) # results are columns of images and other component, not just the components + result_z_slicer, + result_y_slicer, + result_x_slicer, + result_z_max_projection, + result_z_min_projection, + result_intensity_histogram, + result_data_summary, + ] + + reload_base_path.click( + fn=self.update_explorer, inputs=base_path, outputs=explorer + ) + + btn_run.click(fn=self.update_run_btn, inputs=[], outputs=btn_run).then( + fn=self.start_session, + inputs=[load_series, series_contains, explorer, base_path], + outputs=[], + ).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then( + fn=self.check_error_state, inputs=[], outputs=[] + ).success( + fn=self.load_data, + inputs=[virtual_stack, dataset_name, series_contains], + outputs=[], + ).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then( + fn=self.check_error_state, inputs=[], outputs=[] + ).success( + fn=self.run_operations, inputs=pipeline_inputs, outputs=pipeline_outputs + ).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then( + fn=self.check_error_state, inputs=[], outputs=[] + ).success( + fn=self.show_results, inputs=operations, outputs=results + ) # results are columns of images and other component, not just the components """ Gradio passes only the value to the function, not the whole component. @@ -278,15 +287,21 @@ class Interface(BaseInterface): The self.update_slice_wrapper returns a function. """ sliders = [xpos, ypos, zpos] - letters = ["X", "Y", "Z"] + letters = ['X', 'Y', 'Z'] plots = [xslice_plot, yslice_plot, zslice_plot] for slider, letter, plot in zip(sliders, letters, plots): - slider.change(fn = self.update_slice_wrapper(letter), inputs = [slider, cmap], outputs = plot, show_progress="hidden") + slider.change( + fn=self.update_slice_wrapper(letter), + inputs=[slider, cmap], + outputs=plot, + show_progress='hidden', + ) - # Immediate change without the need of pressing the relaunch button - operations.change(fn=self.show_results, inputs = operations, outputs = results) - cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs) + operations.change(fn=self.show_results, inputs=operations, outputs=results) + cmap.change( + fn=self.run_operations, inputs=pipeline_inputs, outputs=pipeline_outputs + ) def update_explorer(self, new_path: str): new_path = os.path.expanduser(new_path) @@ -301,51 +316,60 @@ class Interface(BaseInterface): return gr.update(root_dir=parent_dir, label=parent_dir, value=file_name) else: - raise ValueError("Invalid path") + raise ValueError('Invalid path') def update_run_btn(self): """ When run_btn is clicked, it becomes uninteractive and displays which operation is now in progress When all operations are done, it becomes interactive again with 'Relaunch' label """ - self.spinner_state = (self.spinner_state + 1) % len(self.spinner_messages) if self.error_message is None else len(self.spinner_messages) - 1 + self.spinner_state = ( + (self.spinner_state + 1) % len(self.spinner_messages) + if self.error_message is None + else len(self.spinner_messages) - 1 + ) message = self.spinner_messages[self.spinner_state] - interactive = (self.spinner_state == len(self.spinner_messages) - 1) + interactive = self.spinner_state == len(self.spinner_messages) - 1 return gr.update( - value=f"{message}", + value=f'{message}', interactive=interactive, ) def check_error_state(self): """ - Raising gr.Error doesn't allow us to return anything and thus we can not update the Run button with + Raising gr.Error doesn't allow us to return anything and thus we can not update the Run button with progress messages. We have to first update the button and then raise an Error so the button is interactive """ if self.error_message is not None: error_message = self.error_message self.error_message = None raise gr.Error(error_message) - -####################################################### -# -# THE PIPELINE -# -####################################################### - def start_session(self, load_series:bool, series_contains:str, explorer:str, base_path:str): - self.projections_calculated = False # Probably new file was loaded, we would need new projections + ####################################################### + # + # THE PIPELINE + # + ####################################################### + + def start_session( + self, load_series: bool, series_contains: str, explorer: str, base_path: str + ): + self.projections_calculated = ( + False # Probably new file was loaded, we would need new projections + ) - if load_series and series_contains == "": + if load_series and series_contains == '': # Try to guess the common part of the file names try: - filename = explorer.split("/")[-1] # Extract filename from path - series_contains = re.search(r"[^0-9]+", filename).group() - gr.Info(f"Using '{series_contains}' as common file name part for loading.") + filename = explorer.split('/')[-1] # Extract filename from path + series_contains = re.search(r'[^0-9]+', filename).group() + gr.Info( + f"Using '{series_contains}' as common file name part for loading." + ) self.series_contains = series_contains except: self.error_message = "For series, common part of file name must be provided in 'series_contains' field." - # Get the file path from the explorer or base path # priority is given to the explorer if file is selected @@ -357,20 +381,19 @@ class Interface(BaseInterface): self.file_path = base_path else: - self.error_message = "Invalid file path" + self.error_message = 'Invalid file path' # If we are loading a series, we need to get the directory if load_series: self.file_path = os.path.dirname(self.file_path) - - def load_data(self, virtual_stack:bool, dataset_name:str, contains:str): + def load_data(self, virtual_stack: bool, dataset_name: str, contains: str): try: self.vol = load( - path = self.file_path, - virtual_stack = virtual_stack, - dataset_name = dataset_name, - contains = contains + path=self.file_path, + virtual_stack=virtual_stack, + dataset_name=dataset_name, + contains=contains, ) # Incase the data is 4D (RGB for example), we take the mean of the last dimension @@ -379,54 +402,58 @@ class Interface(BaseInterface): # The rest of the pipeline expects 3D data if self.vol.ndim != 3: - self.error_message = F"Invalid data shape should be 3 dimensional, not shape: {self.vol.shape}" + self.error_message = f'Invalid data shape should be 3 dimensional, not shape: {self.vol.shape}' except Exception as error_message: - self.error_message = F"Error when loading data: {error_message}" - + self.error_message = f'Error when loading data: {error_message}' + def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]: outputs = [] self.calculated_operations = [] for operation in self.all_operations: if operation in operations: - log.info(f"Running {operation}") + log.info(f'Running {operation}') try: outputs.append(self.run_operation(operation, *args)) self.calculated_operations.append(operation) except Exception as err: - self.error_message = F"Error while running operation '{operation}': {err}" + self.error_message = ( + f"Error while running operation '{operation}': {err}" + ) log.info(self.error_message) outputs.append(gr.update()) else: - log.info(f"Skipping {operation}") - outputs.append(gr.update()) - + log.info(f'Skipping {operation}') + outputs.append(gr.update()) + return outputs - def run_operation(self, operation:list, zpos:float, ypos:float, xpos:float, cmap:str, *args): + def run_operation( + self, operation: list, zpos: float, ypos: float, xpos: float, cmap: str, *args + ): match operation: - case "Z Slicer": - return self.update_slice_wrapper("Z")(zpos, cmap) - case "Y Slicer": - return self.update_slice_wrapper("Y")(ypos, cmap) - case "X Slicer": - return self.update_slice_wrapper("X")(xpos, cmap) - case "Z max projection": + case 'Z Slicer': + return self.update_slice_wrapper('Z')(zpos, cmap) + case 'Y Slicer': + return self.update_slice_wrapper('Y')(ypos, cmap) + case 'X Slicer': + return self.update_slice_wrapper('X')(xpos, cmap) + case 'Z max projection': return self.create_projections_figs()[0] - case "Z min projection": + case 'Z min projection': return self.create_projections_figs()[1] - case "Intensity histogram": - # If the operations are run with the run_button, spinner_state == 2, - # If we just changed cmap, spinner state would be 3 + case 'Intensity histogram': + # If the operations are run with the run_button, spinner_state == 2, + # If we just changed cmap, spinner state would be 3 # and we don't have to calculate histogram again # That saves a lot of time as the histogram takes the most time to calculate - return self.plot_histogram() if self.spinner_state == 2 else gr.update() - case "Data summary": + return self.plot_histogram() if self.spinner_state == 2 else gr.update() + case 'Data summary': return self.show_data_summary() case _: - raise NotImplementedError(F"Operation '{operation} is not defined") + raise NotImplementedError(f"Operation '{operation} is not defined") def show_results(self, operations: list[str]) -> list[Dict[str, Any]]: update_list = [] @@ -437,32 +464,34 @@ class Interface(BaseInterface): update_list.append(gr.update(visible=False)) return update_list -####################################################### -# -# CALCULATION OF IMAGES -# -####################################################### + ####################################################### + # + # CALCULATION OF IMAGES + # + ####################################################### def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure: fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) - ax.imshow(img, interpolation="nearest", **kwargs) + ax.imshow(img, interpolation='nearest', **kwargs) # Adjustments - ax.axis("off") + ax.axis('off') fig.subplots_adjust(left=0, right=1, bottom=0, top=1) return fig - def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]: - def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]: + def update_slice_wrapper( + self, letter: str + ) -> Callable[[float, str], Dict[str, Any]]: + def update_slice(position_slider: float, cmap: str) -> Dict[str, Any]: """ position_slider: float from gradio slider, saying which relative slice we want to see cmap: string gradio drop down menu, saying what cmap we want to use for display """ axis = self.axis_dict[letter] slice_index = int(position_slider * (self.vol.shape[axis] - 1)) - + plt.close() plt.set_cmap(cmap) @@ -475,14 +504,19 @@ class Interface(BaseInterface): # The axis we want to slice along is moved to be the last one, could also be the first one, it doesn't matter # Then we take out the slice defined in self.position for this axis - slice_img = np.moveaxis(self.vol, axis, -1)[:,:,slice_index] + slice_img = np.moveaxis(self.vol, axis, -1)[:, :, slice_index] + + fig_img = self.create_img_fig(slice_img, vmin=vmin, vmax=vmax) + + return gr.update( + value=fig_img, label=f'{letter} Slice: {slice_index}', visible=True + ) - fig_img = self.create_img_fig(slice_img, vmin = vmin, vmax = vmax) - - return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True) return update_slice - - def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]: + + def vol_histogram( + self, nbins: int, min_value: float, max_value: float + ) -> tuple[np.ndarray, np.ndarray]: # Start histogram vol_hist = np.zeros(nbins) @@ -500,22 +534,28 @@ class Interface(BaseInterface): if not self.projections_calculated: _ = self.get_projections() - vol_hist, bin_edges = self.vol_histogram(self.nbins, self.min_value, self.max_value) + vol_hist, bin_edges = self.vol_histogram( + self.nbins, self.min_value, self.max_value + ) fig, ax = plt.subplots(figsize=(6, 4)) - ax.bar(bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge") + ax.bar( + bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge' + ) # Adjustments - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["left"].set_visible(True) - ax.spines["bottom"].set_visible(True) - ax.set_yscale("log") + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['left'].set_visible(True) + ax.spines['bottom'].set_visible(True) + ax.set_yscale('log') return fig - - def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]: + + def create_projections_figs( + self, + ) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]: if not self.projections_calculated: projections = self.get_projections() self.max_projection = projections[0] @@ -539,7 +579,7 @@ class Interface(BaseInterface): def get_projections(self) -> tuple[np.ndarray, np.ndarray]: # Create arrays for iteration max_projection = np.zeros(np.shape(self.vol[0])) - min_projection = np.ones(np.shape(self.vol[0])) * float("inf") + min_projection = np.ones(np.shape(self.vol[0])) * float('inf') intensity_sum = 0 # Iterate over slices. This is needed in case of virtual stacks. @@ -566,20 +606,22 @@ class Interface(BaseInterface): def show_data_summary(self): summary_dict = { - "Last modified": datetime.datetime.fromtimestamp(os.path.getmtime(self.file_path)).strftime("%Y-%m-%d %H:%M"), - "File size": _misc.sizeof(os.path.getsize(self.file_path)), - "Z-size": str(self.vol.shape[self.axis_dict["Z"]]), - "Y-size": str(self.vol.shape[self.axis_dict["Y"]]), - "X-size": str(self.vol.shape[self.axis_dict["X"]]), - "Data type": str(self.vol.dtype), - "Min value": str(self.vol.min()), - "Mean value": str(np.mean(self.vol)), - "Max value": str(self.vol.max()), + 'Last modified': datetime.datetime.fromtimestamp( + os.path.getmtime(self.file_path) + ).strftime('%Y-%m-%d %H:%M'), + 'File size': _misc.sizeof(os.path.getsize(self.file_path)), + 'Z-size': str(self.vol.shape[self.axis_dict['Z']]), + 'Y-size': str(self.vol.shape[self.axis_dict['Y']]), + 'X-size': str(self.vol.shape[self.axis_dict['X']]), + 'Data type': str(self.vol.dtype), + 'Min value': str(self.vol.min()), + 'Mean value': str(np.mean(self.vol)), + 'Max value': str(self.vol.max()), } display_dict = {k: v for k, v in summary_dict.items() if v is not None} - return ouf.showdict(display_dict, return_str=True, title="Data summary") - + return ouf.showdict(display_dict, return_str=True, title='Data summary') + -if __name__ == "__main__": +if __name__ == '__main__': Interface().run_interface() diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py index 36d820dde8ab24ac13c94590ae1f69c542f1a841..b350bca637cbd8a9e5ea53ef80db5da261554d69 100644 --- a/qim3d/gui/interface.py +++ b/qim3d/gui/interface.py @@ -1,16 +1,16 @@ +from abc import ABC, abstractmethod +from os import listdir, path from pathlib import Path -from abc import abstractmethod, ABC -from os import path, listdir import gradio as gr +import numpy as np -from .qim_theme import QimTheme import qim3d.gui -import numpy as np # TODO: when offline it throws an error in cli class BaseInterface(ABC): + """ Annotation tool and Data explorer as those don't need any examples. """ @@ -19,7 +19,7 @@ class BaseInterface(ABC): self, title: str, height: int, - width: int = "100%", + width: int = '100%', verbose: bool = False, custom_css: str = None, ): @@ -38,7 +38,7 @@ class BaseInterface(ABC): self.qim_dir = Path(qim3d.__file__).parents[0] self.custom_css = ( - path.join(self.qim_dir, "css", custom_css) + path.join(self.qim_dir, 'css', custom_css) if custom_css is not None else None ) @@ -48,9 +48,9 @@ class BaseInterface(ABC): def set_invisible(self): return gr.update(visible=False) - + def change_visibility(self, is_visible: bool): - return gr.update(visible = is_visible) + return gr.update(visible=is_visible) def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs): """ @@ -72,8 +72,7 @@ class BaseInterface(ABC): quiet=not self.verbose, height=self.height, width=self.width, - favicon_path=Path(qim3d.__file__).parents[0] - / "gui/assets/qim3d-icon.svg", + favicon_path=Path(qim3d.__file__).parents[0] / 'gui/assets/qim3d-icon.svg', **kwargs, ) @@ -88,7 +87,7 @@ class BaseInterface(ABC): title=self.title, css=self.custom_css, ) as gradio_interface: - gr.Markdown(f"# {self.title}") + gr.Markdown(f'# {self.title}') self.define_interface(**kwargs) return gradio_interface @@ -96,11 +95,12 @@ class BaseInterface(ABC): def define_interface(self, **kwargs): pass - def run_interface(self, host: str = "0.0.0.0"): + def run_interface(self, host: str = '0.0.0.0'): qim3d.gui.run_gradio_app(self.create_interface(), host) class InterfaceWithExamples(BaseInterface): + """ For Iso3D and Local Thickness """ @@ -117,7 +117,23 @@ class InterfaceWithExamples(BaseInterface): self._set_examples_list() def _set_examples_list(self): - valid_sufixes = (".tif", ".tiff", ".h5", ".nii", ".gz", ".dcm", ".DCM", ".vol", ".vgi", ".txrm", ".txm", ".xrm") + valid_sufixes = ( + '.tif', + '.tiff', + '.h5', + '.nii', + '.gz', + '.dcm', + '.DCM', + '.vol', + '.vgi', + '.txrm', + '.txm', + '.xrm', + ) examples_folder = path.join(self.qim_dir, 'examples') - self.img_examples = [path.join(examples_folder, example) for example in listdir(examples_folder) if example.endswith(valid_sufixes)] - + self.img_examples = [ + path.join(examples_folder, example) + for example in listdir(examples_folder) + if example.endswith(valid_sufixes) + ] diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py index 2725403f094dd58e89030f23dba06982551491f4..7226df8b2480c3bee6b1473138ce4351e130dafc 100644 --- a/qim3d/gui/iso3d.py +++ b/qim3d/gui/iso3d.py @@ -15,6 +15,7 @@ app.launch() ``` """ + import os import gradio as gr @@ -23,21 +24,19 @@ import plotly.graph_objects as go from scipy import ndimage import qim3d -from qim3d.utils._logger import log from qim3d.gui.interface import InterfaceWithExamples +from qim3d.utils._logger import log -#TODO img in launch should be self.img +# TODO img in launch should be self.img class Interface(InterfaceWithExamples): - def __init__(self, - verbose:bool = False, - plot_height:int = 768, - img = None): - - super().__init__(title = "Isosurfaces for 3D visualization", - height = 1024, - width = 960, - verbose = verbose) + def __init__(self, verbose: bool = False, plot_height: int = 768, img=None): + super().__init__( + title='Isosurfaces for 3D visualization', + height=1024, + width=960, + verbose=verbose, + ) self.interface = None self.img = img @@ -48,11 +47,13 @@ class Interface(InterfaceWithExamples): self.vol = qim3d.io.load(gradiofile.name) assert self.vol.ndim == 3 except AttributeError: - raise gr.Error("You have to select a file") + raise gr.Error('You have to select a file') except ValueError: - raise gr.Error("Unsupported file format") + raise gr.Error('Unsupported file format') except AssertionError: - raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") + raise gr.Error( + f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}" + ) def resize_vol(self, display_size: int): """Resizes the loaded volume to the display size""" @@ -61,12 +62,12 @@ class Interface(InterfaceWithExamples): original_Z, original_Y, original_X = np.shape(self.vol) max_size = np.max([original_Z, original_Y, original_X]) if self.verbose: - log.info(f"\nOriginal volume: {original_Z, original_Y, original_X}") + log.info(f'\nOriginal volume: {original_Z, original_Y, original_X}') # Resize for display self.vol = ndimage.zoom( input=self.vol, - zoom = display_size / max_size, + zoom=display_size / max_size, order=0, prefilter=False, ) @@ -76,16 +77,17 @@ class Interface(InterfaceWithExamples): ) if self.verbose: log.info( - f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}" + f'Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}' ) def save_fig(self, fig: go.Figure, filename: str): # Write Plotly figure to disk fig.write_html(filename) - def create_fig(self, + def create_fig( + self, gradio_file: gr.File, - display_size: int , + display_size: int, opacity: float, opacityscale: str, only_wireframe: bool, @@ -105,8 +107,7 @@ class Interface(InterfaceWithExamples): slice_y_location: int, show_x_slice: bool, slice_x_location: int, - ) -> tuple[go.Figure, str]: - + ) -> tuple[go.Figure, str]: # Load volume self.load_data(gradio_file) @@ -129,191 +130,184 @@ class Interface(InterfaceWithExamples): fig = go.Figure( go.Volume( - z = Z.flatten(), - y = Y.flatten(), - x = X.flatten(), - value = self.vol.flatten(), - isomin = min_value * np.max(self.vol), - isomax = max_value * np.max(self.vol), - cmin = np.min(self.vol), - cmax = np.max(self.vol), - opacity = opacity, - opacityscale = opacityscale, - surface_count = surface_count, - colorscale = colormap, - slices_z = dict( - show = show_z_slice, - locations = [int(self.display_size_z * slice_z_location)], + z=Z.flatten(), + y=Y.flatten(), + x=X.flatten(), + value=self.vol.flatten(), + isomin=min_value * np.max(self.vol), + isomax=max_value * np.max(self.vol), + cmin=np.min(self.vol), + cmax=np.max(self.vol), + opacity=opacity, + opacityscale=opacityscale, + surface_count=surface_count, + colorscale=colormap, + slices_z=dict( + show=show_z_slice, + locations=[int(self.display_size_z * slice_z_location)], ), - slices_y = dict( - show = show_y_slice, + slices_y=dict( + show=show_y_slice, locations=[int(self.display_size_y * slice_y_location)], ), - slices_x = dict( - show = show_x_slice, - locations = [int(self.display_size_x * slice_x_location)], + slices_x=dict( + show=show_x_slice, + locations=[int(self.display_size_x * slice_x_location)], ), - surface = dict(fill=surface_fill), - caps = dict( - x_show = show_caps, - y_show = show_caps, - z_show = show_caps, + surface=dict(fill=surface_fill), + caps=dict( + x_show=show_caps, + y_show=show_caps, + z_show=show_caps, ), - showscale = show_colorbar, + showscale=show_colorbar, colorbar=dict( - thickness=8, outlinecolor="#fff", len=0.5, orientation="h" + thickness=8, outlinecolor='#fff', len=0.5, orientation='h' ), - reversescale = reversescale, - hoverinfo = "skip", + reversescale=reversescale, + hoverinfo='skip', ) ) fig.update_layout( - scene_xaxis_showticklabels = show_ticks, - scene_yaxis_showticklabels = show_ticks, - scene_zaxis_showticklabels = show_ticks, - scene_xaxis_visible = show_axis, - scene_yaxis_visible = show_axis, - scene_zaxis_visible = show_axis, - scene_aspectmode="data", + scene_xaxis_showticklabels=show_ticks, + scene_yaxis_showticklabels=show_ticks, + scene_zaxis_showticklabels=show_ticks, + scene_xaxis_visible=show_axis, + scene_yaxis_visible=show_axis, + scene_zaxis_visible=show_axis, + scene_aspectmode='data', height=self.plot_height, hovermode=False, scene_camera_eye=dict(x=2.0, y=-2.0, z=1.5), ) - filename = "iso3d.html" + filename = 'iso3d.html' self.save_fig(fig, filename) return fig, filename - + def remove_unused_file(self): # Remove localthickness.tif file from working directory # as it otherwise is not deleted - os.remove("iso3d.html") + os.remove('iso3d.html') def define_interface(self, **kwargs): - gr.Markdown( - """ + """ This tool uses Plotly Volume (https://plotly.com/python/3d-volume-plots/) to create iso surfaces from voxels based on their intensity levels. To optimize performance when generating visualizations, set the number of voxels (_display resolution_) and isosurfaces (_total surfaces_) to lower levels. """ - ) + ) with gr.Row(): # Input and parameters column with gr.Column(scale=1, min_width=320): - with gr.Tab("Input"): + with gr.Tab('Input'): # File loader - gradio_file = gr.File( - show_label=False - ) - with gr.Tab("Examples"): + gradio_file = gr.File(show_label=False) + with gr.Tab('Examples'): gr.Examples(examples=self.img_examples, inputs=gradio_file) # Run button with gr.Row(): with gr.Column(scale=3, min_width=64): btn_run = gr.Button( - value="Run 3D visualization", variant = "primary" + value='Run 3D visualization', variant='primary' ) with gr.Column(scale=1, min_width=64): - btn_clear = gr.Button( - value="Clear", variant = "stop" - ) + btn_clear = gr.Button(value='Clear', variant='stop') - with gr.Tab("Display"): + with gr.Tab('Display'): # Display options display_size = gr.Slider( 32, 128, step=4, - label="Display resolution", - info="Number of voxels for the largest dimension", + label='Display resolution', + info='Number of voxels for the largest dimension', value=64, ) surface_count = gr.Slider( - 2, 16, step=1, label="Total iso-surfaces", value=6 + 2, 16, step=1, label='Total iso-surfaces', value=6 ) - show_caps = gr.Checkbox(value=False, label="Show surface caps") + show_caps = gr.Checkbox(value=False, label='Show surface caps') with gr.Row(): opacityscale = gr.Dropdown( - choices=["uniform", "extremes", "min", "max"], - value="uniform", - label="Opacity scale", - info="Handles opacity acording to voxel value", + choices=['uniform', 'extremes', 'min', 'max'], + value='uniform', + label='Opacity scale', + info='Handles opacity acording to voxel value', ) opacity = gr.Slider( - 0.0, 1.0, step=0.1, label="Max opacity", value=0.4 + 0.0, 1.0, step=0.1, label='Max opacity', value=0.4 ) with gr.Row(): min_value = gr.Slider( - 0.0, 1.0, step=0.05, label="Min value", value=0.1 + 0.0, 1.0, step=0.05, label='Min value', value=0.1 ) max_value = gr.Slider( - 0.0, 1.0, step=0.05, label="Max value", value=1 + 0.0, 1.0, step=0.05, label='Max value', value=1 ) - with gr.Tab("Slices") as slices: - show_z_slice = gr.Checkbox(value=False, label="Show Z slice") + with gr.Tab('Slices') as slices: + show_z_slice = gr.Checkbox(value=False, label='Show Z slice') slice_z_location = gr.Slider( - 0.0, 1.0, step=0.05, value=0.5, label="Position" + 0.0, 1.0, step=0.05, value=0.5, label='Position' ) - show_y_slice = gr.Checkbox(value=False, label="Show Y slice") + show_y_slice = gr.Checkbox(value=False, label='Show Y slice') slice_y_location = gr.Slider( - 0.0, 1.0, step=0.05, value=0.5, label="Position" + 0.0, 1.0, step=0.05, value=0.5, label='Position' ) - show_x_slice = gr.Checkbox(value=False, label="Show X slice") + show_x_slice = gr.Checkbox(value=False, label='Show X slice') slice_x_location = gr.Slider( - 0.0, 1.0, step=0.05, value=0.5, label="Position" + 0.0, 1.0, step=0.05, value=0.5, label='Position' ) - with gr.Tab("Misc"): + with gr.Tab('Misc'): with gr.Row(): colormap = gr.Dropdown( choices=[ - "Blackbody", - "Bluered", - "Blues", - "Cividis", - "Earth", - "Electric", - "Greens", - "Greys", - "Hot", - "Jet", - "Magma", - "Picnic", - "Portland", - "Rainbow", - "RdBu", - "Reds", - "Viridis", - "YlGnBu", - "YlOrRd", + 'Blackbody', + 'Bluered', + 'Blues', + 'Cividis', + 'Earth', + 'Electric', + 'Greens', + 'Greys', + 'Hot', + 'Jet', + 'Magma', + 'Picnic', + 'Portland', + 'Rainbow', + 'RdBu', + 'Reds', + 'Viridis', + 'YlGnBu', + 'YlOrRd', ], - value="Magma", - label="Colormap", + value='Magma', + label='Colormap', ) show_colorbar = gr.Checkbox( - value=False, label="Show color scale" + value=False, label='Show color scale' ) reversescale = gr.Checkbox( - value=False, label="Reverse color scale" + value=False, label='Reverse color scale' ) - flip_z = gr.Checkbox(value=True, label="Flip Z axis") - show_axis = gr.Checkbox(value=True, label="Show axis") - show_ticks = gr.Checkbox(value=False, label="Show ticks") - only_wireframe = gr.Checkbox( - value=False, label="Only wireframe" - ) + flip_z = gr.Checkbox(value=True, label='Flip Z axis') + show_axis = gr.Checkbox(value=True, label='Show axis') + show_ticks = gr.Checkbox(value=False, label='Show ticks') + only_wireframe = gr.Checkbox(value=False, label='Only wireframe') # Inputs for gradio inputs = [ @@ -346,7 +340,7 @@ class Interface(InterfaceWithExamples): plot_download = gr.File( interactive=False, - label="Download interactive plot", + label='Download interactive plot', show_label=True, visible=False, ) @@ -367,5 +361,6 @@ class Interface(InterfaceWithExamples): fn=self.remove_unused_file).success( fn=self.set_visible, inputs=None, outputs=plot_download) -if __name__ == "__main__": - Interface().run_interface() \ No newline at end of file + +if __name__ == '__main__': + Interface().run_interface() diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py index 5a25f6e975a084b263fc5a3657aa72eed01ed94a..e5eb9f4dfde36d6d4fe783671e1ea4664a5501c6 100644 --- a/qim3d/gui/layers2d.py +++ b/qim3d/gui/layers2d.py @@ -18,19 +18,15 @@ app = layers.launch() """ import os +from typing import Any, Dict import gradio as gr import numpy as np -from .interface import BaseInterface # from qim3d.processing import layers2d as l2d -from qim3d.processing import segment_layers, get_lines -from qim3d.operations import overlay_rgb_images -from qim3d.io import load -from qim3d.viz._layers2d import image_with_lines -from typing import Dict, Any +from .interface import BaseInterface -#TODO figure out how not update anything and go through processing when there are no data loaded +# TODO figure out how not update anything and go through processing when there are no data loaded # So user could play with the widgets but it doesnt throw error # Right now its only bypassed with several if statements # I opened an issue here https://github.com/gradio-app/gradio/issues/9273 @@ -38,22 +34,26 @@ from typing import Dict, Any X = 'X' Y = 'Y' Z = 'Z' -AXES = {X:2, Y:1, Z:0} +AXES = {X: 2, Y: 1, Z: 0} DEFAULT_PLOT_TYPE = 'Segmentation mask' -SEGMENTATION_COLORS = np.array([[0, 255, 255], # Cyan - [255, 195, 0], # Yellow Orange - [199, 0, 57], # Dark orange - [218, 247, 166], # Light green - [255, 0, 255], # Magenta - [65, 105, 225], # Royal blue - [138, 43, 226], # Blue violet - [255, 0, 0], #Red - ]) +SEGMENTATION_COLORS = np.array( + [ + [0, 255, 255], # Cyan + [255, 195, 0], # Yellow Orange + [199, 0, 57], # Dark orange + [218, 247, 166], # Light green + [255, 0, 255], # Magenta + [65, 105, 225], # Royal blue + [138, 43, 226], # Blue violet + [255, 0, 0], # Red + ] +) + class Interface(BaseInterface): def __init__(self): - super().__init__("Layered surfaces 2D", 1080) + super().__init__('Layered surfaces 2D', 1080) self.data = None # It important to keep the name of the attributes like this (including the capital letter) becuase of @@ -69,8 +69,6 @@ class Interface(BaseInterface): self.error = False - - def define_interface(self): with gr.Row(): with gr.Column(scale=1, min_width=320): @@ -79,62 +77,62 @@ class Interface(BaseInterface): base_path = gr.Textbox( max_lines=1, container=False, - label="Base path", + label='Base path', value=os.getcwd(), ) with gr.Column(scale=1, min_width=36): - reload_base_path = gr.Button(value="⟳") + reload_base_path = gr.Button(value='⟳') explorer = gr.FileExplorer( - ignore_glob="*/.*", + ignore_glob='*/.*', root_dir=os.getcwd(), label=os.getcwd(), render=True, - file_count="single", + file_count='single', interactive=True, - height = 230, + height=230, ) - with gr.Group(): with gr.Row(): axis = gr.Radio( - choices = [Z, Y, X], - value = Z, - label = 'Layer axis', - info = 'Specifies in which direction are the layers. The order of axes is ZYX',) + choices=[Z, Y, X], + value=Z, + label='Layer axis', + info='Specifies in which direction are the layers. The order of axes is ZYX', + ) with gr.Row(): wrap = gr.Checkbox( - label = "Lines start and end at the same level.", - info = "Used when segmenting layers of unfolded image." + label='Lines start and end at the same level.', + info='Used when segmenting layers of unfolded image.', ) - + is_inverted = gr.Checkbox( - label="Invert image before processing", - info="The algorithm effectively flips the gradient.", - ) - + label='Invert image before processing', + info='The algorithm effectively flips the gradient.', + ) + with gr.Row(): delta = gr.Slider( minimum=0, maximum=5, value=0.75, step=0.01, - interactive = True, - label="Delta value", - info="The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute. Delta above 1 is rounded down to closest lower integer", + interactive=True, + label='Delta value', + info='The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute. Delta above 1 is rounded down to closest lower integer', ) - + with gr.Row(): min_margin = gr.Slider( - minimum=1, - maximum=50, - value=10, - step=1, - interactive = True, - label="Min margin", - info="Minimum margin between layers to be detected in the image.", + minimum=1, + maximum=50, + value=10, + step=1, + interactive=True, + label='Min margin', + info='Minimum margin between layers to be detected in the image.', ) with gr.Row(): @@ -144,9 +142,9 @@ class Interface(BaseInterface): value=2, step=1, interactive=True, - label="Number of layers", - info="Number of layers to be detected in the image", - ) + label='Number of layers', + info='Number of layers to be detected in the image', + ) # with gr.Row(): # btn_run = gr.Button("Run Layers2D", variant = 'primary') @@ -162,15 +160,20 @@ class Interface(BaseInterface): change their height manually """ - self.heights = ['60em', '30em', '20em'] # em units are relative to the parent, - + self.heights = [ + '60em', + '30em', + '20em', + ] # em units are relative to the parent, - with gr.Column(scale=2,): + with gr.Column( + scale=2, + ): # with gr.Row(): # Source image outputs # input_image_kwargs = lambda axis: dict( # show_label = True, - # label = F'Slice along {axis}-axis', - # visible = True, + # label = F'Slice along {axis}-axis', + # visible = True, # height = self.heights[2] # ) @@ -178,69 +181,77 @@ class Interface(BaseInterface): # input_plot_y = gr.Image(**input_image_kwargs('Y')) # input_plot_z = gr.Image(**input_image_kwargs('Z')) - with gr.Row(): # Detected layers outputs + with gr.Row(): # Detected layers outputs output_image_kwargs = lambda axis: dict( - show_label = True, - label = F'Detected layers {axis}-axis', - visible = True, - height = self.heights[2] + show_label=True, + label=f'Detected layers {axis}-axis', + visible=True, + height=self.heights[2], ) output_plot_x = gr.Image(**output_image_kwargs('X')) output_plot_y = gr.Image(**output_image_kwargs('Y')) output_plot_z = gr.Image(**output_image_kwargs('Z')) - - with gr.Row(): # Axis position sliders + + with gr.Row(): # Axis position sliders slider_kwargs = lambda axis: dict( - minimum = 0, - maximum = 1, - value = 0.5, - step = 0.01, - label = F'{axis} position', - info = F'The 3D image is sliced along {axis}-axis' + minimum=0, + maximum=1, + value=0.5, + step=0.01, + label=f'{axis} position', + info=f'The 3D image is sliced along {axis}-axis', ) - - x_pos = gr.Slider(**slider_kwargs('X')) + + x_pos = gr.Slider(**slider_kwargs('X')) y_pos = gr.Slider(**slider_kwargs('Y')) z_pos = gr.Slider(**slider_kwargs('Z')) with gr.Row(): - x_check = gr.Checkbox(value = True, interactive=True, label = 'Show X slice') - y_check = gr.Checkbox(value = True, interactive=True, label = 'Show Y slice') - z_check = gr.Checkbox(value = True, interactive=True, label = 'Show Z slice') + x_check = gr.Checkbox( + value=True, interactive=True, label='Show X slice' + ) + y_check = gr.Checkbox( + value=True, interactive=True, label='Show Y slice' + ) + z_check = gr.Checkbox( + value=True, interactive=True, label='Show Z slice' + ) with gr.Row(): with gr.Group(): plot_type = gr.Radio( - choices= (DEFAULT_PLOT_TYPE, 'Segmentation lines',), - value = DEFAULT_PLOT_TYPE, - interactive = True, - show_label=False - ) - + choices=( + DEFAULT_PLOT_TYPE, + 'Segmentation lines', + ), + value=DEFAULT_PLOT_TYPE, + interactive=True, + show_label=False, + ) + alpha = gr.Slider( minimum=0, - maximum = 1, - step = 0.01, - label = 'Alpha value', + maximum=1, + step=0.01, + label='Alpha value', show_label=True, - value = 0.5, - visible = True, - interactive=True - ) - + value=0.5, + visible=True, + interactive=True, + ) + line_thickness = gr.Slider( minimum=0.1, - maximum = 5, - value = 2, - label = 'Line thickness', - show_label = True, - visible = False, - interactive = True - ) + maximum=5, + value=2, + label='Line thickness', + show_label=True, + visible=False, + interactive=True, + ) with gr.Row(): - btn_run = gr.Button("Run Layers2D", variant = 'primary') - + btn_run = gr.Button('Run Layers2D', variant='primary') positions = [x_pos, y_pos, z_pos] process_inputs = [axis, is_inverted, delta, min_margin, n_layers, wrap] @@ -249,20 +260,24 @@ class Interface(BaseInterface): output_plots = [output_plot_x, output_plot_y, output_plot_z] visibility_check_inputs = [x_check, y_check, z_check] - spinner_loading = gr.Text("Loading data...", visible=False) - spinner_running = gr.Text("Running pipeline...", visible=False) + spinner_loading = gr.Text('Loading data...', visible=False) + spinner_running = gr.Text('Running pipeline...', visible=False) reload_base_path.click( - fn=self.update_explorer,inputs=base_path, outputs=explorer) - + fn=self.update_explorer, inputs=base_path, outputs=explorer + ) + plot_type.change( - self.change_plot_type, inputs = plot_type, outputs = [alpha, line_thickness]).then( - fn = self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots - ) - + self.change_plot_type, inputs=plot_type, outputs=[alpha, line_thickness] + ).then( + fn=self.plot_output_img_all, inputs=plotting_inputs, outputs=output_plots + ) + gr.on( - triggers = [alpha.release, line_thickness.release], - fn = self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots + triggers=[alpha.release, line_thickness.release], + fn=self.plot_output_img_all, + inputs=plotting_inputs, + outputs=output_plots, ) """ @@ -275,54 +290,93 @@ class Interface(BaseInterface): update_component = gr.State(True) btn_run.click( - fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run).then( - fn=self.load_data, inputs = [base_path, explorer]).then( - fn = lambda state: not state, inputs = update_component, outputs = update_component) - + fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run + ).then(fn=self.load_data, inputs=[base_path, explorer]).then( + fn=lambda state: not state, + inputs=update_component, + outputs=update_component, + ) + gr.on( - triggers= (axis.change, is_inverted.change, delta.release, min_margin.release, n_layers.release, update_component.change, wrap.change), - fn=self.set_spinner, inputs = spinner_running, outputs=btn_run).then( - fn=self.process_all, inputs = [*positions, *process_inputs]).then( + triggers=( + axis.change, + is_inverted.change, + delta.release, + min_margin.release, + n_layers.release, + update_component.change, + wrap.change, + ), + fn=self.set_spinner, + inputs=spinner_running, + outputs=btn_run, + ).then(fn=self.process_all, inputs=[*positions, *process_inputs]).then( # fn=self.plot_input_img_all, outputs = input_plots, show_progress='hidden').then( - fn=self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots, show_progress='hidden').then( - fn=self.set_relaunch_button, inputs=[], outputs=btn_run) - + fn=self.plot_output_img_all, + inputs=plotting_inputs, + outputs=output_plots, + show_progress='hidden', + ).then(fn=self.set_relaunch_button, inputs=[], outputs=btn_run) + # Chnages visibility and sizes of the plots - gives user the option to see only some of the images and in bigger scale gr.on( triggers=[x_check.change, y_check.change, z_check.change], - fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = positions).then( + fn=self.change_row_visibility, + inputs=visibility_check_inputs, + outputs=positions, + ).then( # fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = input_plots).then( - fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots) - + fn=self.change_plot_size, + inputs=visibility_check_inputs, + outputs=output_plots, + ) + # for axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots): - for axis, slider, output_plot in zip([X,Y,Z], positions, output_plots): + for axis, slider, output_plot in zip([X, Y, Z], positions, output_plots): slider.change( - self.process_wrapper(axis), inputs = [slider, *process_inputs]).then( + self.process_wrapper(axis), inputs=[slider, *process_inputs] + ).then( # self.plot_input_img_wrapper(axis), outputs = input_plot).then( - self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot) - - + self.plot_output_img_wrapper(axis), + inputs=plotting_inputs, + outputs=output_plot, + ) - def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]: + def change_plot_type( + self, + plot_type: str, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: self.plot_type = plot_type if plot_type == 'Segmentation lines': - return gr.update(visible = False), gr.update(visible = True) - else: - return gr.update(visible = True), gr.update(visible = False) - - def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + return gr.update(visible=False), gr.update(visible=True) + else: + return gr.update(visible=True), gr.update(visible=False) + + def change_plot_size( + self, x_check: int, y_check: int, z_check: int + ) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: """ Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define also their height because gradio doesn't do it automatically. The values of heights were set just by eye. They are defines before defining the plot in 'define_interface' """ index = x_check + y_check + z_check - 1 - height = self.heights[index] # also used to define heights of plots in the begining - return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check) + height = self.heights[ + index + ] # also used to define heights of plots in the begining + return ( + gr.update(height=height, visible=x_check), + gr.update(height=height, visible=y_check), + gr.update(height=height, visible=z_check), + ) def change_row_visibility(self, x_check: int, y_check: int, z_check: int): - return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check) - + return ( + self.change_visibility(x_check), + self.change_visibility(y_check), + self.change_visibility(z_check), + ) + def update_explorer(self, new_path: str): # Refresh the file explorer object new_path = os.path.expanduser(new_path) @@ -337,54 +391,76 @@ class Interface(BaseInterface): return gr.update(root_dir=parent_dir, label=parent_dir, value=file_name) else: - raise ValueError("Invalid path") + raise ValueError('Invalid path') def set_relaunch_button(self): - return gr.update(value=f"Relaunch", interactive=True) + return gr.update(value='Relaunch', interactive=True) def set_spinner(self, message: str): if self.error: return gr.Button() # spinner icon/shows the user something is happeing - return gr.update(value=f"{message}", interactive=False) - + return gr.update(value=f'{message}', interactive=False) + def load_data(self, base_path: str, explorer: str): if base_path and os.path.isfile(base_path): file_path = base_path elif explorer and os.path.isfile(explorer): file_path = explorer else: - raise gr.Error("Invalid file path") + raise gr.Error('Invalid file path') try: - self.data = qim3d.io.load( - file_path, - progress_bar=False - ) + self.data = qim3d.io.load(file_path, progress_bar=False) except Exception as error_message: raise gr.Error( - f"Failed to load the image: {error_message}" + f'Failed to load the image: {error_message}' ) from error_message - - def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool): - self.process_wrapper(X)(x_pos, axis, inverted, delta, min_margin, n_layers, wrap) - self.process_wrapper(Y)(y_pos, axis, inverted, delta, min_margin, n_layers, wrap) - self.process_wrapper(Z)(z_pos, axis, inverted, delta, min_margin, n_layers, wrap) - def process_wrapper(self, slicing_axis:str): + def process_all( + self, + x_pos: float, + y_pos: float, + z_pos: float, + axis: str, + inverted: bool, + delta: float, + min_margin: int, + n_layers: int, + wrap: bool, + ): + self.process_wrapper(X)( + x_pos, axis, inverted, delta, min_margin, n_layers, wrap + ) + self.process_wrapper(Y)( + y_pos, axis, inverted, delta, min_margin, n_layers, wrap + ) + self.process_wrapper(Z)( + z_pos, axis, inverted, delta, min_margin, n_layers, wrap + ) + + def process_wrapper(self, slicing_axis: str): """ The function behaves the same in all 3 directions, however we have to know in which direction we are now. Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data and we return a function working in that direction """ - slice_key = F'{slicing_axis}_slice' - seg_key = F'{slicing_axis}_segmentation' + slice_key = f'{slicing_axis}_slice' + seg_key = f'{slicing_axis}_segmentation' slicing_axis_int = AXES[slicing_axis] - def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool): + def process( + pos: float, + segmenting_axis: str, + inverted: bool, + delta: float, + min_margin: int, + n_layers: int, + wrap: bool, + ): """ - Parameters: - ----------- + Parameters + ---------- pos: Relative position of a slice from data segmenting_axis: In which direction we want to detect layers inverted: If we want use inverted gradient @@ -392,6 +468,7 @@ class Interface(BaseInterface): min_margin: What is the minimum distance between layers. If it was 0, all layers would be the same n_layers: How many layer boarders we want to find wrap: If True, the starting point and end point will be at the same level. Useful when segmenting unfolded images. + """ slice = self.get_slice(pos, slicing_axis_int) self.__dict__[slice_key] = slice @@ -399,25 +476,33 @@ class Interface(BaseInterface): if segmenting_axis == slicing_axis: self.__dict__[seg_key] = None else: - if self.is_transposed(slicing_axis, segmenting_axis): slice = np.rot90(slice) - self.__dict__[seg_key] = qim3d.processing.segment_layers(slice, inverted = inverted, n_layers = n_layers, delta = delta, min_margin = min_margin, wrap = wrap) - + self.__dict__[seg_key] = qim3d.processing.segment_layers( + slice, + inverted=inverted, + n_layers=n_layers, + delta=delta, + min_margin=min_margin, + wrap=wrap, + ) + return process - def is_transposed(self, slicing_axis:str, segmenting_axis:str): + def is_transposed(self, slicing_axis: str, segmenting_axis: str): """ - Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is. + Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is. If it is not, we have to rotate it before we put it to segmentation algorithm """ - remaining_axis = F"{X}{Y}{Z}".replace(slicing_axis, '').replace(segmenting_axis, '') + remaining_axis = f'{X}{Y}{Z}'.replace(slicing_axis, '').replace( + segmenting_axis, '' + ) return AXES[segmenting_axis] > AXES[remaining_axis] - - def get_slice(self, pos:float, axis:int): + + def get_slice(self, pos: float, axis: int): idx = int(pos * (self.data.shape[axis] - 1)) - return np.take(self.data, idx, axis = axis) - + return np.take(self.data, idx, axis=axis) + # def plot_input_img_wrapper(self, axis:str): # slice_key = F'{axis.lower()}_slice' # def plot_input_img(): @@ -432,45 +517,55 @@ class Interface(BaseInterface): # y_plot = self.plot_input_img_wrapper('y')() # z_plot = self.plot_input_img_wrapper('z')() # return x_plot, y_plot, z_plot - - def plot_output_img_wrapper(self, slicing_axis:str): - slice_key = F'{slicing_axis}_slice' - seg_key = F'{slicing_axis}_segmentation' - def plot_output_img(segmenting_axis:str, alpha:float, line_thickness:float): + def plot_output_img_wrapper(self, slicing_axis: str): + slice_key = f'{slicing_axis}_slice' + seg_key = f'{slicing_axis}_segmentation' + + def plot_output_img(segmenting_axis: str, alpha: float, line_thickness: float): slice = self.__dict__[slice_key] seg = self.__dict__[seg_key] - if seg is None: # In case segmenting axis si the same as slicing axis + if seg is None: # In case segmenting axis si the same as slicing axis return slice - + if self.plot_type == DEFAULT_PLOT_TYPE: n_layers = len(seg) + 1 - seg = np.sum(seg, axis = 0) - seg = np.repeat(seg[..., None], 3, axis = -1) + seg = np.sum(seg, axis=0) + seg = np.repeat(seg[..., None], 3, axis=-1) for i in range(n_layers): - seg[seg[:,:,0] == i, :] = SEGMENTATION_COLORS[i] + seg[seg[:, :, 0] == i, :] = SEGMENTATION_COLORS[i] if self.is_transposed(slicing_axis, segmenting_axis): - seg = np.rot90(seg, k = 3) + seg = np.rot90(seg, k=3) # slice = 255 * (slice/np.max(slice)) - # return image_with_overlay(np.repeat(slice[..., None], 3, -1), seg, alpha) + # return image_with_overlay(np.repeat(slice[..., None], 3, -1), seg, alpha) return qim3d.operations.overlay_rgb_images(slice, seg, alpha) else: lines = qim3d.processing.get_lines(seg) if self.is_transposed(slicing_axis, segmenting_axis): - return qim3d.viz.image_with_lines(np.rot90(slice), lines, line_thickness).rotate(270, expand = True) + return qim3d.viz.image_with_lines( + np.rot90(slice), lines, line_thickness + ).rotate(270, expand=True) else: return qim3d.viz.image_with_lines(slice, lines, line_thickness) - + return plot_output_img - - def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float): - x_output = self.plot_output_img_wrapper(X)(segmenting_axis, alpha, line_thickness) - y_output = self.plot_output_img_wrapper(Y)(segmenting_axis, alpha, line_thickness) - z_output = self.plot_output_img_wrapper(Z)(segmenting_axis, alpha, line_thickness) + + def plot_output_img_all( + self, segmenting_axis: str, alpha: float, line_thickness: float + ): + x_output = self.plot_output_img_wrapper(X)( + segmenting_axis, alpha, line_thickness + ) + y_output = self.plot_output_img_wrapper(Y)( + segmenting_axis, alpha, line_thickness + ) + z_output = self.plot_output_img_wrapper(Z)( + segmenting_axis, alpha, line_thickness + ) return x_output, y_output, z_output -if __name__ == "__main__": + +if __name__ == '__main__': Interface().run_interface() - \ No newline at end of file diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py index 1aefca7d0ba1f1524e8f3c3de45f797b60dbd9cf..652f18d46af6c37ef23415ec333aba98e39444b4 100644 --- a/qim3d/gui/local_thickness.py +++ b/qim3d/gui/local_thickness.py @@ -1,16 +1,16 @@ """ !!! quote "Reference" - Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW). + Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW). <https://doi.org/10.1109/cvprw59228.2023.00456> ```bibtex - @inproceedings{Dahl_2023, title={Fast Local Thickness}, - url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456}, - DOI={10.1109/cvprw59228.2023.00456}, - booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)}, - publisher={IEEE}, - author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm}, - year={2023}, + @inproceedings{Dahl_2023, title={Fast Local Thickness}, + url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456}, + DOI={10.1109/cvprw59228.2023.00456}, + booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)}, + publisher={IEEE}, + author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm}, + year={2023}, month=jun } ``` @@ -32,29 +32,31 @@ app.launch() ``` """ + import os +import gradio as gr +import localthickness as lt + # matplotlib.use("Agg") import matplotlib.pyplot as plt -import gradio as gr import numpy as np import tifffile -import localthickness as lt -import qim3d +import qim3d class Interface(qim3d.gui.interface.InterfaceWithExamples): - def __init__(self, - img: np.ndarray = None, - verbose:bool = False, - plot_height:int = 768, - figsize:int = 6): - - super().__init__(title = "Local thickness", - height = 1024, - width = 960, - verbose = verbose) + def __init__( + self, + img: np.ndarray = None, + verbose: bool = False, + plot_height: int = 768, + figsize: int = 6, + ): + super().__init__( + title='Local thickness', height=1024, width=960, verbose=verbose + ) self.plot_height = plot_height self.figsize = figsize @@ -64,7 +66,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): # Get the temporary files from gradio temp_sets = self.interface.temp_file_sets for temp_set in temp_sets: - if "localthickness" in str(temp_set): + if 'localthickness' in str(temp_set): # Get the lsit of the temporary files temp_path_list = list(temp_set) @@ -84,7 +86,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): def define_interface(self): gr.Markdown( - "Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)" + 'Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)' ) with gr.Row(): @@ -92,12 +94,12 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): if self.img is not None: data = gr.State(value=self.img) else: - with gr.Tab("Input"): + with gr.Tab('Input'): data = gr.File( show_label=False, value=self.img, ) - with gr.Tab("Examples"): + with gr.Tab('Examples'): gr.Examples(examples=self.img_examples, inputs=data) with gr.Row(): @@ -106,17 +108,15 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): maximum=1, value=0.5, step=0.01, - label="Z position", - info="Local thickness is calculated in 3D, this slider controls the visualization only.", + label='Z position', + info='Local thickness is calculated in 3D, this slider controls the visualization only.', ) - with gr.Tab("Parameters"): + with gr.Tab('Parameters'): gr.Markdown( - "It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results." - ) - lt_scale = gr.Slider( - 0.1, 1.0, label="Scale", value=0.5, step=0.1 + 'It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results.' ) + lt_scale = gr.Slider(0.1, 1.0, label='Scale', value=0.5, step=0.1) with gr.Row(): threshold = gr.Slider( @@ -124,85 +124,83 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): 1.0, value=0.5, step=0.05, - label="Threshold", - info="Local thickness uses a binary image, so a threshold value is needed.", + label='Threshold', + info='Local thickness uses a binary image, so a threshold value is needed.', ) dark_objects = gr.Checkbox( value=False, - label="Dark objects", - info="Inverts the image before thresholding. Use in case your foreground is darker than the background.", + label='Dark objects', + info='Inverts the image before thresholding. Use in case your foreground is darker than the background.', ) - with gr.Tab("Display options"): + with gr.Tab('Display options'): cmap_original = gr.Dropdown( - value="viridis", + value='viridis', choices=plt.colormaps(), - label="Colormap - input", + label='Colormap - input', interactive=True, ) cmap_lt = gr.Dropdown( - value="magma", + value='magma', choices=plt.colormaps(), - label="Colormap - local thickness", + label='Colormap - local thickness', interactive=True, ) - nbins = gr.Slider( - 5, 50, value=25, step=1, label="Histogram bins" - ) + nbins = gr.Slider(5, 50, value=25, step=1, label='Histogram bins') # Run button with gr.Row(): with gr.Column(scale=3, min_width=64): - btn = gr.Button( - "Run local thickness", variant = "primary" - ) + btn = gr.Button('Run local thickness', variant='primary') with gr.Column(scale=1, min_width=64): - btn_clear = gr.Button("Clear", variant = "stop") + btn_clear = gr.Button('Clear', variant='stop') - with gr.Column(scale=4): + def create_uniform_image(intensity=1): """ Generates a blank image with a single color. Gradio `gr.Plot` components will flicker if there is no default value. bug fix on gradio version 4.44.0 """ - pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255) + pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int( + intensity * 255 + ) fig, ax = plt.subplots(figsize=(10, 10)) - ax.imshow(pixels, interpolation="nearest") + ax.imshow(pixels, interpolation='nearest') # Adjustments - ax.axis("off") + ax.axis('off') fig.subplots_adjust(left=0, right=1, bottom=0, top=1) return fig - + with gr.Row(): input_vol = gr.Plot( show_label=True, - label="Original", + label='Original', visible=True, value=create_uniform_image(), ) binary_vol = gr.Plot( show_label=True, - label="Binary", + label='Binary', visible=True, value=create_uniform_image(), ) output_vol = gr.Plot( show_label=True, - label="Local thickness", + label='Local thickness', visible=True, value=create_uniform_image(), ) with gr.Row(): histogram = gr.Plot( show_label=True, - label="Thickness histogram", + label='Thickness histogram', visible=True, value=create_uniform_image(), ) @@ -210,11 +208,10 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): lt_output = gr.File( interactive=False, show_label=True, - label="Output file", + label='Output file', visible=False, ) - # Run button # fmt: off viz_input = lambda zpos, cmap: self.show_slice(self.vol, zpos, self.vmin, self.vmax, cmap) @@ -246,11 +243,11 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): fn=viz_input, inputs = [zpos, cmap_original], outputs=input_vol, show_progress=False).success( fn=viz_binary, inputs = [zpos, cmap_original], outputs=binary_vol, show_progress=False).success( fn=viz_output, inputs = [zpos, cmap_lt], outputs=output_vol, show_progress=False) - + cmap_original.change( fn=viz_input, inputs = [zpos, cmap_original],outputs=input_vol, show_progress=False).success( fn=viz_binary, inputs = [zpos, cmap_original], outputs=binary_vol, show_progress=False) - + cmap_lt.change( fn=viz_output, inputs = [zpos, cmap_lt], outputs=output_vol, show_progress=False ) @@ -274,7 +271,9 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): except AttributeError: self.vol = data except AssertionError: - raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") + raise gr.Error( + f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}" + ) if dark_objects: self.vol = np.invert(self.vol) @@ -283,15 +282,22 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): self.vmin = np.min(self.vol) self.vmax = np.max(self.vol) - def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"): + def show_slice( + self, + vol: np.ndarray, + zpos: int, + vmin: float = None, + vmax: float = None, + cmap: str = 'viridis', + ): plt.close() z_idx = int(zpos * (vol.shape[0] - 1)) fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) - ax.imshow(vol[z_idx], interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) + ax.imshow(vol[z_idx], interpolation='nearest', cmap=cmap, vmin=vmin, vmax=vmax) # Adjustments - ax.axis("off") + ax.axis('off') fig.subplots_adjust(left=0, right=1, bottom=0, top=1) return fig @@ -300,7 +306,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): # Make a binary volume # Nothing fancy, but we could add new features here self.vol_binary = self.vol > (threshold * np.max(self.vol)) - + def compute_localthickness(self, lt_scale: float): self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale) @@ -318,29 +324,30 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): fig, ax = plt.subplots(figsize=(6, 4)) ax.bar( - bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge" + bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge' ) # Adjustments - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["left"].set_visible(True) - ax.spines["bottom"].set_visible(True) - ax.set_yscale("log") + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['left'].set_visible(True) + ax.spines['bottom'].set_visible(True) + ax.set_yscale('log') return fig def save_lt(self): - filename = "localthickness.tif" + filename = 'localthickness.tif' # Save output image in a temp space tifffile.imwrite(filename, self.vol_thickness) return filename - + def remove_unused_file(self): # Remove localthickness.tif file from working directory # as it otherwise is not deleted os.remove('localthickness.tif') - -if __name__ == "__main__": - Interface().run_interface() \ No newline at end of file + + +if __name__ == '__main__': + Interface().run_interface() diff --git a/qim3d/gui/qim_theme.py b/qim3d/gui/qim_theme.py index 5cdc844b695a0fb64e6ff6588ea6b48f0115088e..2dcd6992ae5ed9263dc4c6d13a7d9faa4e3ebaf6 100644 --- a/qim3d/gui/qim_theme.py +++ b/qim3d/gui/qim_theme.py @@ -1,30 +1,34 @@ import gradio as gr + class QimTheme(gr.themes.Default): + """ Theme for qim3d gradio interfaces. The theming options are quite broad. However if there is something you can not achieve with this theme there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom in comments. """ + def __init__(self, force_light_mode: bool = True): """ - Parameters: - ----------- - - force_light_mode (bool, optional): Gradio themes have dark mode by default. + Parameters + ---------- + - force_light_mode (bool, optional): Gradio themes have dark mode by default. QIM platform is not ready for dark mode yet, thus the tools should also be in light mode. This sets the darkmode values to be the same as light mode values. + """ super().__init__() self.force_light_mode = force_light_mode - self.general_values() # Not color related + self.general_values() # Not color related self.set_light_mode_values() - self.set_dark_mode_values() # Checks the light mode setting inside + self.set_dark_mode_values() # Checks the light mode setting inside def general_values(self): self.set_button() self.set_h1() - + def set_light_mode_values(self): self.set_light_primary_button() self.set_light_secondary_button() @@ -34,8 +38,14 @@ class QimTheme(gr.themes.Default): def set_dark_mode_values(self): if self.force_light_mode: - for attr in [dark_attr for dark_attr in dir(self) if not dark_attr.startswith("_") and dark_attr.endswith("dark")]: - self.__dict__[attr] = self.__dict__[attr[:-5]] # ligth and dark attributes have same names except for '_dark' at the end + for attr in [ + dark_attr + for dark_attr in dir(self) + if not dark_attr.startswith('_') and dark_attr.endswith('dark') + ]: + self.__dict__[attr] = self.__dict__[ + attr[:-5] + ] # ligth and dark attributes have same names except for '_dark' at the end else: self.set_dark_primary_button() # Secondary button looks good by default in dark mode @@ -44,26 +54,28 @@ class QimTheme(gr.themes.Default): # Example looks good by default in dark mode def set_button(self): - self.button_transition = "0.15s" - self.button_large_text_weight = "normal" + self.button_transition = '0.15s' + self.button_large_text_weight = 'normal' def set_light_primary_button(self): - self.run_color = "#198754" - self.button_primary_background_fill = "#FFFFFF" + self.run_color = '#198754' + self.button_primary_background_fill = '#FFFFFF' self.button_primary_background_fill_hover = self.run_color self.button_primary_border_color = self.run_color self.button_primary_text_color = self.run_color - self.button_primary_text_color_hover = "#FFFFFF" + self.button_primary_text_color_hover = '#FFFFFF' def set_dark_primary_button(self): - self.bright_run_color = "#299764" - self.button_primary_background_fill_dark = self.button_primary_background_fill_hover + self.bright_run_color = '#299764' + self.button_primary_background_fill_dark = ( + self.button_primary_background_fill_hover + ) self.button_primary_background_fill_hover_dark = self.bright_run_color self.button_primary_border_color_dark = self.button_primary_border_color self.button_primary_border_color_hover_dark = self.bright_run_color def set_light_secondary_button(self): - self.button_secondary_background_fill = "white" + self.button_secondary_background_fill = 'white' def set_light_example(self): """ @@ -73,10 +85,10 @@ class QimTheme(gr.themes.Default): self.color_accent_soft = self.neutral_100 def set_h1(self): - self.text_xxl = "2.5rem" + self.text_xxl = '2.5rem' def set_light_checkbox(self): - light_blue = "#60a5fa" + light_blue = '#60a5fa' self.checkbox_background_color_selected = light_blue self.checkbox_border_color_selected = light_blue self.checkbox_border_color_focus = light_blue @@ -86,21 +98,20 @@ class QimTheme(gr.themes.Default): self.checkbox_border_color_focus_dark = self.checkbox_border_color_focus_dark def set_light_cancel_button(self): - self.cancel_color = "#dc3545" - self.button_cancel_background_fill = "white" + self.cancel_color = '#dc3545' + self.button_cancel_background_fill = 'white' self.button_cancel_background_fill_hover = self.cancel_color self.button_cancel_border_color = self.cancel_color self.button_cancel_text_color = self.cancel_color - self.button_cancel_text_color_hover = "white" + self.button_cancel_text_color_hover = 'white' def set_dark_cancel_button(self): self.button_cancel_background_fill_dark = self.cancel_color - self.button_cancel_background_fill_hover_dark = "red" + self.button_cancel_background_fill_hover_dark = 'red' self.button_cancel_border_color_dark = self.cancel_color - self.button_cancel_border_color_hover_dark = "red" - self.button_cancel_text_color_dark = "white" + self.button_cancel_border_color_hover_dark = 'red' + self.button_cancel_text_color_dark = 'white' # def _get_theme_css(self): # sup = super()._get_theme_css() # return "\n.svelte-182fdeq {\nbackground: rgba(255, 0, 0, 0.5) !important;\n}\n" + sup # You have to use !important, so it overrides other css - \ No newline at end of file diff --git a/qim3d/io/__init__.py b/qim3d/io/__init__.py index 8456005474cbda035509ff635a784e8330facc3c..db64b581d28e70fbfe6e63ce58d787d862b8fba7 100644 --- a/qim3d/io/__init__.py +++ b/qim3d/io/__init__.py @@ -1,6 +1,6 @@ +# from ._sync import Sync # this will be added back after future development from ._loading import load, load_mesh from ._downloader import Downloader from ._saving import save, save_mesh -# from ._sync import Sync # this will be added back after future development from ._convert import convert from ._ome_zarr import export_ome_zarr, import_ome_zarr diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py index 834e298c007f90d6350c731995d3bc235c144c5a..fbfe75a9b9425246dedbd05198b61181a8c84f1a 100644 --- a/qim3d/io/_convert.py +++ b/qim3d/io/_convert.py @@ -6,21 +6,24 @@ import nibabel as nib import numpy as np import tifffile as tiff import zarr -from tqdm import tqdm import zarr.core +import qim3d + +from tqdm import tqdm from qim3d.utils._misc import stringify_path -from qim3d.io import save class Convert: def __init__(self, **kwargs): - """Utility class to convert files to other formats without loading the entire file into memory + """ + Utility class to convert files to other formats without loading the entire file into memory Args: chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64). + """ - self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64)) + self.chunk_shape = kwargs.get('chunk_shape', (64, 64, 64)) def convert(self, input_path: str, output_path: str): def get_file_extension(file_path): @@ -29,6 +32,7 @@ class Convert: root, ext2 = os.path.splitext(root) ext = ext2 + ext return ext + # Stringify path in case it is not already a string input_path = stringify_path(input_path) input_ext = get_file_extension(input_path) @@ -37,28 +41,30 @@ class Convert: if os.path.isfile(input_path): match input_ext, output_ext: - case (".tif", ".zarr") | (".tiff", ".zarr"): + case ('.tif', '.zarr') | ('.tiff', '.zarr'): return self.convert_tif_to_zarr(input_path, output_path) - case (".nii", ".zarr") | (".nii.gz", ".zarr"): + case ('.nii', '.zarr') | ('.nii.gz', '.zarr'): return self.convert_nifti_to_zarr(input_path, output_path) case _: - raise ValueError("Unsupported file format") + raise ValueError('Unsupported file format') # Load a directory elif os.path.isdir(input_path): match input_ext, output_ext: - case (".zarr", ".tif") | (".zarr", ".tiff"): + case ('.zarr', '.tif') | ('.zarr', '.tiff'): return self.convert_zarr_to_tif(input_path, output_path) - case (".zarr", ".nii"): + case ('.zarr', '.nii'): return self.convert_zarr_to_nifti(input_path, output_path) - case (".zarr", ".nii.gz"): - return self.convert_zarr_to_nifti(input_path, output_path, compression=True) + case ('.zarr', '.nii.gz'): + return self.convert_zarr_to_nifti( + input_path, output_path, compression=True + ) case _: - raise ValueError("Unsupported file format") + raise ValueError('Unsupported file format') # Fail else: # Find the closest matching path to warn the user - parent_dir = os.path.dirname(input_path) or "." - parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" + parent_dir = os.path.dirname(input_path) or '.' + parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else '' valid_paths = [os.path.join(parent_dir, file) for file in parent_files] similar_paths = difflib.get_close_matches(input_path, valid_paths) if similar_paths: @@ -66,10 +72,11 @@ class Convert: message = f"Invalid path. Did you mean '{suggestion}'?" raise ValueError(repr(message)) else: - raise ValueError("Invalid path") + raise ValueError('Invalid path') def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array: - """Convert a tiff file to a zarr file + """ + Convert a tiff file to a zarr file Args: tif_path (str): path to the tiff file @@ -77,10 +84,15 @@ class Convert: Returns: zarr.core.Array: zarr array containing the data from the tiff file + """ vol = tiff.memmap(tif_path) z = zarr.open( - zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype + zarr_path, + mode='w', + shape=vol.shape, + chunks=self.chunk_shape, + dtype=vol.dtype, ) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) # ! Fastest way is z[:] = vol[:], but does not have a progress bar @@ -98,7 +110,8 @@ class Convert: return z def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None: - """Convert a zarr file to a tiff file + """ + Convert a zarr file to a tiff file Args: zarr_path (str): path to the zarr file @@ -106,12 +119,14 @@ class Convert: returns: None + """ z = zarr.open(zarr_path) - save(tif_path, z) + qim3d.io.save(tif_path, z) def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array: - """Convert a nifti file to a zarr file + """ + Convert a nifti file to a zarr file Args: nifti_path (str): path to the nifti file @@ -119,10 +134,15 @@ class Convert: Returns: zarr.core.Array: zarr array containing the data from the nifti file + """ vol = nib.load(nifti_path).dataobj z = zarr.open( - zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype + zarr_path, + mode='w', + shape=vol.shape, + chunks=self.chunk_shape, + dtype=vol.dtype, ) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) # ! Fastest way is z[:] = vol[:], but does not have a progress bar @@ -139,8 +159,11 @@ class Convert: return z - def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None: - """Convert a zarr file to a nifti file + def convert_zarr_to_nifti( + self, zarr_path: str, nifti_path: str, compression: bool = False + ) -> None: + """ + Convert a zarr file to a nifti file Args: zarr_path (str): path to the zarr file @@ -148,18 +171,23 @@ class Convert: Returns: None + """ z = zarr.open(zarr_path) - save(nifti_path, z, compression=compression) - + qim3d.io.save(nifti_path, z, compression=compression) + -def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None: - """Convert a file to another format without loading the entire file into memory +def convert( + input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64) +) -> None: + """ + Convert a file to another format without loading the entire file into memory Args: input_path (str): path to the input file output_path (str): path to the output file chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64). + """ converter = Convert(chunk_shape=chunk_shape) converter.convert(input_path, output_path) diff --git a/qim3d/io/_downloader.py b/qim3d/io/_downloader.py index 1de7bcd91062fc20f35d987be01719da20cf028d..bede79ef1bcb2cd628846ad67eb16ad1b0f09584 100644 --- a/qim3d/io/_downloader.py +++ b/qim3d/io/_downloader.py @@ -1,23 +1,24 @@ -"Manages downloads and access to data" +"""Manages downloads and access to data""" import os import urllib.request - from urllib.parse import quote + +import outputformat as ouf from tqdm import tqdm -from pathlib import Path from qim3d.io import load from qim3d.utils import log -import outputformat as ouf class Downloader: - """Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/). + + """ + Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/). Attributes: folder_name (str or os.PathLike): Folder class with the name of the folder in <https://data.qim.dk/> - + Methods: list_files(): Prints the downloadable files from the QIM data repository. @@ -51,25 +52,26 @@ class Downloader: Example: ```python import qim3d - + downloader = qim3d.io.Downloader() - downloader.list_files() + downloader.list_files() data = downloader.Cowry_Shell.Cowry_DOWNSAMPLED(load_file=True) qim3d.viz.slicer_orthogonal(data, color_map="magma") ```  + """ def __init__(self): folders = _extract_names() for idx, folder in enumerate(folders): - exec(f"self.{folder} = self._Myfolder(folder)") + exec(f'self.{folder} = self._Myfolder(folder)') def list_files(self): """Generate and print formatted folder, file, and size information.""" - url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository" + url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository' folders = _extract_names() @@ -78,17 +80,20 @@ class Downloader: files = _extract_names(folder) for file in files: - url = os.path.join(url_dl, folder, file).replace("\\", "/") + url = os.path.join(url_dl, folder, file).replace('\\', '/') file_size = _get_file_size(url) - formatted_file = f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}" + formatted_file = ( + f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}" + ) formatted_size = _format_file_size(file_size) path_string = f'{folder}.{formatted_file}' log.info(f'{path_string:<50}({formatted_size})') - - + class _Myfolder: - """Class for extracting the files from each folder in the Downloader class. + + """ + Class for extracting the files from each folder in the Downloader class. Args: folder(str): name of the folder of interest in the QIM data repository. @@ -99,6 +104,7 @@ class Downloader: [file_name_2](load_file,optional): Function to download file number 2 in the given folder. ... [file_name_n](load_file,optional): Function to download file number n in the given folder. + """ def __init__(self, folder: str): @@ -107,14 +113,15 @@ class Downloader: for idx, file in enumerate(files): # Changes names to usable function name. file_name = file - if ("%20" in file) or ("-" in file): - file_name = file_name.replace("%20", "_") - file_name = file_name.replace("-", "_") + if ('%20' in file) or ('-' in file): + file_name = file_name.replace('%20', '_') + file_name = file_name.replace('-', '_') setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file)) def _make_fn(self, folder: str, file: str): - """Private method that returns a function. The function downloads the chosen file from the folder. + """ + Private method that returns a function. The function downloads the chosen file from the folder. Args: folder(str): Folder where the file is located. @@ -122,23 +129,26 @@ class Downloader: Returns: function: the function used to download the file. + """ - url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository" + url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository' def _download(load_file: bool = False, virtual_stack: bool = True): - """Downloads the file and optionally also loads it. + """ + Downloads the file and optionally also loads it. Args: load_file(bool,optional): Whether to simply download or also load the file. Returns: virtual_stack: The loaded image. + """ download_file(url_dl, folder, file) if load_file == True: - log.info(f"\nLoading {file}") + log.info(f'\nLoading {file}') file_path = os.path.join(folder, file) return load(path=file_path, virtual_stack=virtual_stack) @@ -159,38 +169,40 @@ def _get_file_size(url: str): Helper function for the ´download_file()´ function. Finds the size of the file. """ - return int(urllib.request.urlopen(url).info().get("Content-Length", -1)) + return int(urllib.request.urlopen(url).info().get('Content-Length', -1)) def download_file(path: str, name: str, file: str): - """Downloads the file from path / name / file. + """ + Downloads the file from path / name / file. Args: path(str): path to the folders available. name(str): name of the folder of interest. file(str): name of the file to be downloaded. + """ if not os.path.exists(name): os.makedirs(name) - url = os.path.join(path, name, file).replace("\\", "/") # if user is on windows + url = os.path.join(path, name, file).replace('\\', '/') # if user is on windows file_path = os.path.join(name, file) if os.path.exists(file_path): - log.warning(f"File already downloaded:\n{os.path.abspath(file_path)}") + log.warning(f'File already downloaded:\n{os.path.abspath(file_path)}') return else: log.info( - f"Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}" + f'Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}' ) - if " " in url: - url = quote(url, safe=":/") + if ' ' in url: + url = quote(url, safe=':/') with tqdm( total=_get_file_size(url), - unit="B", + unit='B', unit_scale=True, unit_divisor=1024, ncols=80, @@ -203,28 +215,31 @@ def download_file(path: str, name: str, file: str): def _extract_html(url: str): - """Extracts the html content of a webpage in "utf-8" + """ + Extracts the html content of a webpage in "utf-8" Args: url(str): url to the location where all the data is stored. Returns: html_content(str): decoded html. + """ try: with urllib.request.urlopen(url) as response: html_content = response.read().decode( - "utf-8" + 'utf-8' ) # Assuming the content is in UTF-8 encoding except urllib.error.URLError as e: - log.warning(f"Failed to retrieve data from {url}. Error: {e}") + log.warning(f'Failed to retrieve data from {url}. Error: {e}') return html_content def _extract_names(name: str = None): - """Extracts the names of the folders and files. + """ + Extracts the names of the folders and files. Finds the names of either the folders if no name is given, or all the names of all files in the given folder. @@ -235,31 +250,33 @@ def _extract_names(name: str = None): Returns: list: If name is None, returns a list of all folders available. If name is not None, returns a list of all files available in the given 'name' folder. + """ - url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository" + url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository' if name: - datapath = os.path.join(url, name).replace("\\", "/") + datapath = os.path.join(url, name).replace('\\', '/') html_content = _extract_html(datapath) data_split = html_content.split( - "files/public/projects/viscomp_data_repository/" + 'files/public/projects/viscomp_data_repository/' )[3:] data_files = [ - element.split(" ")[0][(len(name) + 1) : -3] for element in data_split + element.split(' ')[0][(len(name) + 1) : -3] for element in data_split ] return data_files else: html_content = _extract_html(url) split = html_content.split('"icon-folder-open">')[2:] - folders = [element.split(" ")[0][4:-4] for element in split] + folders = [element.split(' ')[0][4:-4] for element in split] return folders + def _format_file_size(size_in_bytes): # Define size units - units = ["B", "KB", "MB", "GB", "TB", "PB"] + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] size = float(size_in_bytes) unit_index = 0 @@ -269,4 +286,4 @@ def _format_file_size(size_in_bytes): unit_index += 1 # Format the size with 1 decimal place - return f"{size:.2f}{units[unit_index]}" + return f'{size:.2f}{units[unit_index]}' diff --git a/qim3d/io/_loading.py b/qim3d/io/_loading.py index df8320c9aefbeb25c9f02a3f176450bd58229747..cd041a1f03d348134eaa39f80451c4e1a00d120d 100644 --- a/qim3d/io/_loading.py +++ b/qim3d/io/_loading.py @@ -13,31 +13,33 @@ Example: import difflib import os import re -from pathlib import Path +from typing import Dict, Optional import dask import dask.array as da import numpy as np import tifffile +import trimesh from dask import delayed from PIL import Image, UnidentifiedImageError import qim3d -from qim3d.utils import log +from qim3d.utils import Memory, log from qim3d.utils._misc import get_file_size, sizeof, stringify_path -from qim3d.utils import Memory from qim3d.utils._progress_bar import FileLoadingProgressBar import trimesh - +from pygel3d import hmesh from typing import Optional, Dict dask.config.set(scheduler="processes") class DataLoader: - """Utility class for loading data from different file formats. - Attributes: + """ + Utility class for loading data from different file formats. + + Attributes virtual_stack (bool): Specifies whether virtual stack is enabled. dataset_name (str): Specifies the name of the dataset to be loaded (only relevant for HDF5 files) @@ -46,17 +48,19 @@ class DataLoader: contains (str): Specifies a part of the name that is common for the TIFF file stack to be loaded (only relevant for TIFF stacks) - Methods: + Methods load_tiff(path): Load a TIFF file from the specified path. load_h5(path): Load an HDF5 file from the specified path. load_tiff_stack(path): Load a stack of TIFF files from the specified path. load_txrm(path): Load a TXRM/TXM/XRM file from the specified path load_vol(path): Load a VOL file from the specified path. Path should point to the .vgi metadata file load(path): Load a file or directory based on the given path + """ def __init__(self, **kwargs): - """Initializes a new instance of the DataLoader class. + """ + Initializes a new instance of the DataLoader class. Args: virtual_stack (bool, optional): Specifies whether to use virtual @@ -69,17 +73,19 @@ class DataLoader: force_load (bool, optional): If false and user tries to load file that exceeds available memory, throws a MemoryError. If true, this error is changed to warning and dataloader tries to load the file. Default is False. dim_order (tuple, optional): The order of the dimensions in the volume. Default is (2,1,0) which corresponds to (z,y,x) - """ - self.virtual_stack = kwargs.get("virtual_stack", False) - self.dataset_name = kwargs.get("dataset_name", None) - self.return_metadata = kwargs.get("return_metadata", False) - self.contains = kwargs.get("contains", None) - self.force_load = kwargs.get("force_load", False) - self.dim_order = kwargs.get("dim_order", (2, 1, 0)) - self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp") - def load_tiff(self, path: str|os.PathLike): - """Load a TIFF file from the specified path. + """ + self.virtual_stack = kwargs.get('virtual_stack', False) + self.dataset_name = kwargs.get('dataset_name', None) + self.return_metadata = kwargs.get('return_metadata', False) + self.contains = kwargs.get('contains', None) + self.force_load = kwargs.get('force_load', False) + self.dim_order = kwargs.get('dim_order', (2, 1, 0)) + self.PIL_extensions = ('.jp2', '.jpg', 'jpeg', '.png', 'gif', '.bmp', '.webp') + + def load_tiff(self, path: str | os.PathLike): + """ + Load a TIFF file from the specified path. Args: path (str): The path to the TIFF file. @@ -98,12 +104,13 @@ class DataLoader: else: vol = tifffile.imread(path, key=range(series) if series > 1 else None) - log.info("Loaded shape: %s", vol.shape) + log.info('Loaded shape: %s', vol.shape) return vol - def load_h5(self, path: str|os.PathLike) -> tuple[np.ndarray, Optional[Dict]]: - """Load an HDF5 file from the specified path. + def load_h5(self, path: str | os.PathLike) -> tuple[np.ndarray, Optional[Dict]]: + """ + Load an HDF5 file from the specified path. Args: path (str): The path to the HDF5 file. @@ -117,11 +124,12 @@ class DataLoader: ValueError: If the specified dataset_name is not found or is invalid. ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file ValueError: If no datasets are found in the file. + """ import h5py # Read file - f = h5py.File(path, "r") + f = h5py.File(path, 'r') data_keys = _get_h5_dataset_keys(f) datasets = [] metadata = {} @@ -132,7 +140,7 @@ class DataLoader: datasets.append(key) if f[key].attrs.keys(): metadata[key] = { - "value": f[key][()], + 'value': f[key][()], **{attr_key: val for attr_key, val in f[key].attrs.items()}, } @@ -162,7 +170,7 @@ class DataLoader: ) else: raise ValueError( - f"Invalid dataset name. Please choose between the following datasets: {datasets}" + f'Invalid dataset name. Please choose between the following datasets: {datasets}' ) else: raise ValueError( @@ -171,22 +179,23 @@ class DataLoader: # No datasets were found else: - raise ValueError(f"Did not find any data in the file: {path}") + raise ValueError(f'Did not find any data in the file: {path}') if not self.virtual_stack: vol = vol[()] # Load dataset into memory f.close() - log.info("Loaded the following dataset: %s", name) - log.info("Loaded shape: %s", vol.shape) + log.info('Loaded the following dataset: %s', name) + log.info('Loaded shape: %s', vol.shape) if self.return_metadata: return vol, metadata else: return vol - def load_tiff_stack(self, path: str|os.PathLike) -> np.ndarray|np.memmap: - """Load a stack of TIFF files from the specified path. + def load_tiff_stack(self, path: str | os.PathLike) -> np.ndarray | np.memmap: + """ + Load a stack of TIFF files from the specified path. Args: path (str): The path to the stack of TIFF files. @@ -198,6 +207,7 @@ class DataLoader: Raises: ValueError: If the 'contains' argument is not specified. ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory + """ if not self.contains: raise ValueError( @@ -207,7 +217,7 @@ class DataLoader: tiff_stack = [ file for file in os.listdir(path) - if (file.endswith(".tif") or file.endswith(".tiff")) + if (file.endswith('.tif') or file.endswith('.tiff')) and self.contains in file ] tiff_stack.sort() # Ensure proper ordering @@ -217,30 +227,33 @@ class DataLoader: for filename in tiff_stack: name = os.path.splitext(filename)[0] # Remove file extension tiff_stack_only_letters.append( - "".join(filter(str.isalpha, name)) + ''.join(filter(str.isalpha, name)) ) # Remove everything else than letters from the name # Get unique elements from tiff_stack_only_letters unique_names = list(set(tiff_stack_only_letters)) if len(unique_names) > 1: raise ValueError( - f"The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded" + f'The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded' ) vol = tifffile.imread( - [os.path.join(path, file) for file in tiff_stack], out="memmap" + [os.path.join(path, file) for file in tiff_stack], out='memmap' ) if not self.virtual_stack: vol = np.copy(vol) # Copy to memory - log.info("Found %s file(s)", len(tiff_stack)) - log.info("Loaded shape: %s", vol.shape) + log.info('Found %s file(s)', len(tiff_stack)) + log.info('Loaded shape: %s', vol.shape) return vol - def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]: - """Load a TXRM/XRM/TXM file from the specified path. + def load_txrm( + self, path: str | os.PathLike + ) -> tuple[dask.array.core.Array | np.ndarray, Optional[Dict]]: + """ + Load a TXRM/XRM/TXM file from the specified path. Args: path (str): The path to the TXRM/TXM file. @@ -252,6 +265,7 @@ class DataLoader: Raises: ValueError: If the dxchange library is not installed + """ import olefile @@ -259,13 +273,13 @@ class DataLoader: import dxchange except ImportError: raise ValueError( - "The library dxchange is required to load TXRM files. Please find installation instructions at https://dxchange.readthedocs.io/en/latest/source/install.html" + 'The library dxchange is required to load TXRM files. Please find installation instructions at https://dxchange.readthedocs.io/en/latest/source/install.html' ) if self.virtual_stack: - if not path.endswith(".txm"): + if not path.endswith('.txm'): log.warning( - "Virtual stack is only thoroughly tested for reconstructed volumes in TXM format and is thus not guaranteed to load TXRM and XRM files correctly" + 'Virtual stack is only thoroughly tested for reconstructed volumes in TXM format and is thus not guaranteed to load TXRM and XRM files correctly' ) # Get metadata @@ -275,7 +289,7 @@ class DataLoader: # Compute data offsets in bytes for each slice offsets = _get_ole_offsets(ole) - if len(offsets) != metadata["number_of_images"]: + if len(offsets) != metadata['number_of_images']: raise ValueError( f'Metadata is erroneous: number of images {metadata["number_of_images"]} is different from number of data offsets {len(offsets)}' ) @@ -286,17 +300,17 @@ class DataLoader: np.memmap( path, dtype=dxchange.reader._get_ole_data_type(metadata).newbyteorder( - "<" + '<' ), - mode="r", + mode='r', offset=offset, - shape=(1, metadata["image_height"], metadata["image_width"]), + shape=(1, metadata['image_height'], metadata['image_width']), ) ) vol = da.concatenate(slices, axis=0) log.warning( - "Virtual stack volume will be returned as a dask array. To load certain slices into memory, use normal indexing followed by the compute() method, e.g. vol[:,0,:].compute()" + 'Virtual stack volume will be returned as a dask array. To load certain slices into memory, use normal indexing followed by the compute() method, e.g. vol[:,0,:].compute()' ) else: @@ -310,8 +324,9 @@ class DataLoader: else: return vol - def load_nifti(self, path: str|os.PathLike): - """Load a NIfTI file from the specified path. + def load_nifti(self, path: str | os.PathLike): + """ + Load a NIfTI file from the specified path. Args: path (str): The path to the NIfTI file. @@ -320,6 +335,7 @@ class DataLoader: numpy.ndarray, nibabel.arrayproxy.ArrayProxy or tuple: The loaded volume. If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object If 'self.return_metadata' is True, returns a tuple (volume, metadata). + """ import nibabel as nib @@ -340,19 +356,22 @@ class DataLoader: else: return vol - def load_pil(self, path: str|os.PathLike): - """Load a PIL image from the specified path + def load_pil(self, path: str | os.PathLike): + """ + Load a PIL image from the specified path Args: path (str): The path to the image supported by PIL. Returns: numpy.ndarray: The loaded image/volume. + """ return np.array(Image.open(path)) - def load_PIL_stack(self, path: str|os.PathLike): - """Load a stack of PIL files from the specified path. + def load_PIL_stack(self, path: str | os.PathLike): + """ + Load a stack of PIL files from the specified path. Args: path (str): The path to the stack of PIL files. @@ -364,6 +383,7 @@ class DataLoader: Raises: ValueError: If the 'contains' argument is not specified. ValueError: If the 'contains' argument matches multiple PIL stacks in the directory + """ if not self.contains: raise ValueError( @@ -384,18 +404,17 @@ class DataLoader: for filename in PIL_stack: name = os.path.splitext(filename)[0] # Remove file extension PIL_stack_only_letters.append( - "".join(filter(str.isalpha, name)) + ''.join(filter(str.isalpha, name)) ) # Remove everything else than letters from the name # Get unique elements unique_names = list(set(PIL_stack_only_letters)) if len(unique_names) > 1: raise ValueError( - f"The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded" + f'The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded' ) if self.virtual_stack: - full_paths = [os.path.join(path, file) for file in PIL_stack] def lazy_loader(path): @@ -411,38 +430,38 @@ class DataLoader: # Stack the images into a single Dask array dask_images = [ - da.from_delayed(img, shape=image_shape, dtype=dtype) for img in lazy_images + da.from_delayed(img, shape=image_shape, dtype=dtype) + for img in lazy_images ] stacked = da.stack(dask_images, axis=0) return stacked - + else: # Generate placeholder volume first_image = self.load_pil(os.path.join(path, PIL_stack[0])) - vol = np.zeros((len(PIL_stack), *first_image.shape), dtype=first_image.dtype) + vol = np.zeros( + (len(PIL_stack), *first_image.shape), dtype=first_image.dtype + ) # Load file sequence for idx, file_name in enumerate(PIL_stack): - vol[idx] = self.load_pil(os.path.join(path, file_name)) return vol - - # log.info("Found %s file(s)", len(PIL_stack)) # log.info("Loaded shape: %s", vol.shape) - - - def _load_vgi_metadata(self, path: str|os.PathLike): - """Helper functions that loads metadata from a VGI file + def _load_vgi_metadata(self, path: str | os.PathLike): + """ + Helper functions that loads metadata from a VGI file Args: path (str): The path to the VGI file. returns: dict: The loaded metadata. + """ meta_data = {} current_section = meta_data @@ -450,11 +469,11 @@ class DataLoader: should_indent = True - with open(path, "r") as f: + with open(path) as f: for line in f: line = line.strip() # {NAME} is start of a new object, so should indent - if line.startswith("{") and line.endswith("}"): + if line.startswith('{') and line.endswith('}'): section_name = line[1:-1] current_section[section_name] = {} section_stack.append(current_section) @@ -462,7 +481,7 @@ class DataLoader: should_indent = True # [NAME] is start of a section, so should not indent - elif line.startswith("[") and line.endswith("]"): + elif line.startswith('[') and line.endswith(']'): section_name = line[1:-1] if not should_indent: @@ -475,17 +494,18 @@ class DataLoader: should_indent = False # = is a key value pair - elif "=" in line: - key, value = line.split("=", 1) + elif '=' in line: + key, value = line.split('=', 1) current_section[key.strip()] = value.strip() - elif line == "": + elif line == '': if len(section_stack) > 1: current_section = section_stack.pop() return meta_data - def load_vol(self, path: str|os.PathLike): - """Load a VOL filed based on the VGI metadata file + def load_vol(self, path: str | os.PathLike): + """ + Load a VOL filed based on the VGI metadata file Args: path (str): The path to the VGI file. @@ -496,43 +516,44 @@ class DataLoader: returns: numpy.ndarray, numpy.memmap or tuple: The loaded volume. If 'self.return_metadata' is True, returns a tuple (volume, metadata). + """ # makes sure path point to .VGI metadata file and not the .VOL file - if path.endswith(".vol") and os.path.isfile(path.replace(".vol", ".vgi")): - path = path.replace(".vol", ".vgi") - log.warning("Corrected path to .vgi metadata file from .vol file") - elif path.endswith(".vol") and not os.path.isfile(path.replace(".vol", ".vgi")): + if path.endswith('.vol') and os.path.isfile(path.replace('.vol', '.vgi')): + path = path.replace('.vol', '.vgi') + log.warning('Corrected path to .vgi metadata file from .vol file') + elif path.endswith('.vol') and not os.path.isfile(path.replace('.vol', '.vgi')): raise ValueError( - f"Unsupported file format, should point to .vgi metadata file assumed to be in same folder as .vol file: {path}" + f'Unsupported file format, should point to .vgi metadata file assumed to be in same folder as .vol file: {path}' ) meta_data = self._load_vgi_metadata(path) # Extracts relevant information from the metadata - file_name = meta_data["volume1"]["file1"]["Name"] - path = path.rsplit("/", 1)[ + file_name = meta_data['volume1']['file1']['Name'] + path = path.rsplit('/', 1)[ 0 ] # Remove characters after the last "/" to be replaced with .vol filename vol_path = os.path.join( path, file_name ) # .vol and .vgi files are assumed to be in the same directory - dims = meta_data["volume1"]["file1"]["Size"] + dims = meta_data['volume1']['file1']['Size'] dims = [int(n) for n in dims.split() if n.isdigit()] - dt = meta_data["volume1"]["file1"]["Datatype"] + dt = meta_data['volume1']['file1']['Datatype'] match dt: - case "float": + case 'float': dt = np.float32 - case "float32": + case 'float32': dt = np.float32 - case "uint8": + case 'uint8': dt = np.uint8 - case "unsigned integer": + case 'unsigned integer': dt = np.uint16 - case "uint16": + case 'uint16': dt = np.uint16 case _: - raise ValueError(f"Unsupported data type: {dt}") + raise ValueError(f'Unsupported data type: {dt}') dims_order = ( dims[self.dim_order[0]], @@ -540,7 +561,7 @@ class DataLoader: dims[self.dim_order[2]], ) if self.virtual_stack: - vol = np.memmap(vol_path, dtype=dt, mode="r", shape=dims_order) + vol = np.memmap(vol_path, dtype=dt, mode='r', shape=dims_order) else: vol = np.fromfile(vol_path, dtype=dt, count=np.prod(dims)) vol = np.reshape(vol, dims_order) @@ -550,11 +571,13 @@ class DataLoader: else: return vol - def load_dicom(self, path: str|os.PathLike): - """Load a DICOM file + def load_dicom(self, path: str | os.PathLike): + """ + Load a DICOM file Args: path (str): Path to file + """ import pydicom @@ -565,15 +588,17 @@ class DataLoader: else: return dcm_data.pixel_array - def load_dicom_dir(self, path: str|os.PathLike): - """Load a directory of DICOM files into a numpy 3d array + def load_dicom_dir(self, path: str | os.PathLike): + """ + Load a directory of DICOM files into a numpy 3d array Args: path (str): Directory path - + returns: numpy.ndarray, numpy.memmap or tuple: The loaded volume. If 'self.return_metadata' is True, returns a tuple (volume, metadata). + """ import pydicom @@ -590,14 +615,14 @@ class DataLoader: for filename in dicom_stack: name = os.path.splitext(filename)[0] # Remove file extension dicom_stack_only_letters.append( - "".join(filter(str.isalpha, name)) + ''.join(filter(str.isalpha, name)) ) # Remove everything else than letters from the name # Get unique elements from tiff_stack_only_letters unique_names = list(set(dicom_stack_only_letters)) if len(unique_names) > 1: raise ValueError( - f"The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded" + f'The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded' ) # dicom_list contains the dicom objects with metadata @@ -609,10 +634,10 @@ class DataLoader: return vol, dicom_list else: return vol - - def load_zarr(self, path: str|os.PathLike): - """ Loads a Zarr array from disk. + def load_zarr(self, path: str | os.PathLike): + """ + Loads a Zarr array from disk. Args: path (str): The path to the Zarr array on disk. @@ -620,6 +645,7 @@ class DataLoader: Returns: dask.array | numpy.ndarray: The dask array loaded from disk. if 'self.virtual_stack' is True, returns a dask array object, else returns a numpy.ndarray object. + """ # Opens the Zarr array @@ -634,25 +660,25 @@ class DataLoader: def check_file_size(self, filename: str): """ Checks if there is enough memory where the file can be loaded. + Args: - ------------ + ---- filename: (str) Specifies path to file force_load: (bool, optional) If true, the memory error will not be raised. Warning will be printed insted and the loader will attempt to load the file. Raises: - ----------- + ------ MemoryError: If filesize is greater then available memory + """ - if ( - self.virtual_stack - ): # If virtual_stack is True, then data is loaded from the disk, no need for loading into memory + if self.virtual_stack: # If virtual_stack is True, then data is loaded from the disk, no need for loading into memory return file_size = get_file_size(filename) available_memory = Memory().free if file_size > available_memory: - message = f"The file {filename} has {sizeof(file_size)} but only {sizeof(available_memory)} of memory is available." + message = f'The file {filename} has {sizeof(file_size)} but only {sizeof(available_memory)} of memory is available.' if self.force_load: log.warning(message) else: @@ -660,7 +686,7 @@ class DataLoader: message + " Set 'force_load=True' to ignore this error." ) - def load(self, path: str|os.PathLike): + def load(self, path: str | os.PathLike): """ Load a file or directory based on the given path. @@ -677,6 +703,7 @@ class DataLoader: ValueError: If the format is not supported ValueError: If the file or directory does not exist. MemoryError: If file size exceeds available memory and force_load is not set to True. In check_size function. + """ # Stringify path in case it is not already a string @@ -686,35 +713,35 @@ class DataLoader: if os.path.isfile(path): # Choose the loader based on the file extension self.check_file_size(path) - if path.endswith(".tif") or path.endswith(".tiff"): + if path.endswith('.tif') or path.endswith('.tiff'): return self.load_tiff(path) - elif path.endswith(".h5"): + elif path.endswith('.h5'): return self.load_h5(path) - elif path.endswith((".txrm", ".txm", ".xrm")): + elif path.endswith(('.txrm', '.txm', '.xrm')): return self.load_txrm(path) - elif path.endswith((".nii", ".nii.gz")): + elif path.endswith(('.nii', '.nii.gz')): return self.load_nifti(path) - elif path.endswith((".vol", ".vgi")): + elif path.endswith(('.vol', '.vgi')): return self.load_vol(path) - elif path.endswith((".dcm", ".DCM")): + elif path.endswith(('.dcm', '.DCM')): return self.load_dicom(path) else: try: return self.load_pil(path) except UnidentifiedImageError: - raise ValueError("Unsupported file format") + raise ValueError('Unsupported file format') # Load a directory elif os.path.isdir(path): # load tiff stack if folder contains tiff files else load dicom directory if any( - [f.endswith(".tif") or f.endswith(".tiff") for f in os.listdir(path)] + [f.endswith('.tif') or f.endswith('.tiff') for f in os.listdir(path)] ): return self.load_tiff_stack(path) elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]): return self.load_PIL_stack(path) - elif path.endswith(".zarr"): + elif path.endswith('.zarr'): return self.load_zarr(path) else: return self.load_dicom_dir(path) @@ -729,7 +756,7 @@ class DataLoader: message = f"Invalid path. Did you mean '{suggestion}'?" raise ValueError(repr(message)) else: - raise ValueError("Invalid path") + raise ValueError('Invalid path') def _get_h5_dataset_keys(f): @@ -743,18 +770,18 @@ def _get_h5_dataset_keys(f): def _get_ole_offsets(ole): slice_offset = {} for stream in ole.listdir(): - if stream[0].startswith("ImageData"): + if stream[0].startswith('ImageData'): sid = ole._find(stream) direntry = ole.direntries[sid] sect_start = direntry.isectStart offset = ole.sectorsize * (sect_start + 1) - slice_offset[f"{stream[0]}/{stream[1]}"] = offset + slice_offset[f'{stream[0]}/{stream[1]}'] = offset # sort dictionary after natural sorting (https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/) sorted_keys = sorted( slice_offset.keys(), key=lambda string_: [ - int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_) + int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_) ], ) slice_offset_sorted = {key: slice_offset[key] for key in sorted_keys} @@ -763,7 +790,7 @@ def _get_ole_offsets(ole): def load( - path: str|os.PathLike, + path: str | os.PathLike, virtual_stack: bool = False, dataset_name: bool = None, return_metadata: bool = False, @@ -822,6 +849,7 @@ def load( vol = qim3d.io.load("path/to/image.tif", virtual_stack=True) ``` + """ loader = DataLoader( @@ -843,13 +871,13 @@ def load( def log_memory_info(data): mem = Memory() log.info( - "Volume using %s of memory\n", + 'Volume using %s of memory\n', sizeof(data[0].nbytes if isinstance(data, tuple) else data.nbytes), ) mem.report() if return_metadata and not isinstance(data, tuple): - log.warning("The file format does not contain metadata") + log.warning('The file format does not contain metadata') if not virtual_stack: log_memory_info(data) @@ -858,22 +886,31 @@ def load( if not isinstance( type(data[0]) if isinstance(data, tuple) else type(data), np.ndarray ): - log.info("Using virtual stack") + log.info('Using virtual stack') else: - log.warning("Virtual stack is not supported for this file format") + log.warning('Virtual stack is not supported for this file format') log_memory_info(data) return data -def load_mesh(filename: str) -> trimesh.Trimesh: + +def load_mesh(filename: str) -> hmesh.Manifold: """ - Load a mesh from an .obj file using trimesh. + Load a mesh from a specific file. + This function is based on the [PyGEL3D library's loading function implementation](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#load). + + Supported formats: + + - `X3D` + - `OBJ` + - `OFF` + - `PLY` Args: - filename (str or os.PathLike): The path to the .obj file. + filename (str or os.PathLike): The path to the file. Returns: - mesh (trimesh.Trimesh): A trimesh object containing the mesh data (vertices and faces). + mesh (hmesh.Manifold or None): A hmesh object containing the mesh data or None if loading failed. Example: ```python @@ -881,6 +918,8 @@ def load_mesh(filename: str) -> trimesh.Trimesh: mesh = qim3d.io.load_mesh("path/to/mesh.obj") ``` + """ - mesh = trimesh.load(filename) - return mesh + mesh = hmesh.load(filename) + + return mesh \ No newline at end of file diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py index d2358cca6147483a480a41af86ec3c56232f13a8..68eb92d55a1ad2e9e21af81d6b02b8f68370ada3 100644 --- a/qim3d/io/_ome_zarr.py +++ b/qim3d/io/_ome_zarr.py @@ -2,39 +2,27 @@ Exporting data to different formats. """ -import os import math +import os import shutil -import logging +from typing import List, Union +import dask.array as da import numpy as np import zarr -import tqdm +from ome_zarr import scale from ome_zarr.io import parse_url +from ome_zarr.reader import Reader +from ome_zarr.scale import dask_resize from ome_zarr.writer import ( - write_image, - _create_mip, - write_multiscale, CurrentFormat, - Format, + write_multiscale, ) -from ome_zarr.scale import dask_resize -from ome_zarr.reader import Reader -from ome_zarr import scale from scipy.ndimage import zoom -from typing import Any, Callable, Iterator, List, Tuple, Union -import dask.array as da -import dask -from dask.distributed import Client, LocalCluster - -from skimage.transform import ( - resize, -) from qim3d.utils import log -from qim3d.utils._progress_bar import OmeZarrExportProgressBar from qim3d.utils._ome_zarr import get_n_chunks - +from qim3d.utils._progress_bar import OmeZarrExportProgressBar ListOfArrayLike = Union[List[da.Array], List[np.ndarray]] ArrayLike = Union[da.Array, np.ndarray] @@ -43,10 +31,19 @@ ArrayLike = Union[da.Array, np.ndarray] class OMEScaler( scale.Scaler, ): - """Scaler in the style of OME-Zarr. - This is needed because their current zoom implementation is broken.""" - def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"): + """ + Scaler in the style of OME-Zarr. + This is needed because their current zoom implementation is broken. + """ + + def __init__( + self, + order: int = 0, + downscale: float = 2, + max_layer: int = 5, + method: str = 'scaleZYXdask', + ): self.order = order self.downscale = downscale self.max_layer = max_layer @@ -55,11 +52,11 @@ class OMEScaler( def scaleZYX(self, base: da.core.Array): """Downsample using :func:`scipy.ndimage.zoom`.""" rv = [base] - log.info(f"- Scale 0: {rv[-1].shape}") + log.info(f'- Scale 0: {rv[-1].shape}') for i in range(self.max_layer): rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order)) - log.info(f"- Scale {i+1}: {rv[-1].shape}") + log.info(f'- Scale {i+1}: {rv[-1].shape}') return list(rv) @@ -82,8 +79,8 @@ class OMEScaler( """ - def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape): + def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape): # Get the chunksize needed so that all the blocks match the new shape # This snippet comes from the original OME-Zarr-python library better_chunksize = tuple( @@ -92,7 +89,7 @@ class OMEScaler( ).astype(int) ) - log.debug(f"better chunk size: {better_chunksize}") + log.debug(f'better chunk size: {better_chunksize}') # Compute the chunk size after the downscaling new_chunk_size = tuple( @@ -100,20 +97,19 @@ class OMEScaler( ) log.debug( - f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}" + f'orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}' ) def resize_chunk(chunk, scale_factors, order): - - #print(f"zoom factors: {scale_factors}") + # print(f"zoom factors: {scale_factors}") resized_chunk = zoom( chunk, zoom=scale_factors, order=order, - mode="grid-constant", + mode='grid-constant', grid_mode=True, ) - #print(f"resized chunk shape: {resized_chunk.shape}") + # print(f"resized chunk shape: {resized_chunk.shape}") return resized_chunk @@ -121,7 +117,7 @@ class OMEScaler( # Testing new shape predicted_shape = np.multiply(vol.shape, scale_factors) - log.debug(f"predicted shape: {predicted_shape}") + log.debug(f'predicted shape: {predicted_shape}') scaled_vol = da.map_blocks( resize_chunk, vol, @@ -136,7 +132,7 @@ class OMEScaler( return scaled_vol rv = [base] - log.info(f"- Scale 0: {rv[-1].shape}") + log.info(f'- Scale 0: {rv[-1].shape}') for i in range(self.max_layer): log.debug(f"\nScale {i+1}\n{'-'*32}") @@ -147,17 +143,17 @@ class OMEScaler( np.ceil(np.multiply(base.shape, downscale_factor)).astype(int) ) - log.debug(f"target shape: {scaled_shape}") + log.debug(f'target shape: {scaled_shape}') downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float)) - log.debug(f"downscale rate: {downscale_rate}") + log.debug(f'downscale rate: {downscale_rate}') scale_factors = tuple(np.divide(1, downscale_rate)) - log.debug(f"scale factors: {scale_factors}") + log.debug(f'scale factors: {scale_factors}') - log.debug("\nResizing volume chunk-wise") + log.debug('\nResizing volume chunk-wise') scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape) rv.append(scaled_vol) - log.info(f"- Scale {i+1}: {rv[-1].shape}") + log.info(f'- Scale {i+1}: {rv[-1].shape}') return list(rv) @@ -165,10 +161,9 @@ class OMEScaler( """Downsample using the original OME-Zarr python library""" rv = [base] - log.info(f"- Scale 0: {rv[-1].shape}") + log.info(f'- Scale 0: {rv[-1].shape}') for i in range(self.max_layer): - scaled_shape = tuple( base.shape[j] // (self.downscale ** (i + 1)) for j in range(3) ) @@ -176,20 +171,20 @@ class OMEScaler( scaled = dask_resize(base, scaled_shape, order=self.order) rv.append(scaled) - log.info(f"- Scale {i+1}: {rv[-1].shape}") + log.info(f'- Scale {i+1}: {rv[-1].shape}') return list(rv) def export_ome_zarr( - path: str|os.PathLike, - data: np.ndarray|da.core.Array, + path: str | os.PathLike, + data: np.ndarray | da.core.Array, chunk_size: int = 256, downsample_rate: int = 2, order: int = 1, replace: bool = False, - method: str = "scaleZYX", + method: str = 'scaleZYX', progress_bar: bool = True, - progress_bar_repeat_time: str = "auto", + progress_bar_repeat_time: str = 'auto', ) -> None: """ Export 3D image data to OME-Zarr format with pyramidal downsampling. @@ -220,6 +215,7 @@ def export_ome_zarr( qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2) ``` + """ # Check if directory exists @@ -228,19 +224,19 @@ def export_ome_zarr( shutil.rmtree(path) else: raise ValueError( - f"Directory {path} already exists. Use replace=True to overwrite." + f'Directory {path} already exists. Use replace=True to overwrite.' ) # Check if downsample_rate is valid if downsample_rate <= 1: - raise ValueError("Downsample rate must be greater than 1.") + raise ValueError('Downsample rate must be greater than 1.') - log.info(f"Exporting data to OME-Zarr format at {path}") + log.info(f'Exporting data to OME-Zarr format at {path}') # Get the number of scales min_dim = np.max(np.shape(data)) nscales = math.ceil(math.log(min_dim / chunk_size) / math.log(downsample_rate)) - log.info(f"Number of scales: {nscales + 1}") + log.info(f'Number of scales: {nscales + 1}') # Create scaler scaler = OMEScaler( @@ -249,32 +245,31 @@ def export_ome_zarr( # write the image data os.mkdir(path) - store = parse_url(path, mode="w").store + store = parse_url(path, mode='w').store root = zarr.group(store=store) # Check if we want to process using Dask - if "dask" in method and not isinstance(data, da.Array): - log.info("\nConverting input data to Dask array") + if 'dask' in method and not isinstance(data, da.Array): + log.info('\nConverting input data to Dask array') data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size)) - log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") + log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n') - elif "dask" in method and isinstance(data, da.Array): - log.info("\nInput data will be rechunked") + elif 'dask' in method and isinstance(data, da.Array): + log.info('\nInput data will be rechunked') data = data.rechunk((chunk_size, chunk_size, chunk_size)) - log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") - + log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n') - log.info("Calculating the multi-scale pyramid") + log.info('Calculating the multi-scale pyramid') # Generate multi-scale pyramid mip = scaler.func(data) - log.info("Writing data to disk") + log.info('Writing data to disk') kwargs = dict( pyramid=mip, group=root, fmt=CurrentFormat(), - axes="zyx", + axes='zyx', name=None, compute=True, storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)), @@ -291,16 +286,14 @@ def export_ome_zarr( else: write_multiscale(**kwargs) - log.info("\nAll done!") + log.info('\nAll done!') return def import_ome_zarr( - path: str|os.PathLike, - scale: int = 0, - load: bool = True - ) -> np.ndarray: + path: str | os.PathLike, scale: int = 0, load: bool = True +) -> np.ndarray: """ Import image data from an OME-Zarr file. @@ -339,22 +332,22 @@ def import_ome_zarr( image_node = nodes[0] dask_data = image_node.data - log.info(f"Data contains {len(dask_data)} scales:") + log.info(f'Data contains {len(dask_data)} scales:') for i in np.arange(len(dask_data)): - log.info(f"- Scale {i}: {dask_data[i].shape}") + log.info(f'- Scale {i}: {dask_data[i].shape}') - if scale == "highest": + if scale == 'highest': scale = 0 - if scale == "lowest": + if scale == 'lowest': scale = len(dask_data) - 1 if scale >= len(dask_data): raise ValueError( - f"Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}." + f'Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}.' ) - log.info(f"\nLoading scale {scale} with shape {dask_data[scale].shape}") + log.info(f'\nLoading scale {scale} with shape {dask_data[scale].shape}') if load: vol = dask_data[scale].compute() diff --git a/qim3d/io/_saving.py b/qim3d/io/_saving.py index 0a1fae81cbb4de7357b9b9c57d627ed183d04813..cde1c358f041b65826572c7dc1f5d0a69896b506 100644 --- a/qim3d/io/_saving.py +++ b/qim3d/io/_saving.py @@ -5,10 +5,10 @@ Provides functionality for saving data from various file formats. Example: ```python import qim3d - + # Generate synthetic blob synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) - + qim3d.io.save("fly.tif", synthetic_blob) ``` @@ -18,7 +18,7 @@ Example: # Generate synthetic blob synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) - + qim3d.io.save("slices", synthetic_blob, basename="fly-slices", sliced_dim=0) ``` @@ -26,25 +26,29 @@ Example: import datetime import os + import dask.array as da import h5py import nibabel as nib import numpy as np import PIL import tifffile +import trimesh import zarr from pydicom.dataset import FileDataset, FileMetaDataset from pydicom.uid import UID import trimesh - +from pygel3d import hmesh from qim3d.utils import log -from qim3d.utils._misc import sizeof, stringify_path +from qim3d.utils._misc import stringify_path class DataSaver: - """Utility class for saving data to different file formats. - Attributes: + """ + Utility class for saving data to different file formats. + + Attributes replace (bool): Specifies if an existing file with identical path is replaced. compression (bool): Specifies if the file is saved with Deflate compression (lossless). basename (str): Specifies the basename for a TIFF stack saved as several files @@ -52,13 +56,15 @@ class DataSaver: sliced_dim (int): Specifies the dimension that is sliced in case a TIFF stack is saved as several files (only relevant for TIFF stacks) - Methods: + Methods save_tiff(path,data): Save data to a TIFF file to the given path. load(path,data): Save data to the given path. + """ def __init__(self, **kwargs): - """Initializes a new instance of the DataSaver class. + """ + Initializes a new instance of the DataSaver class. Args: replace (bool, optional): Specifies if an existing file with identical path should be replaced. @@ -69,35 +75,40 @@ class DataSaver: (only relevant for TIFF stacks). Default is None sliced_dim (int, optional): Specifies the dimension that is sliced in case a TIFF stack is saved as several files (only relevant for TIFF stacks). Default is 0, i.e., the first dimension. + """ - self.replace = kwargs.get("replace", False) - self.compression = kwargs.get("compression", False) - self.basename = kwargs.get("basename", None) - self.sliced_dim = kwargs.get("sliced_dim", 0) - self.chunk_shape = kwargs.get("chunk_shape", "auto") + self.replace = kwargs.get('replace', False) + self.compression = kwargs.get('compression', False) + self.basename = kwargs.get('basename', None) + self.sliced_dim = kwargs.get('sliced_dim', 0) + self.chunk_shape = kwargs.get('chunk_shape', 'auto') - def save_tiff(self, path: str|os.PathLike, data: np.ndarray): - """Save data to a TIFF file to the given path. + def save_tiff(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to a TIFF file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ tifffile.imwrite(path, data, compression=self.compression) - def save_tiff_stack(self, path: str|os.PathLike, data: np.ndarray): - """Save data as a TIFF stack containing slices in separate files to the given path. + def save_tiff_stack(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data as a TIFF stack containing slices in separate files to the given path. The slices will be named according to the basename plus a suffix with a zero-filled value corresponding to the slice number Args: path (str): The directory to save files to data (numpy.ndarray): The data to be saved + """ - extension = ".tif" + extension = '.tif' if data.ndim <= 2: - path = os.path.join(path, self.basename, ".tif") + path = os.path.join(path, self.basename, '.tif') self.save_tiff(path, data) else: # get number of total slices @@ -117,21 +128,22 @@ class DataSaver: self.save_tiff(filepath, sliced) pattern_string = ( - filepath[: -(len(extension) + zfill_val)] + "-" * zfill_val + extension + filepath[: -(len(extension) + zfill_val)] + '-' * zfill_val + extension ) log.info( f"Total of {no_slices} files saved following the pattern '{pattern_string}'" ) - def save_nifti(self, path: str|os.PathLike, data: np.ndarray): - """Save data to a NIfTI file to the given path. + def save_nifti(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to a NIfTI file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ - import nibabel as nib # Create header header = nib.Nifti1Header() @@ -141,11 +153,11 @@ class DataSaver: img = nib.Nifti1Image(data, np.eye(4), header) # nib does automatically compress if filetype ends with .gz - if self.compression and not path.endswith(".gz"): - path += ".gz" + if self.compression and not path.endswith('.gz'): + path += '.gz' log.warning("File extension '.gz' is added since compression is enabled.") - if not self.compression and path.endswith(".gz"): + if not self.compression and path.endswith('.gz'): path = path[:-3] log.warning( "File extension '.gz' is ignored since compression is disabled." @@ -154,83 +166,84 @@ class DataSaver: # Save image nib.save(img, path) - def save_vol(self, path: str|os.PathLike, data: np.ndarray): - """Save data to a VOL file to the given path. + def save_vol(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to a VOL file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ # No support for compression yet if self.compression: raise NotImplementedError( - "Saving compressed .vol files is not yet supported" + 'Saving compressed .vol files is not yet supported' ) # Create custom .vgi metadata file - metadata = "" - metadata += "{volume1}\n" # .vgi organization - metadata += "[file1]\n" # .vgi organization - metadata += "Size = {} {} {}\n".format( - data.shape[1], data.shape[2], data.shape[0] - ) # Swap axes to match .vol format - metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string - metadata += "Name = {}.vol\n".format( - path.rsplit("/", 1)[-1][:-4] + metadata = '' + metadata += '{volume1}\n' # .vgi organization + metadata += '[file1]\n' # .vgi organization + metadata += f'Size = {data.shape[1]} {data.shape[2]} {data.shape[0]}\n' # Swap axes to match .vol format + metadata += f'Datatype = {str(data.dtype)}\n' # Get datatype as string + metadata += 'Name = {}.vol\n'.format( + path.rsplit('/', 1)[-1][:-4] ) # Get filename without extension # Save metadata - with open(path[:-4] + ".vgi", "w") as f: + with open(path[:-4] + '.vgi', 'w') as f: f.write(metadata) # Save data using numpy in binary format - data.tofile(path[:-4] + ".vol") + data.tofile(path[:-4] + '.vol') def save_h5(self, path, data): - """Save data to a HDF5 file to the given path. + """ + Save data to a HDF5 file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ - import h5py - with h5py.File(path, "w") as f: + with h5py.File(path, 'w') as f: f.create_dataset( - "dataset", data=data, compression="gzip" if self.compression else None + 'dataset', data=data, compression='gzip' if self.compression else None ) - def save_dicom(self, path: str|os.PathLike, data: np.ndarray): - """Save data to a DICOM file to the given path. + def save_dicom(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to a DICOM file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ import pydicom - from pydicom.dataset import FileDataset, FileMetaDataset - from pydicom.uid import UID # based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html # Populate required values for file meta information file_meta = FileMetaDataset() - file_meta.MediaStorageSOPClassUID = UID("1.2.840.10008.5.1.4.1.1.2") - file_meta.MediaStorageSOPInstanceUID = UID("1.2.3") - file_meta.ImplementationClassUID = UID("1.2.3.4") + file_meta.MediaStorageSOPClassUID = UID('1.2.840.10008.5.1.4.1.1.2') + file_meta.MediaStorageSOPInstanceUID = UID('1.2.3') + file_meta.ImplementationClassUID = UID('1.2.3.4') # Create the FileDataset instance (initially no data elements, but file_meta # supplied) - ds = FileDataset(path, {}, file_meta=file_meta, preamble=b"\0" * 128) + ds = FileDataset(path, {}, file_meta=file_meta, preamble=b'\0' * 128) - ds.PatientName = "Test^Firstname" - ds.PatientID = "123456" - ds.StudyInstanceUID = "1.2.3.4.5" + ds.PatientName = 'Test^Firstname' + ds.PatientID = '123456' + ds.StudyInstanceUID = '1.2.3.4.5' ds.SamplesPerPixel = 1 ds.PixelRepresentation = 0 ds.BitsStored = 16 ds.BitsAllocated = 16 - ds.PhotometricInterpretation = "MONOCHROME2" + ds.PhotometricInterpretation = 'MONOCHROME2' ds.Rows = data.shape[1] ds.Columns = data.shape[2] ds.NumberOfFrames = data.shape[0] @@ -240,8 +253,8 @@ class DataSaver: # Set creation date/time dt = datetime.datetime.now() - ds.ContentDate = dt.strftime("%Y%m%d") - timeStr = dt.strftime("%H%M%S.%f") # long format with micro seconds + ds.ContentDate = dt.strftime('%Y%m%d') + timeStr = dt.strftime('%H%M%S.%f') # long format with micro seconds ds.ContentTime = timeStr # Needs to be here because of bug in pydicom ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian @@ -255,8 +268,9 @@ class DataSaver: ds.save_as(path) - def save_to_zarr(self, path: str|os.PathLike, data: da.core.Array): - """Saves a Dask array to a Zarr array on disk. + def save_to_zarr(self, path: str | os.PathLike, data: da.core.Array): + """ + Saves a Dask array to a Zarr array on disk. Args: path (str): The path to the Zarr array on disk. @@ -264,38 +278,41 @@ class DataSaver: Returns: zarr.core.Array: The Zarr array saved on disk. + """ if isinstance(data, da.Array): # If the data is a Dask array, save using dask if self.chunk_shape: - log.info("Rechunking data to shape %s", self.chunk_shape) + log.info('Rechunking data to shape %s', self.chunk_shape) data = data.rechunk(self.chunk_shape) - log.info("Saving Dask array to Zarr array on disk") + log.info('Saving Dask array to Zarr array on disk') da.to_zarr(data, path, overwrite=self.replace) else: zarr_array = zarr.open( path, - mode="w", + mode='w', shape=data.shape, chunks=self.chunk_shape, dtype=data.dtype, ) zarr_array[:] = data - def save_PIL(self, path: str|os.PathLike, data: np.ndarray): - """Save data to a PIL file to the given path. + def save_PIL(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to a PIL file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved + """ # No support for compression yet - if self.compression and path.endswith(".png"): - raise NotImplementedError("png does not support compression") - elif not self.compression and path.endswith((".jpeg", ".jpg")): - raise NotImplementedError("jpeg does not support no compression") + if self.compression and path.endswith('.png'): + raise NotImplementedError('png does not support compression') + elif not self.compression and path.endswith(('.jpeg', '.jpg')): + raise NotImplementedError('jpeg does not support no compression') # Convert to PIL image img = PIL.Image.fromarray(data) @@ -303,8 +320,9 @@ class DataSaver: # Save image img.save(path) - def save(self, path: str|os.PathLike, data: np.ndarray): - """Save data to the given path. + def save(self, path: str | os.PathLike, data: np.ndarray): + """ + Save data to the given path. Args: path (str): The path to save file to @@ -316,6 +334,7 @@ class DataSaver: ValueError: If the provided path does not exist and self.basename is not provided ValueError: If a file extension is not provided. ValueError: if a file with the specified path already exists and replace=False. + """ path = stringify_path(path) @@ -325,7 +344,7 @@ class DataSaver: # If path is an existing directory if isdir: # Check if this is a Zarr directory - if ".zarr" in path: + if '.zarr' in path: if self.replace: return self.save_to_zarr(path, data) if not self.replace: @@ -340,7 +359,7 @@ class DataSaver: else: raise ValueError( f"To save a stack as several TIFF files to the directory '{path}', please provide the keyword argument 'basename'. " - + "Otherwise, to save a single file, please provide a full path with a filename and valid extension." + + 'Otherwise, to save a single file, please provide a full path with a filename and valid extension.' ) # If path is not an existing directory @@ -353,7 +372,7 @@ class DataSaver: return self.save_tiff_stack(path, data) # Check if a parent directory exists - parentdir = os.path.dirname(path) or "." + parentdir = os.path.dirname(path) or '.' if os.path.isdir(parentdir): # If there is a file extension in the path if ext: @@ -367,53 +386,54 @@ class DataSaver: "A file with the provided path already exists. To replace it set 'replace=True'" ) - if path.endswith((".tif", ".tiff")): + if path.endswith(('.tif', '.tiff')): return self.save_tiff(path, data) - elif path.endswith((".nii", "nii.gz")): + elif path.endswith(('.nii', 'nii.gz')): return self.save_nifti(path, data) - elif path.endswith(("TXRM", "XRM", "TXM")): + elif path.endswith(('TXRM', 'XRM', 'TXM')): raise NotImplementedError( - "Saving TXRM files is not yet supported" + 'Saving TXRM files is not yet supported' ) - elif path.endswith((".h5")): + elif path.endswith('.h5'): return self.save_h5(path, data) - elif path.endswith((".vol", ".vgi")): + elif path.endswith(('.vol', '.vgi')): return self.save_vol(path, data) - elif path.endswith((".dcm", ".DCM")): + elif path.endswith(('.dcm', '.DCM')): return self.save_dicom(path, data) - elif path.endswith((".zarr")): + elif path.endswith('.zarr'): return self.save_to_zarr(path, data) - elif path.endswith((".jpeg", ".jpg", ".png")): + elif path.endswith(('.jpeg', '.jpg', '.png')): return self.save_PIL(path, data) else: - raise ValueError("Unsupported file format") + raise ValueError('Unsupported file format') # If there is no file extension in the path else: raise ValueError( - "Please provide a file extension if you want to save as a single file." - + " Otherwise, please provide a basename to save as a TIFF stack" + 'Please provide a file extension if you want to save as a single file.' + + ' Otherwise, please provide a basename to save as a TIFF stack' ) else: raise ValueError( f"The directory '{parentdir}' does not exist.\n" - + "Please provide a valid directory or specify a basename if you want to save a tiff stack as several files to a folder that does not yet exist" + + 'Please provide a valid directory or specify a basename if you want to save a tiff stack as several files to a folder that does not yet exist' ) def save( - path: str|os.PathLike, + path: str | os.PathLike, data: np.ndarray, replace: bool = False, compression: bool = False, basename: bool = None, sliced_dim: int = 0, - chunk_shape: str = "auto", + chunk_shape: str = 'auto', **kwargs, ) -> None: - """Save data to a specified file path. + """ + Save data to a specified file path. Args: - path (str or os.PathLike): The path to save file to. File format is chosen based on the extension. + path (str or os.PathLike): The path to save file to. File format is chosen based on the extension. Supported extensions are: <em>'.tif', '.tiff', '.nii', '.nii.gz', '.h5', '.vol', '.vgi', '.dcm', '.DCM', '.zarr', '.jpeg', '.jpg', '.png'</em> data (numpy.ndarray): The data to be saved replace (bool, optional): Specifies if an existing file with identical path should be replaced. @@ -452,6 +472,7 @@ def save( qim3d.io.save("slices", vol, basename="blob-slices", sliced_dim=0) ``` + """ DataSaver( @@ -464,31 +485,54 @@ def save( ).save(path, data) -def save_mesh( - filename: str, - mesh: trimesh.Trimesh - ) -> None: +# def save_mesh( +# filename: str, +# mesh: trimesh.Trimesh +# ) -> None: +# """ +# Save a trimesh object to an .obj file. + +# Args: +# filename (str or os.PathLike): The name of the file to save the mesh. +# mesh (trimesh.Trimesh): A trimesh.Trimesh object representing the mesh. + +# Example: +# ```python +# import qim3d + +# vol = qim3d.generate.noise_object(base_shape=(32, 32, 32), +# final_shape=(32, 32, 32), +# noise_scale=0.05, +# order=1, +# gamma=1.0, +# max_value=255, +# threshold=0.5) +# mesh = qim3d.mesh.from_volume(vol) +# qim3d.io.save_mesh("mesh.obj", mesh) +# ``` +# """ +# # Export the mesh to the specified filename +# mesh.export(filename) + +def save_mesh(filename: str, mesh: hmesh.Manifold) -> None: """ - Save a trimesh object to an .obj file. + Save a mesh object to a specific file. + This function is based on the [PyGEL3D library's saving function implementation](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#save). + Args: - filename (str or os.PathLike): The name of the file to save the mesh. - mesh (trimesh.Trimesh): A trimesh.Trimesh object representing the mesh. + filename (str or os.PathLike): The path to save file to. File format is chosen based on the extension. Supported extensions are: '.x3d', '.obj', '.off'. + mesh (pygel3d.hmesh.Manifold): A hmesh.Manifold object representing the mesh. Example: ```python import qim3d - vol = qim3d.generate.noise_object(base_shape=(32, 32, 32), - final_shape=(32, 32, 32), - noise_scale=0.05, - order=1, - gamma=1.0, - max_value=255, - threshold=0.5) - mesh = qim3d.mesh.from_volume(vol) - qim3d.io.save_mesh("mesh.obj", mesh) + synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) + mesh = qim3d.mesh.from_volume(synthetic_blob) + qim3d.io.save_mesh("mesh.obj", mesh) ``` + """ # Export the mesh to the specified filename - mesh.export(filename) \ No newline at end of file + hmesh.save(filename, mesh) \ No newline at end of file diff --git a/qim3d/io/_sync.py b/qim3d/io/_sync.py index 48c17a8cb8d8fba2c3ededc93f252196087b6fc2..97b2b90dbb62ca8d9a26e0cec92f825a131549a5 100644 --- a/qim3d/io/_sync.py +++ b/qim3d/io/_sync.py @@ -1,35 +1,46 @@ -""" Dataset synchronization tasks """ +"""Dataset synchronization tasks""" + import os import subprocess +from pathlib import Path + import outputformat as ouf + from qim3d.utils import log -from pathlib import Path class Sync: + """Class for dataset synchronization tasks""" def __init__(self): # Checks if rsync is available if not self._check_rsync(): raise RuntimeError( - "Could not find rsync, please check if it is installed in your system." + 'Could not find rsync, please check if it is installed in your system.' ) def _check_rsync(self): """Check if rsync is available""" try: - subprocess.run(["rsync", "--version"], capture_output=True, check=True) + subprocess.run(['rsync', '--version'], capture_output=True, check=True) return True except Exception as error: - log.error("rsync is not available") + log.error('rsync is not available') log.error(error) return False - def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]: - """Check if all files from 'source' are in 'destination' + def check_destination( + self, + source: str, + destination: str, + checksum: bool = False, + verbose: bool = True, + ) -> list[str]: + """ + Check if all files from 'source' are in 'destination' This function compares the files in the 'source' directory to those in the 'destination' directory and reports any differences or missing files. @@ -51,13 +62,13 @@ class Sync: destination = Path(destination) if checksum: - rsync_args = "-avrc" + rsync_args = '-avrc' else: - rsync_args = "-avr" + rsync_args = '-avr' command = [ - "rsync", - "-n", + 'rsync', + '-n', rsync_args, str(source) + os.path.sep, str(destination) + os.path.sep, @@ -70,18 +81,25 @@ class Sync: ) diff_files_and_folders = out.stdout.decode().splitlines()[1:-3] - diff_files = [f for f in diff_files_and_folders if not f.endswith("/")] + diff_files = [f for f in diff_files_and_folders if not f.endswith('/')] if len(diff_files) > 0 and verbose: - title = "Source files differing or missing in destination" + title = 'Source files differing or missing in destination' log.info( - ouf.showlist(diff_files, style="line", return_str=True, title=title) + ouf.showlist(diff_files, style='line', return_str=True, title=title) ) return diff_files - def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None: - """Checks whether 'source' and 'destination' directories are synchronized. + def compare_dirs( + self, + source: str, + destination: str, + checksum: bool = False, + verbose: bool = True, + ) -> None: + """ + Checks whether 'source' and 'destination' directories are synchronized. This function compares the contents of two directories ('source' and 'destination') and reports any differences. @@ -107,7 +125,7 @@ class Sync: if verbose: s_files, s_dirs = self.count_files_and_dirs(source) d_files, d_dirs = self.count_files_and_dirs(destination) - log.info("\n") + log.info('\n') s_d = self.check_destination( source, destination, checksum=checksum, verbose=False @@ -120,7 +138,7 @@ class Sync: # No differences if verbose: log.info( - "Source and destination are synchronized, no differences found." + 'Source and destination are synchronized, no differences found.' ) return @@ -128,9 +146,9 @@ class Sync: log.info( ouf.showlist( union, - style="line", + style='line', return_str=True, - title=f"{len(union)} files are not in sync", + title=f'{len(union)} files are not in sync', ) ) @@ -139,9 +157,9 @@ class Sync: log.info( ouf.showlist( intersection, - style="line", + style='line', return_str=True, - title=f"{len(intersection)} files present on both, but not equal", + title=f'{len(intersection)} files present on both, but not equal', ) ) @@ -150,9 +168,9 @@ class Sync: log.info( ouf.showlist( s_exclusive, - style="line", + style='line', return_str=True, - title=f"{len(s_exclusive)} files present only on {source}", + title=f'{len(s_exclusive)} files present only on {source}', ) ) @@ -161,15 +179,18 @@ class Sync: log.info( ouf.showlist( d_exclusive, - style="line", + style='line', return_str=True, - title=f"{len(d_exclusive)} files present only on {destination}", + title=f'{len(d_exclusive)} files present only on {destination}', ) ) return - def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]: - """Count the number of files and directories in the given path. + def count_files_and_dirs( + self, path: str | os.PathLike, verbose: bool = True + ) -> tuple[int, int]: + """ + Count the number of files and directories in the given path. This function recursively counts the number of files and directories in the specified directory 'path'. @@ -202,6 +223,6 @@ class Sync: dirs += dirs_count if verbose: - log.info(f"Total of {files} files and {dirs} directories on {path}") + log.info(f'Total of {files} files and {dirs} directories on {path}') return files, dirs diff --git a/qim3d/mesh/_common_mesh_methods.py b/qim3d/mesh/_common_mesh_methods.py index a7e358ea7eed9bade283bd411f33da6f64ffa3a0..9e1c5e814b1704ec6f0a95764e1f7b16285b9570 100644 --- a/qim3d/mesh/_common_mesh_methods.py +++ b/qim3d/mesh/_common_mesh_methods.py @@ -1,78 +1,46 @@ +from typing import Any, Tuple + import numpy as np from skimage import measure, filters -import trimesh +from pygel3d import hmesh from typing import Tuple, Any from qim3d.utils._logger import log def from_volume( volume: np.ndarray, - level: float = None, - step_size: int = 1, - allow_degenerate: bool = False, - padding: Tuple[int, int, int] = (2, 2, 2), - **kwargs: Any, -) -> trimesh.Trimesh: - """ - Convert a volume to a mesh using the Marching Cubes algorithm, with optional thresholding and padding. + **kwargs: any +) -> hmesh.Manifold: + """ Convert a 3D numpy array to a mesh object using the [volumetric_isocontour](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#volumetric_isocontour) function from Pygel3D. Args: - volume (np.ndarray): The 3D numpy array representing the volume. - level (float, optional): The threshold value for Marching Cubes. If None, Otsu's method is used. - step_size (int, optional): The step size for the Marching Cubes algorithm. - allow_degenerate (bool, optional): Whether to allow degenerate (i.e. zero-area) triangles in the end-result. If False, degenerate triangles are removed, at the cost of making the algorithm slower. Default False. - padding (tuple of ints, optional): Padding to add around the volume. - **kwargs: Additional keyword arguments to pass to `skimage.measure.marching_cubes`. + volume (np.ndarray): A 3D numpy array representing a volume. + **kwargs: Additional arguments to pass to the Pygel3D volumetric_isocontour function. + + Raises: + ValueError: If the input volume is not a 3D numpy array or if the input volume is empty. Returns: - mesh (trimesh.Trimesh): The generated mesh. + hmesh.Manifold: A Pygel3D mesh object representing the input volume. Example: + Convert a 3D numpy array to a Pygel3D mesh object: ```python import qim3d - vol = qim3d.generate.noise_object(base_shape=(128,128,128), - final_shape=(128,128,128), - noise_scale=0.03, - order=1, - gamma=1, - max_value=255, - threshold=0.5, - dtype='uint8' - ) - mesh = qim3d.mesh.from_volume(vol, step_size=3) - qim3d.viz.mesh(mesh.vertices, mesh.faces) + + # Generate a 3D blob + synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) + + # Convert the 3D numpy array to a Pygel3D mesh object + mesh = qim3d.mesh.from_volume(synthetic_blob) ``` - <iframe src="https://platform.qim.dk/k3d/mesh_visualization.html" width="100%" height="500" frameborder="0"></iframe> """ + if volume.ndim != 3: raise ValueError("The input volume must be a 3D numpy array.") + + if volume.size == 0: + raise ValueError("The input volume must not be empty.") - # Compute the threshold level if not provided - if level is None: - level = filters.threshold_otsu(volume) - log.info(f"Computed level using Otsu's method: {level}") - - # Apply padding to the volume - if padding is not None: - pad_z, pad_y, pad_x = padding - padding_value = np.min(volume) - volume = np.pad( - volume, - ((pad_z, pad_z), (pad_y, pad_y), (pad_x, pad_x)), - mode="constant", - constant_values=padding_value, - ) - log.info(f"Padded volume with {padding} to shape: {volume.shape}") - - # Call skimage.measure.marching_cubes with user-provided kwargs - verts, faces, normals, values = measure.marching_cubes( - volume, level=level, step_size=step_size, allow_degenerate=allow_degenerate, **kwargs - ) - - # Create the Trimesh object - mesh = trimesh.Trimesh(vertices=verts, faces=faces) - - # Fix face orientation to ensure normals point outwards - trimesh.repair.fix_inversion(mesh, multibody=True) - - return mesh + mesh = hmesh.volumetric_isocontour(volume, **kwargs) + return mesh \ No newline at end of file diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py index 4624be5fb111dac4b820b52f5ab072bf744c9154..1d0b5669c650d3cdbc93c38757266499c8a379be 100644 --- a/qim3d/ml/models/__init__.py +++ b/qim3d/ml/models/__init__.py @@ -1 +1 @@ -from ._unet import UNet, Hyperparameters +from ._unet import Hyperparameters, UNet diff --git a/qim3d/operations/_common_operations_methods.py b/qim3d/operations/_common_operations_methods.py index a83b25d26f223c2ad20ad2b383044f4a45a0ad38..d8eea2a02c0664b3f5e7adf5d90ee9ba8741bb3a 100644 --- a/qim3d/operations/_common_operations_methods.py +++ b/qim3d/operations/_common_operations_methods.py @@ -1,14 +1,15 @@ import numpy as np + import qim3d.filters as filters -from qim3d.utils import log -__all__ = ["remove_background", "fade_mask", "overlay_rgb_images"] +__all__ = ['remove_background', 'fade_mask', 'overlay_rgb_images'] + def remove_background( vol: np.ndarray, median_filter_size: int = 2, min_object_radius: int = 3, - background: str = "dark", + background: str = 'dark', **median_kwargs, ) -> np.ndarray: """ @@ -41,6 +42,7 @@ def remove_background( fig2 = qim3d.viz.slices_grid(vol_filtered, value_min=0, value_max=255, num_slices=5, display_figure=True) ```  + """ # Create a pipeline with a median filter and a tophat filter @@ -53,12 +55,11 @@ def remove_background( return pipeline(vol) - def fade_mask( vol: np.ndarray, decay_rate: float = 10, ratio: float = 0.5, - geometry: str = "spherical", + geometry: str = 'spherical', invert: bool = False, axis: int = 0, **kwargs, @@ -96,9 +97,9 @@ def fade_mask(  """ - if 0 > axis or axis >= vol.ndim: + if axis < 0 or axis >= vol.ndim: raise ValueError( - "Axis must be between 0 and the number of dimensions of the volume" + 'Axis must be between 0 and the number of dimensions of the volume' ) # Generate the coordinates of each point in the array @@ -112,23 +113,23 @@ def fade_mask( center = np.array([(s - 1) / 2 for s in shape]) # Calculate the distance of each point from the center - if geometry == "spherical": + if geometry == 'spherical': distance = np.linalg.norm([z - center[0], y - center[1], x - center[2]], axis=0) - elif geometry == "cylindrical": + elif geometry == 'cylindrical': distance_list = np.array([z - center[0], y - center[1], x - center[2]]) # remove the axis along which the fading is not applied distance_list = np.delete(distance_list, axis, axis=0) distance = np.linalg.norm(distance_list, axis=0) else: raise ValueError("Geometry must be 'spherical' or 'cylindrical'") - + # Compute the maximum distance from the center max_distance = np.linalg.norm(center) - + # Compute ratio to make synthetic blobs exactly cylindrical # target_max_normalized_distance = 1.4 works well to make the blobs cylindrical - if "target_max_normalized_distance" in kwargs: - target_max_normalized_distance = kwargs["target_max_normalized_distance"] + if 'target_max_normalized_distance' in kwargs: + target_max_normalized_distance = kwargs['target_max_normalized_distance'] ratio = np.max(distance) / (target_max_normalized_distance * max_distance) # Normalize the distances so that they go from 0 at the center to 1 at the farthest point @@ -154,7 +155,10 @@ def fade_mask( def overlay_rgb_images( - background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black: bool = True, + background: np.ndarray, + foreground: np.ndarray, + alpha: float = 0.5, + hide_black: bool = True, ) -> np.ndarray: """ Overlay an RGB foreground onto an RGB background using alpha blending. @@ -176,33 +180,38 @@ def overlay_rgb_images( - It ensures that the background and foreground have the same first two dimensions (image size matches). - It can handle greyscale images, values from 0 to 1, raw values which are negative or bigger than 255. - It calculates the maximum projection of the foreground and blends them onto the background. + """ - def to_uint8(image:np.ndarray): + def to_uint8(image: np.ndarray): if np.min(image) < 0: image = image - np.min(image) maxim = np.max(image) if maxim > 255: - image = (image / maxim)*255 + image = (image / maxim) * 255 elif maxim <= 1: - image = image*255 + image = image * 255 if image.ndim == 2: image = np.repeat(image[..., None], 3, -1) elif image.ndim == 3: - image = image[..., :3] # Ignoring alpha channel + image = image[..., :3] # Ignoring alpha channel else: - raise ValueError(F'Input image can not have higher dimension than 3. Yours have {image.ndim}') - + raise ValueError( + f'Input image can not have higher dimension than 3. Yours have {image.ndim}' + ) + return image.astype(np.uint8) - + background = to_uint8(background) foreground = to_uint8(foreground) # Ensure both images have the same shape if background.shape != foreground.shape: - raise ValueError(F"Input images must have the same first two dimensions. But background is of shape {background.shape} and foreground is of shape {foreground.shape}") + raise ValueError( + f'Input images must have the same first two dimensions. But background is of shape {background.shape} and foreground is of shape {foreground.shape}' + ) # Perform alpha blending foreground_max_projection = np.amax(foreground, axis=2) @@ -215,16 +224,20 @@ def overlay_rgb_images( ) # Check alpha validity if alpha < 0: - raise ValueError(F'Alpha has to be positive number. You used {alpha}') + raise ValueError(f'Alpha has to be positive number. You used {alpha}') elif alpha > 1: alpha = 1 - + # If the pixel is black, its alpha value is set to 0, so it has no effect on the image if hide_black: - alpha = np.full((background.shape[0], background.shape[1],1), alpha) - alpha[np.apply_along_axis(lambda x: (x == [0,0,0]).all(), axis = 2, arr = foreground)] = 0 + alpha = np.full((background.shape[0], background.shape[1], 1), alpha) + alpha[ + np.apply_along_axis( + lambda x: (x == [0, 0, 0]).all(), axis=2, arr=foreground + ) + ] = 0 composite = background * (1 - alpha) + foreground * alpha - composite = np.clip(composite, 0, 255).astype("uint8") + composite = np.clip(composite, 0, 255).astype('uint8') - return composite.astype("uint8") \ No newline at end of file + return composite.astype('uint8') diff --git a/qim3d/processing/__init__.py b/qim3d/processing/__init__.py index 913d6ddb12fd7a22df1ba522bc25b2db93d35dbf..bb0f25c4788c8e5a62dc54cafaedf33df98b0ce0 100644 --- a/qim3d/processing/__init__.py +++ b/qim3d/processing/__init__.py @@ -1,3 +1,3 @@ +from ._layers import get_lines, segment_layers from ._local_thickness import local_thickness from ._structure_tensor import structure_tensor -from ._layers import segment_layers, get_lines diff --git a/qim3d/processing/_layers.py b/qim3d/processing/_layers.py index b77a4834120fedf31b757b23d9a631c709cfc1f6..ef3e6e922212fc1103c9d8e625a6593f10ed4974 100644 --- a/qim3d/processing/_layers.py +++ b/qim3d/processing/_layers.py @@ -1,15 +1,16 @@ import numpy as np -from slgbuilder import GraphObject -from slgbuilder import MaxflowBuilder - -def segment_layers(data: np.ndarray, - inverted: bool = False, - n_layers: int = 1, - delta: float = 1, - min_margin: int = 10, - max_margin: int = None, - wrap: bool = False - ) -> list: +from slgbuilder import GraphObject, MaxflowBuilder + + +def segment_layers( + data: np.ndarray, + inverted: bool = False, + n_layers: int = 1, + delta: float = 1, + min_margin: int = 10, + max_margin: int = None, + wrap: bool = False, +) -> list: """ Works on 2D and 3D data. Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation @@ -56,40 +57,49 @@ def segment_layers(data: np.ndarray, if inverted: data = ~data else: - raise TypeError(F"Data has to be type np.ndarray. Your data is of type {type(data)}") - + raise TypeError( + f'Data has to be type np.ndarray. Your data is of type {type(data)}' + ) + helper = MaxflowBuilder() if not isinstance(n_layers, int): - raise TypeError(F"Number of layers has to be positive integer. You passed {type(n_layers)}") - + raise TypeError( + f'Number of layers has to be positive integer. You passed {type(n_layers)}' + ) + if n_layers == 1: layer = GraphObject(data) helper.add_object(layer) elif n_layers > 1: layers = [GraphObject(data) for _ in range(n_layers)] helper.add_objects(layers) - for i in range(len(layers)-1): - helper.add_layered_containment(layers[i], layers[i+1], min_margin=min_margin, max_margin=max_margin) + for i in range(len(layers) - 1): + helper.add_layered_containment( + layers[i], layers[i + 1], min_margin=min_margin, max_margin=max_margin + ) else: - raise ValueError(F"Number of layers has to be positive integer. You passed {n_layers}") - + raise ValueError( + f'Number of layers has to be positive integer. You passed {n_layers}' + ) + helper.add_layered_boundary_cost() if delta > 1: delta = int(delta) elif delta <= 0: - raise ValueError(F'Delta has to be positive number. You passed {delta}') - helper.add_layered_smoothness(delta=delta, wrap = bool(wrap)) + raise ValueError(f'Delta has to be positive number. You passed {delta}') + helper.add_layered_smoothness(delta=delta, wrap=bool(wrap)) helper.solve() if n_layers == 1: - segmentations =[helper.what_segments(layer)] + segmentations = [helper.what_segments(layer)] else: segmentations = [helper.what_segments(l).astype(np.int32) for l in layers] return segmentations -def get_lines(segmentations:list[np.ndarray]) -> list: + +def get_lines(segmentations: list[np.ndarray]) -> list: """ Expects list of arrays where each array is 2D segmentation with only 2 classes. This function gets the border between those two so it could be plotted. Used with qim3d.processing.segment_layers @@ -99,6 +109,7 @@ def get_lines(segmentations:list[np.ndarray]) -> list: Returns: segmentation_lines (list): List of 1D numpy arrays + """ segmentation_lines = [np.argmin(s, axis=0) - 0.5 for s in segmentations] - return segmentation_lines \ No newline at end of file + return segmentation_lines diff --git a/qim3d/processing/_local_thickness.py b/qim3d/processing/_local_thickness.py index 12d6f19b664d61bb6bd8e0bc8db1347e8cd1c9a0..87718fbd80094f707c0d453469bb140d059f31ea 100644 --- a/qim3d/processing/_local_thickness.py +++ b/qim3d/processing/_local_thickness.py @@ -1,19 +1,23 @@ """Wrapper for the local thickness function from the localthickness package including visualization functions.""" -import numpy as np from typing import Optional -from qim3d.utils import log -import qim3d + +import numpy as np from IPython.display import display +import qim3d +from qim3d.utils import log + + def local_thickness( image: np.ndarray, scale: float = 1, mask: Optional[np.ndarray] = None, visualize: bool = False, - **viz_kwargs + **viz_kwargs, ) -> np.ndarray: - """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness) + """ + Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness) The "Fast Local Thickness" by Vedrana Andersen Dahl and Anders Bjorholm Dahl from the Technical University of Denmark is a efficient algorithm for computing local thickness in 2D and 3D images. Their method significantly reduces computation time compared to traditional algorithms by utilizing iterative dilation with small structuring elements, rather than the large ones typically used. @@ -90,9 +94,7 @@ def local_thickness( # If not, binarize it using Otsu's method, log the threshold and compute the local thickness threshold = threshold_otsu(image=image) log.warning( - "Input image is not binary. It will be binarized using Otsu's method with threshold: {}".format( - threshold - ) + f"Input image is not binary. It will be binarized using Otsu's method with threshold: {threshold}" ) local_thickness = lt.local_thickness(image > threshold, scale=scale, mask=mask) else: diff --git a/qim3d/processing/_structure_tensor.py b/qim3d/processing/_structure_tensor.py index de73521274a1879ae02edf4ab9f459409f370cd2..7edd6c4d6a30f28cd0a589a6abbdd57ab5874afb 100644 --- a/qim3d/processing/_structure_tensor.py +++ b/qim3d/processing/_structure_tensor.py @@ -1,11 +1,12 @@ """Wrapper for the structure tensor function from the structure_tensor package""" -from typing import Tuple import logging +from typing import Tuple + import numpy as np -from qim3d.utils._logger import log from IPython.display import display + def structure_tensor( vol: np.ndarray, sigma: float = 1.0, @@ -13,9 +14,10 @@ def structure_tensor( base_noise: bool = True, full: bool = False, visualize: bool = False, - **viz_kwargs + **viz_kwargs, ) -> Tuple[np.ndarray, np.ndarray]: - """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/). + """ + Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/). The structure tensor algorithm is a method for analyzing the orientation of fiber-like structures in 3D images. @@ -83,7 +85,7 @@ def structure_tensor( logging.getLogger().setLevel(previous_logging_level) if vol.ndim != 3: - raise ValueError("The input volume must be 3D") + raise ValueError('The input volume must be 3D') # Ensure volume is a float if vol.dtype != np.float32 and vol.dtype != np.float64: diff --git a/qim3d/segmentation/_common_segmentation_methods.py b/qim3d/segmentation/_common_segmentation_methods.py index fc6474597a34246f31ddcbc9634f0be803b831ed..c2560fe494ee8c28213baeba2bd124d58ca1c177 100644 --- a/qim3d/segmentation/_common_segmentation_methods.py +++ b/qim3d/segmentation/_common_segmentation_methods.py @@ -1,16 +1,18 @@ import numpy as np + from qim3d.utils._logger import log -__all__ = ["watershed"] +__all__ = ['watershed'] + def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, int]: """ Apply watershed segmentation to a binary volume. Args: - bin_vol (np.ndarray): Binary volume to segment. The input should be a 3D binary image where non-zero elements + bin_vol (np.ndarray): Binary volume to segment. The input should be a 3D binary image where non-zero elements represent the objects to be segmented. - min_distance (int): Minimum number of pixels separating peaks in the distance transform. Peaks that are + min_distance (int): Minimum number of pixels separating peaks in the distance transform. Peaks that are too close will be merged, affecting the number of segmented objects. Default is 5. Returns: @@ -37,11 +39,13 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i  """ - import skimage import scipy + import skimage if len(np.unique(bin_vol)) > 2: - raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.") + raise ValueError( + 'bin_vol has to be binary volume - it must contain max 2 unique values.' + ) # Compute distance transform of binary volume distance = scipy.ndimage.distance_transform_edt(bin_vol) @@ -65,6 +69,6 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i # Extract number of objects found num_labels = len(np.unique(labeled_volume)) - 1 - log.info(f"Total number of objects found: {num_labels}") + log.info(f'Total number of objects found: {num_labels}') - return labeled_volume, num_labels \ No newline at end of file + return labeled_volume, num_labels diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py index 190266d4f6c39954e26f6f421dec893a69319a13..5efb0714656cc30c53ec4e380ab473f5aeed10bc 100644 --- a/qim3d/segmentation/_connected_components.py +++ b/qim3d/segmentation/_connected_components.py @@ -1,5 +1,6 @@ import numpy as np from scipy.ndimage import find_objects, label + from qim3d.utils._logger import log @@ -11,66 +12,73 @@ class CC: Args: connected_components (np.ndarray): The connected components. num_connected_components (int): The number of connected components. + """ self._connected_components = connected_components self.cc_count = num_connected_components - + self.shape = connected_components.shape - + def __len__(self): """ Returns the number of connected components in the object. """ return self.cc_count - def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray: + def get_cc(self, index: int | None = None, crop: bool = False) -> np.ndarray: """ Get the connected component with the given index, if index is None selects a random component. Args: - index (int): The index of the connected component. + index (int): The index of the connected component. If none returns all components. If 'random' returns a random component. crop (bool): If True, the volume is cropped to the bounding box of the connected component. Returns: np.ndarray: The connected component as a binary mask. + """ if index is None: volume = self._connected_components - elif index == "random": + elif index == 'random': index = np.random.randint(1, self.cc_count + 1) volume = self._connected_components == index else: - assert 1 <= index <= self.cc_count, "Index out of range. Needs to be in range [1, cc_count]." + assert ( + 1 <= index <= self.cc_count + ), 'Index out of range. Needs to be in range [1, cc_count].' volume = self._connected_components == index - + if crop: # As we index get_bounding_box element 0 will be the bounding box for the connected component at index - bbox = self.get_bounding_box(index)[0] + bbox = self.get_bounding_box(index)[0] volume = volume[bbox] - + return volume - def get_bounding_box(self, index: int|None =None)-> list[tuple]: - """Get the bounding boxes of the connected components. + def get_bounding_box(self, index: int | None = None) -> list[tuple]: + """ + Get the bounding boxes of the connected components. Args: index (int, optional): The index of the connected component. If none selects all components. Returns: list: A list of bounding boxes. + """ if index: - assert 1 <= index <= self.cc_count, "Index out of range." + assert 1 <= index <= self.cc_count, 'Index out of range.' return find_objects(self._connected_components == index) else: return find_objects(self._connected_components) def get_3d_cc(image: np.ndarray) -> CC: - """ Returns an object (CC) containing the connected components of the input volume. Use plot_cc to visualize the connected components. + """ + Returns an object (CC) containing the connected components of the input volume. Use plot_cc to visualize the connected components. Args: image (np.ndarray): An array-like object to be labeled. Any non-zero values in `input` are @@ -85,7 +93,8 @@ def get_3d_cc(image: np.ndarray) -> CC: vol = qim3d.examples.cement_128x128x128[50:150]<60 cc = qim3d.segmentation.get_3d_cc(vol) ``` + """ connected_components, num_connected_components = label(image) - log.info(f"Total number of connected components found: {num_connected_components}") + log.info(f'Total number of connected components found: {num_connected_components}') return CC(connected_components, num_connected_components) diff --git a/qim3d/tests/__init__.py b/qim3d/tests/__init__.py index 1fa15c27df5773c234deedf1e0bc500d41b6ecd4..2593d1fa77ada7c333b4e55a09c9ea05e6047d56 100644 --- a/qim3d/tests/__init__.py +++ b/qim3d/tests/__init__.py @@ -1,18 +1,21 @@ -"Helper functions for testing" +"""Helper functions for testing""" import os -import matplotlib -import matplotlib.pyplot as plt -from pathlib import Path import shutil -from PIL import Image import socket +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt import numpy as np +from PIL import Image + from qim3d.utils._logger import log def mock_plot(): - """Creates a mock plot of a sine wave. + """ + Creates a mock plot of a sine wave. Returns: matplotlib.figure.Figure: The generated plot figure. @@ -22,9 +25,10 @@ def mock_plot(): >>> fig = mock_plot() >>> plt.show() + """ - matplotlib.use("Agg") + matplotlib.use('Agg') fig = plt.figure(figsize=(5, 4)) axes = fig.add_axes([0.1, 0.1, 0.8, 0.8]) @@ -34,7 +38,7 @@ def mock_plot(): return fig -def mock_write_file(path, content="File created by qim3d"): +def mock_write_file(path, content='File created by qim3d'): """ Creates a file at the specified path and writes a predefined text into it. @@ -43,8 +47,9 @@ def mock_write_file(path, content="File created by qim3d"): Example: >>> mock_write_file("example.txt") + """ - _file = open(path, "w", encoding="utf-8") + _file = open(path, 'w', encoding='utf-8') _file.write(content) _file.close() @@ -60,7 +65,8 @@ def is_server_running(ip, port): def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): - """Creates a temporary folder to test deep learning tools. + """ + Creates a temporary folder to test deep learning tools. Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'. n random images are then added to all four subfolders. @@ -74,9 +80,10 @@ def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): Example: >>> tempdata('temporary_folder',n = 10, img_shape = (16,16)) + """ - folder_trte = ["train", "test"] - sub_folders = ["images", "labels"] + folder_trte = ['train', 'test'] + sub_folders = ['images', 'labels'] # Creating train/test folder path_train = Path(folder) / folder_trte[0] @@ -98,10 +105,10 @@ def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): os.makedirs(path_train_lab) os.makedirs(path_test_lab) for i in range(n): - img.save(path_train_im / f"img_train{i}.png") - img.save(path_train_lab / f"img_train{i}.png") - img.save(path_test_im / f"img_test{i}.png") - img.save(path_test_lab / f"img_test{i}.png") + img.save(path_train_im / f'img_train{i}.png') + img.save(path_train_lab / f'img_train{i}.png') + img.save(path_test_im / f'img_test{i}.png') + img.save(path_test_lab / f'img_test{i}.png') if remove: for filename in os.listdir(folder): @@ -112,6 +119,6 @@ def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: - log.warning("Failed to delete %s. Reason: %s" % (file_path, e)) + log.warning('Failed to delete %s. Reason: %s' % (file_path, e)) os.rmdir(folder) diff --git a/qim3d/tests/filters/test_filters.py b/qim3d/tests/filters/test_filters.py index adbfd0130713b469c9a624a2be5ef72596d209dd..779be645cb1a6e7326fd104c1131fcf9b2e1b30e 100644 --- a/qim3d/tests/filters/test_filters.py +++ b/qim3d/tests/filters/test_filters.py @@ -1,28 +1,33 @@ -import qim3d +import re + import numpy as np import pytest -import re + +import qim3d + def test_filter_base_initialization(): - filter_base = qim3d.filters.FilterBase(3,size=2) + filter_base = qim3d.filters.FilterBase(3, size=2) assert filter_base.args == (3,) assert filter_base.kwargs == {'size': 2} + def test_gaussian_filter(): input_image = np.random.rand(50, 50) # Testing the function - filtered_image_fn = qim3d.filters.gaussian(input_image,sigma=1.5) + filtered_image_fn = qim3d.filters.gaussian(input_image, sigma=1.5) # Testing the class method gaussian_filter_cls = qim3d.filters.Gaussian(sigma=1.5) filtered_image_cls = gaussian_filter_cls(input_image) - + # Assertions assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape - assert np.array_equal(filtered_image_fn,filtered_image_cls) + assert np.array_equal(filtered_image_fn, filtered_image_cls) assert not np.array_equal(filtered_image_fn, input_image) + def test_median_filter(): input_image = np.random.rand(50, 50) @@ -38,6 +43,7 @@ def test_median_filter(): assert np.array_equal(filtered_image_fn, filtered_image_cls) assert not np.array_equal(filtered_image_fn, input_image) + def test_maximum_filter(): input_image = np.random.rand(50, 50) @@ -53,6 +59,7 @@ def test_maximum_filter(): assert np.array_equal(filtered_image_fn, filtered_image_cls) assert not np.array_equal(filtered_image_fn, input_image) + def test_minimum_filter(): input_image = np.random.rand(50, 50) @@ -68,6 +75,7 @@ def test_minimum_filter(): assert np.array_equal(filtered_image_fn, filtered_image_cls) assert not np.array_equal(filtered_image_fn, input_image) + def test_sequential_filter_pipeline(): input_image = np.random.rand(50, 50) @@ -77,17 +85,23 @@ def test_sequential_filter_pipeline(): maximum_filter = qim3d.filters.Maximum(size=3) # Testing the sequential pipeline - sequential_pipeline = qim3d.filters.Pipeline(gaussian_filter, median_filter, maximum_filter) + sequential_pipeline = qim3d.filters.Pipeline( + gaussian_filter, median_filter, maximum_filter + ) filtered_image_pipeline = sequential_pipeline(input_image) # Testing the equivalence to maximum(median(gaussian(input,**kwargs),**kwargs),**kwargs) - expected_output = qim3d.filters.maximum(qim3d.filters.median(qim3d.filters.gaussian(input_image, sigma=1.5), size=3), size=3) + expected_output = qim3d.filters.maximum( + qim3d.filters.median(qim3d.filters.gaussian(input_image, sigma=1.5), size=3), + size=3, + ) # Assertions assert filtered_image_pipeline.shape == expected_output.shape == input_image.shape assert not np.array_equal(filtered_image_pipeline, input_image) assert np.array_equal(filtered_image_pipeline, expected_output) + def test_sequential_filter_appending(): input_image = np.random.rand(50, 50) @@ -97,30 +111,42 @@ def test_sequential_filter_appending(): maximum_filter = qim3d.filters.Maximum(size=3) # Sequential pipeline with filter initialized at the beginning - sequential_pipeline_initial = qim3d.filters.Pipeline(gaussian_filter, median_filter, maximum_filter) + sequential_pipeline_initial = qim3d.filters.Pipeline( + gaussian_filter, median_filter, maximum_filter + ) filtered_image_initial = sequential_pipeline_initial(input_image) # Sequential pipeline with filter appended - sequential_pipeline_appended = qim3d.filters.Pipeline(gaussian_filter, median_filter) + sequential_pipeline_appended = qim3d.filters.Pipeline( + gaussian_filter, median_filter + ) sequential_pipeline_appended.append(maximum_filter) filtered_image_appended = sequential_pipeline_appended(input_image) # Assertions - assert filtered_image_initial.shape == filtered_image_appended.shape == input_image.shape - assert not np.array_equal(filtered_image_appended,input_image) + assert ( + filtered_image_initial.shape + == filtered_image_appended.shape + == input_image.shape + ) + assert not np.array_equal(filtered_image_appended, input_image) assert np.array_equal(filtered_image_initial, filtered_image_appended) + def test_assertion_error_not_filterbase_subclass(): # Get valid filter classes - valid_filters = [subclass.__name__ for subclass in qim3d.filters.FilterBase.__subclasses__()] + valid_filters = [ + subclass.__name__ for subclass in qim3d.filters.FilterBase.__subclasses__() + ] # Create invalid object invalid_filter = object() # An object that is not an instance of FilterBase - # Construct error message - message = f"filters should be instances of one of the following classes: {valid_filters}" + message = ( + f'filters should be instances of one of the following classes: {valid_filters}' + ) # Use pytest.raises to catch the AssertionError with pytest.raises(AssertionError, match=re.escape(message)): - sequential_pipeline = qim3d.filters.Pipeline(invalid_filter) \ No newline at end of file + sequential_pipeline = qim3d.filters.Pipeline(invalid_filter) diff --git a/qim3d/tests/gui/test_annotation_tool.py b/qim3d/tests/gui/test_annotation_tool.py index 06a0bedfb869766c91277951f15ac4ddc7ab64f6..75ac8a4b5d7c6b4bfe13843777afa0bd9aa0f225 100644 --- a/qim3d/tests/gui/test_annotation_tool.py +++ b/qim3d/tests/gui/test_annotation_tool.py @@ -1,12 +1,13 @@ -import qim3d import multiprocessing import time +import qim3d + def test_starting_class(): app = qim3d.gui.annotation_tool.Interface() - assert app.title == "Annotation Tool" + assert app.title == 'Annotation Tool' def start_server(ip, port): @@ -15,7 +16,7 @@ def start_server(ip, port): def test_app_launch(): - ip = "localhost" + ip = 'localhost' port = 65432 proc = multiprocessing.Process(target=start_server, args=(ip, port)) diff --git a/qim3d/tests/gui/test_iso3d.py b/qim3d/tests/gui/test_iso3d.py index 1d2e4955c396cfb217c77eb416f5dcbbfa6431d4..7af53f2b7184c20430c39b7f38b9baf41f67f46f 100644 --- a/qim3d/tests/gui/test_iso3d.py +++ b/qim3d/tests/gui/test_iso3d.py @@ -1,12 +1,13 @@ -import qim3d import multiprocessing import time +import qim3d + def test_starting_class(): app = qim3d.gui.iso3d.Interface() - assert app.title == "Isosurfaces for 3D visualization" + assert app.title == 'Isosurfaces for 3D visualization' def start_server(ip, port): @@ -15,7 +16,7 @@ def start_server(ip, port): def test_app_launch(): - ip = "localhost" + ip = 'localhost' port = 65432 proc = multiprocessing.Process(target=start_server, args=(ip, port)) diff --git a/qim3d/tests/io/test_downloader.py b/qim3d/tests/io/test_downloader.py index a6791aa0a60b0133824e62e68146f469a48f9c9c..ec1834f474a6fedc1436ff13e0139e647341ae4a 100644 --- a/qim3d/tests/io/test_downloader.py +++ b/qim3d/tests/io/test_downloader.py @@ -1,14 +1,16 @@ -import qim3d -import os -import pytest -from pathlib import Path import shutil +from pathlib import Path + +import pytest + +import qim3d -@pytest.fixture + +@pytest.fixture() def setup_temp_folder(): """Fixture to create and clean up a temporary folder for tests.""" - folder = "Cowry_Shell" - file = "Cowry_DOWNSAMPLED.tif" + folder = 'Cowry_Shell' + file = 'Cowry_DOWNSAMPLED.tif' path = Path(folder) / file # Ensure clean environment before running tests @@ -30,7 +32,7 @@ def test_download(setup_temp_folder): dl.Cowry_Shell.Cowry_DOWNSAMPLED() # Verify the file was downloaded correctly - assert path.exists(), f"{path} does not exist after download." + assert path.exists(), f'{path} does not exist after download.' img = qim3d.io.load(str(path)) assert img.shape == (500, 350, 350) @@ -40,22 +42,24 @@ def test_download(setup_temp_folder): def test_get_file_size(): """Tests for correct and incorrect file size retrieval.""" - coal_file = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Coal/CoalBrikett.tif" - folder_url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository/" + coal_file = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Coal/CoalBrikett.tif' + folder_url = ( + 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository/' + ) # Correct file size size = qim3d.io._downloader._get_file_size(coal_file) - assert size == 2_400_082_900, f"Expected size mismatch for {coal_file}." + assert size == 2_400_082_900, f'Expected size mismatch for {coal_file}.' # Wrong URL (not a file) size = qim3d.io._downloader._get_file_size(folder_url) - assert size == -1, "Expected size -1 for non-file URL." + assert size == -1, 'Expected size -1 for non-file URL.' def test_extract_html(): - url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository" + url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository' html = qim3d.io._downloader._extract_html(url) - assert 'data-path="/files/public/projects/viscomp_data_repository"' in html, \ - "Expected HTML content not found in extracted HTML." - + assert ( + 'data-path="/files/public/projects/viscomp_data_repository"' in html + ), 'Expected HTML content not found in extracted HTML.' diff --git a/qim3d/tests/io/test_load.py b/qim3d/tests/io/test_load.py index 88d0f092b9576ae293c26254de7bf0ee85e14fce..c42e806d6e26ee8afd34db92febdf094674cac49 100644 --- a/qim3d/tests/io/test_load.py +++ b/qim3d/tests/io/test_load.py @@ -1,15 +1,17 @@ -import qim3d -import numpy as np -from pathlib import Path import os -import pytest import re +from pathlib import Path + +import numpy as np +import pytest + +import qim3d # Load volume into memory vol = qim3d.examples.bone_128x128x128 # Ceate memory map to blobs -volume_path = Path(qim3d.__file__).parents[0] / "examples" / "bone_128x128x128.tif" +volume_path = Path(qim3d.__file__).parents[0] / 'examples' / 'bone_128x128x128.tif' vol_memmap = qim3d.io.load(volume_path, virtual_stack=True) @@ -26,7 +28,7 @@ def test_load_type_memmap(): def test_invalid_path(): - invalid_path = os.path.join("this", "path", "doesnt", "exist.tif") + invalid_path = os.path.join('this', 'path', 'doesnt', 'exist.tif') with pytest.raises(FileNotFoundError): qim3d.io.load(invalid_path) diff --git a/qim3d/tests/io/test_save.py b/qim3d/tests/io/test_save.py index d6cbfb3e69aa2cbcbbd6098e08cd67cd99b1c1b7..f4a53b1ffee7ec20c5ba03127950e76bff3e2452 100644 --- a/qim3d/tests/io/test_save.py +++ b/qim3d/tests/io/test_save.py @@ -11,30 +11,31 @@ import qim3d def test_image_exist(): # Create random test image - test_image = np.random.randint(0,256,(100,100,100),'uint8') + test_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.tif") + image_path = os.path.join(temp_dir, 'test_image.tif') # Save to temporary directory - qim3d.io.save(image_path,test_image) + qim3d.io.save(image_path, test_image) # Assert that test image has been saved assert os.path.exists(image_path) + def test_compression(): # Get test image (should not be random in order for compression to function) test_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.tif") - compressed_image_path = os.path.join(temp_dir,"compressed_test_image.tif") + image_path = os.path.join(temp_dir, 'test_image.tif') + compressed_image_path = os.path.join(temp_dir, 'compressed_test_image.tif') # Save to temporary directory with and without compression - qim3d.io.save(image_path,test_image) - qim3d.io.save(compressed_image_path,test_image,compression=True) + qim3d.io.save(image_path, test_image) + qim3d.io.save(compressed_image_path, test_image, compression=True) # Compute file sizes file_size = os.path.getsize(image_path) @@ -43,16 +44,17 @@ def test_compression(): # Assert that compressed file size is smaller than non-compressed file size assert compressed_file_size < file_size + def test_image_matching(): # Create random test image - original_image = np.random.randint(0,256,(100,100,100),'uint8') + original_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"original_image.tif") + image_path = os.path.join(temp_dir, 'original_image.tif') # Save to temporary directory - qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path, original_image) # Load from temporary directory saved_image = qim3d.io.load(image_path) @@ -64,16 +66,17 @@ def test_image_matching(): # Assert that original image is identical to saved_image assert original_hash == saved_hash + def test_compressed_image_matching(): # Get test image (should not be random in order for compression to function) original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"original_image.tif") + image_path = os.path.join(temp_dir, 'original_image.tif') # Save to temporary directory - qim3d.io.save(image_path,original_image,compression=True) + qim3d.io.save(image_path, original_image, compression=True) # Load from temporary directory saved_image_compressed = qim3d.io.load(image_path) @@ -85,47 +88,53 @@ def test_compressed_image_matching(): # Assert that original image is identical to saved_image assert original_hash == compressed_hash + def test_file_replace(): # Create random test image - test_image1 = np.random.randint(0,256,(100,100,100),'uint8') - test_image2 = np.random.randint(0,256,(100,100,100),'uint8') + test_image1 = np.random.randint(0, 256, (100, 100, 100), 'uint8') + test_image2 = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.tif") + image_path = os.path.join(temp_dir, 'test_image.tif') # Save first test image to temporary directory - qim3d.io.save(image_path,test_image1) + qim3d.io.save(image_path, test_image1) # Get hash hash1 = calculate_image_hash(qim3d.io.load(image_path)) - + # Replace existing file - qim3d.io.save(image_path,test_image2,replace=True) + qim3d.io.save(image_path, test_image2, replace=True) # Get hash again hash2 = calculate_image_hash(qim3d.io.load(image_path)) # Assert that the file was modified by checking if the second modification time is newer than the first assert hash1 != hash2 + def test_file_already_exists(): # Create random test image - test_image1 = np.random.randint(0,256,(100,100,100),'uint8') - test_image2 = np.random.randint(0,256,(100,100,100),'uint8') + test_image1 = np.random.randint(0, 256, (100, 100, 100), 'uint8') + test_image2 = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.tif") + image_path = os.path.join(temp_dir, 'test_image.tif') # Save first test image to temporary directory - qim3d.io.save(image_path,test_image1) + qim3d.io.save(image_path, test_image1) - with pytest.raises(ValueError,match="A file with the provided path already exists. To replace it set 'replace=True'"): + with pytest.raises( + ValueError, + match="A file with the provided path already exists. To replace it set 'replace=True'", + ): # Try to save another image to the existing path - qim3d.io.save(image_path,test_image2) + qim3d.io.save(image_path, test_image2) + def test_no_file_ext(): # Create random test image - test_image = np.random.randint(0,256,(100,100,100),'uint8') + test_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create filename without extension filename = 'test_image' @@ -134,95 +143,103 @@ def test_no_file_ext(): # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,filename) + image_path = os.path.join(temp_dir, filename) - with pytest.raises(ValueError,match=message): + with pytest.raises(ValueError, match=message): # Try to save the test image to a path witout file extension - qim3d.io.save(image_path,test_image) + qim3d.io.save(image_path, test_image) + def test_folder_doesnt_exist(): # Create random test image - test_image = np.random.randint(0,256,(100,100,100),'uint8') + test_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') - # Create invalid path - invalid_path = os.path.join('this','path','doesnt','exist.tif') + # Create invalid path + invalid_path = os.path.join('this', 'path', 'doesnt', 'exist.tif') - #message = f'The directory {re.escape(os.path.dirname(invalid_path))} does not exist. Please provide a valid directory' + # message = f'The directory {re.escape(os.path.dirname(invalid_path))} does not exist. Please provide a valid directory' message = f"""The directory '{re.escape(os.path.dirname(invalid_path))}' does not exist.\nPlease provide a valid directory or specify a basename if you want to save a tiff stack as several files to a folder that does not yet exist""" - with pytest.raises(ValueError,match=message): + with pytest.raises(ValueError, match=message): # Try to save test image to an invalid path - qim3d.io.save(invalid_path,test_image) - + qim3d.io.save(invalid_path, test_image) + + def test_unsupported_file_format(): # Create random test image - test_image = np.random.randint(0,256,(100,100,100),'uint8') + test_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create filename with unsupported format filename = 'test_image.unsupported' # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,filename) + image_path = os.path.join(temp_dir, filename) - with pytest.raises(ValueError,match='Unsupported file format'): + with pytest.raises(ValueError, match='Unsupported file format'): # Try to save test image with an unsupported file extension - qim3d.io.save(image_path,test_image) + qim3d.io.save(image_path, test_image) + def test_no_basename(): # Create random test image - test_image = np.random.randint(0,256,(100,100,100),'uint8') + test_image = np.random.randint(0, 256, (100, 100, 100), 'uint8') # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: message = f"""To save a stack as several TIFF files to the directory '{re.escape(temp_dir)}', please provide the keyword argument 'basename'. Otherwise, to save a single file, please provide a full path with a filename and valid extension.""" - with pytest.raises(ValueError,match=message): + with pytest.raises(ValueError, match=message): # Try to save test image to an existing directory (indicating # that you want to save as several files) without providing a basename - qim3d.io.save(temp_dir,test_image) + qim3d.io.save(temp_dir, test_image) + def test_mkdir_tiff_stack(): # Create random test image - test_image = np.random.randint(0,256,(10,100,100),'uint8') + test_image = np.random.randint(0, 256, (10, 100, 100), 'uint8') # create temporary directory with tempfile.TemporaryDirectory() as temp_dir: # Define a folder that does not yet exist - path2save= os.path.join(temp_dir,'tempfolder') - + path2save = os.path.join(temp_dir, 'tempfolder') + # Save to this folder with a basename - qim3d.io.save(path2save,test_image,basename='test') + qim3d.io.save(path2save, test_image, basename='test') # Assert that folder is created assert os.path.isdir(path2save) + def test_tiff_stack_naming(): # Create random test image - test_image = np.random.randint(0,256,(10,100,100),'uint8') + test_image = np.random.randint(0, 256, (10, 100, 100), 'uint8') # Define expected filenames basename = 'test' - expected_filenames = [basename + str(i).zfill(2) + '.tif' for i,_ in enumerate(test_image)] - + expected_filenames = [ + basename + str(i).zfill(2) + '.tif' for i, _ in enumerate(test_image) + ] + # create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - qim3d.io.save(temp_dir,test_image,basename=basename) + qim3d.io.save(temp_dir, test_image, basename=basename) + + assert expected_filenames == sorted(os.listdir(temp_dir)) - assert expected_filenames==sorted(os.listdir(temp_dir)) - def test_tiff_stack_slicing_dim(): # Create random test image where the three dimensions are not the same length - test_image = np.random.randint(0,256,(5,10,15),'uint8') - + test_image = np.random.randint(0, 256, (5, 10, 15), 'uint8') + with tempfile.TemporaryDirectory() as temp_dir: - # Iterate thorugh all three dims and save the image as slices in + # Iterate thorugh all three dims and save the image as slices in # each dimension in separate folder and assert the number of files # match the shape of the image for dim in range(3): - path2save = os.path.join(temp_dir,'dim'+str(dim)) - qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim) - assert len(os.listdir(path2save))==test_image.shape[dim] + path2save = os.path.join(temp_dir, 'dim' + str(dim)) + qim3d.io.save(path2save, test_image, basename='test', sliced_dim=dim) + assert len(os.listdir(path2save)) == test_image.shape[dim] + def test_tiff_save_load(): # Create random test image @@ -230,10 +247,10 @@ def test_tiff_save_load(): # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.tif") + image_path = os.path.join(temp_dir, 'test_image.tif') # Save to temporary directory - qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path, original_image) # Load from temporary directory saved_image = qim3d.io.load(image_path) @@ -245,16 +262,17 @@ def test_tiff_save_load(): # Assert that original image is identical to saved_image assert original_hash == saved_hash + def test_vol_save_load(): # Create random test image original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.vol") + image_path = os.path.join(temp_dir, 'test_image.vol') # Save to temporary directory - qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path, original_image) # Load from temporary directory saved_image = qim3d.io.load(image_path) @@ -266,18 +284,19 @@ def test_vol_save_load(): # Assert that original image is identical to saved_image assert original_hash == saved_hash + def test_pil_save_load(): # Create random test image original_image = qim3d.examples.bone_128x128x128[0] # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path_png = os.path.join(temp_dir,"test_image.png") - image_path_jpg = os.path.join(temp_dir,"test_image.jpg") + image_path_png = os.path.join(temp_dir, 'test_image.png') + image_path_jpg = os.path.join(temp_dir, 'test_image.jpg') # Save to temporary directory - qim3d.io.save(image_path_png,original_image) - qim3d.io.save(image_path_jpg,original_image, compression=True) + qim3d.io.save(image_path_png, original_image) + qim3d.io.save(image_path_jpg, original_image, compression=True) # Load from temporary directory saved_image_png = qim3d.io.load(image_path_png) @@ -289,22 +308,23 @@ def test_pil_save_load(): # Assert that original image is identical to saved_image assert original_hash == saved_png_hash - + # jpg is lossy so the hashes will not match, checks that the image is the same size and similar values assert original_image.shape == saved_image_jpg.shape + def test_nifti_save_load(): # Create random test image original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.nii") - image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + image_path = os.path.join(temp_dir, 'test_image.nii') + image_path_compressed = os.path.join(temp_dir, 'test_image_compressed.nii.gz') # Save to temporary directory - qim3d.io.save(image_path,original_image) - qim3d.io.save(image_path_compressed,original_image, compression=True) + qim3d.io.save(image_path, original_image) + qim3d.io.save(image_path_compressed, original_image, compression=True) # Load from temporary directory saved_image = qim3d.io.load(image_path) @@ -333,12 +353,12 @@ def test_h5_save_load(): # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: - image_path = os.path.join(temp_dir,"test_image.h5") - image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + image_path = os.path.join(temp_dir, 'test_image.h5') + image_path_compressed = os.path.join(temp_dir, 'test_image_compressed.nii.gz') # Save to temporary directory - qim3d.io.save(image_path,original_image) - qim3d.io.save(image_path_compressed,original_image, compression=True) + qim3d.io.save(image_path, original_image) + qim3d.io.save(image_path_compressed, original_image, compression=True) # Load from temporary directory saved_image = qim3d.io.load(image_path) @@ -361,7 +381,8 @@ def test_h5_save_load(): # Assert that compressed file size is smaller than non-compressed file size assert compressed_file_size < file_size -def calculate_image_hash(image): + +def calculate_image_hash(image): image_bytes = image.tobytes() hash_object = hashlib.md5(image_bytes) - return hash_object.hexdigest() \ No newline at end of file + return hash_object.hexdigest() diff --git a/qim3d/tests/mesh/test_mesh.py b/qim3d/tests/mesh/test_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..9e19475bfd7adbbd0d11dd4e96c03a15ca2e970a --- /dev/null +++ b/qim3d/tests/mesh/test_mesh.py @@ -0,0 +1,42 @@ +import pytest +import numpy as np +from pygel3d import hmesh +import qim3d + +def test_from_volume_valid_input(): + """Test that from_volume returns a hmesh.Manifold object for a valid 3D input.""" + volume = np.random.rand(50, 50, 50).astype(np.float32) # Generate a random 3D volume + mesh = qim3d.mesh.from_volume(volume) + assert isinstance(mesh, hmesh.Manifold) # Check if output is a Manifold object + +def test_from_volume_invalid_input(): + """Test that from_volume raises ValueError for non-3D input.""" + volume = np.random.rand(50, 50) # A 2D array + with pytest.raises(ValueError, match="The input volume must be a 3D numpy array."): + qim3d.mesh.from_volume(volume) + +def test_from_volume_empty_array(): + """Test how from_volume handles an empty 3D array.""" + volume = np.empty((0, 0, 0)) # Empty 3D array + with pytest.raises(ValueError): # It should fail because it doesn't make sense to generate a mesh from empty data + qim3d.mesh.from_volume(volume) + +def test_from_volume_with_kwargs(): + """Test that from_volume correctly passes kwargs.""" + volume = np.random.rand(50, 50, 50).astype(np.float32) + + # Mock volumetric_isocontour to check if kwargs are passed + def mock_volumetric_isocontour(vol, **kwargs): + assert "isovalue" in kwargs + assert kwargs["isovalue"] == 0.5 + return hmesh.Manifold() + + # Replace the function temporarily + original_function = hmesh.volumetric_isocontour + hmesh.volumetric_isocontour = mock_volumetric_isocontour + + try: + qim3d.mesh.from_volume(volume, isovalue=0.5) + finally: + hmesh.volumetric_isocontour = original_function # Restore original function + diff --git a/qim3d/tests/ml/models/test_unet.py b/qim3d/tests/ml/models/test_unet.py index 9f32bfce12be3a21238f2d2d388e59f873a8ef1c..78ad2e5fca176ad2257747e3b85bd55ad5ffd0d6 100644 --- a/qim3d/tests/ml/models/test_unet.py +++ b/qim3d/tests/ml/models/test_unet.py @@ -1,6 +1,8 @@ -import qim3d import torch +import qim3d + + # unit tests for UNet() def test_starting_unet(): unet = qim3d.ml.models.UNet() @@ -12,11 +14,12 @@ def test_forward_pass(): unet = qim3d.ml.models.UNet() # Size: B x C x H x W - x = torch.ones([1,1,256,256]) + x = torch.ones([1, 1, 256, 256]) output = unet(x) assert x.shape == output.shape + # unit tests for Hyperparameters() def test_hyper(): unet = qim3d.ml.models.UNet() @@ -24,10 +27,11 @@ def test_hyper(): assert hyperparams.n_epochs == 10 + def test_hyper_dict(): unet = qim3d.ml.models.UNet() hyperparams = qim3d.ml.models.Hyperparameters(unet) hyper_dict = hyperparams() - assert type(hyper_dict) == dict + assert type(hyper_dict) == dict diff --git a/qim3d/tests/ml/test_models.py b/qim3d/tests/ml/test_models.py index 547310f303e5b78ca0036dd2be464673e0975384..a16f904289f8489a1c442d55f807ce530c53b499 100644 --- a/qim3d/tests/ml/test_models.py +++ b/qim3d/tests/ml/test_models.py @@ -1,7 +1,7 @@ -import qim3d import pytest from torch import ones +import qim3d from qim3d.tests import temp_data @@ -9,10 +9,10 @@ from qim3d.tests import temp_data def test_model_summary(): n = 10 img_shape = (32, 32) - folder = "folder_data" + folder = 'folder_data' temp_data(folder, img_shape=img_shape, n=n) - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') augment = qim3d.ml.Augmentation(transform_train=None) train_set, val_set, test_set = qim3d.ml.prepare_datasets( folder, 1 / 3, unet, augment @@ -30,10 +30,10 @@ def test_model_summary(): # unit test for inference() def test_inference(): - folder = "folder_data" + folder = 'folder_data' temp_data(folder) - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') augment = qim3d.ml.Augmentation(transform_train=None) train_set, _, _ = qim3d.ml.prepare_datasets(folder, 1 / 3, unet, augment) @@ -46,13 +46,13 @@ def test_inference(): # unit test for tuple ValueError(). def test_inference_tuple(): - folder = "folder_data" + folder = 'folder_data' temp_data(folder) - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') data = [1, 2, 3] - with pytest.raises(ValueError, match="Data items must be tuples"): + with pytest.raises(ValueError, match='Data items must be tuples'): qim3d.ml.inference(data, unet) temp_data(folder, remove=True) @@ -60,13 +60,13 @@ def test_inference_tuple(): # unit test for tensor ValueError(). def test_inference_tensor(): - folder = "folder_data" + folder = 'folder_data' temp_data(folder) - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') data = [(1, 2)] - with pytest.raises(ValueError, match="Data items must consist of tensors"): + with pytest.raises(ValueError, match='Data items must consist of tensors'): qim3d.ml.inference(data, unet) temp_data(folder, remove=True) @@ -74,14 +74,14 @@ def test_inference_tensor(): # unit test for dimension ValueError(). def test_inference_dim(): - folder = "folder_data" + folder = 'folder_data' temp_data(folder) - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') data = [(ones(1), ones(1))] # need the r"" for special characters - with pytest.raises(ValueError, match=r"Input image must be \(C,H,W\) format"): + with pytest.raises(ValueError, match=r'Input image must be \(C,H,W\) format'): qim3d.ml.inference(data, unet) temp_data(folder, remove=True) @@ -89,12 +89,12 @@ def test_inference_dim(): # unit test for train_model() def test_train_model(): - folder = "folder_data" + folder = 'folder_data' temp_data(folder) n_epochs = 1 - unet = qim3d.ml.models.UNet(size="small") + unet = qim3d.ml.models.UNet(size='small') augment = qim3d.ml.Augmentation(transform_train=None) hyperparams = qim3d.ml.Hyperparameters(unet, n_epochs=n_epochs) train_set, val_set, test_set = qim3d.ml.prepare_datasets( @@ -108,6 +108,6 @@ def test_train_model(): unet, hyperparams, train_loader, val_loader, plot=False, return_loss=True ) - assert len(train_loss["loss"]) == n_epochs + assert len(train_loss['loss']) == n_epochs temp_data(folder, remove=True) diff --git a/qim3d/tests/notebooks/test_notebooks.py b/qim3d/tests/notebooks/test_notebooks.py index edac804fdfb767b1ef1ee658153f1760337769a3..be69139da909bb2cad0c16d4a3edade1913ac294 100644 --- a/qim3d/tests/notebooks/test_notebooks.py +++ b/qim3d/tests/notebooks/test_notebooks.py @@ -1,25 +1,31 @@ from testbook import testbook + def test_blob_detection_notebook(): with testbook('./docs/notebooks/blob_detection.ipynb', execute=True) as tb: pass + def test_filters_notebook(): with testbook('./docs/notebooks/filters.ipynb', execute=True) as tb: pass + def test_local_thickness_notebook(): with testbook('./docs/notebooks/local_thickness.ipynb', execute=True) as tb: pass + def test_logging_notebook(): with testbook('./docs/notebooks/Logging.ipynb', execute=True) as tb: pass + def test_references_from_doi_notebook(): with testbook('./docs/notebooks/References from DOI.ipynb', execute=True) as tb: pass + def test_structure_tensor_notebook(): with testbook('./docs/notebooks/structure_tensor.ipynb', execute=True) as tb: - pass \ No newline at end of file + pass diff --git a/qim3d/tests/processing/test_local_thickness.py b/qim3d/tests/processing/test_local_thickness.py index edc18c4d31d422c8f85b1cf966f6e1a02a729474..056fa6c9773db7f1055852d6158eea93a28b9d73 100644 --- a/qim3d/tests/processing/test_local_thickness.py +++ b/qim3d/tests/processing/test_local_thickness.py @@ -1,7 +1,8 @@ -import qim3d import numpy as np from skimage.draw import disk, ellipsoid -import pytest + +import qim3d + def test_local_thickness_2d(): # Create a binary 2D image @@ -21,12 +22,17 @@ def test_local_thickness_2d(): assert np.allclose(lt, lt_manual, rtol=1e-1) + def test_local_thickness_3d(): - disk3d = ellipsoid(15,15,15) + disk3d = ellipsoid(15, 15, 15) # Remove weird border pixels border_thickness = 2 - disk3d = disk3d[border_thickness:-border_thickness, border_thickness:-border_thickness, border_thickness:-border_thickness] + disk3d = disk3d[ + border_thickness:-border_thickness, + border_thickness:-border_thickness, + border_thickness:-border_thickness, + ] disk3d = np.pad(disk3d, border_thickness, mode='constant') lt = qim3d.processing.local_thickness(disk3d) diff --git a/qim3d/tests/processing/test_structure_tensor.py b/qim3d/tests/processing/test_structure_tensor.py index 0c6421b075baab3ab263384c4957bcf80dfe5f0e..0a1f4713e0322856c640c7a9bf5dd249b23a377b 100644 --- a/qim3d/tests/processing/test_structure_tensor.py +++ b/qim3d/tests/processing/test_structure_tensor.py @@ -1,12 +1,15 @@ -import pytest import numpy as np +import pytest + import qim3d + def test_wrong_ndim(): img_2d = np.random.rand(50, 50) - with pytest.raises(ValueError, match = "The input volume must be 3D"): + with pytest.raises(ValueError, match='The input volume must be 3D'): qim3d.processing.structure_tensor(img_2d, 1.5, 1.5) + def test_structure_tensor(): volume = np.random.rand(50, 50, 50) val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5) @@ -16,6 +19,7 @@ def test_structure_tensor(): assert np.all(val[1] <= val[2]) assert np.all(val[0] <= val[2]) + def test_structure_tensor_full(): volume = np.random.rand(50, 50, 50) val, vec = qim3d.processing.structure_tensor(volume, 1.5, 1.5, full=True) @@ -23,4 +27,4 @@ def test_structure_tensor_full(): assert vec.shape == (3, 3, 50, 50, 50) assert np.all(val[0] <= val[1]) assert np.all(val[1] <= val[2]) - assert np.all(val[0] <= val[2]) \ No newline at end of file + assert np.all(val[0] <= val[2]) diff --git a/qim3d/tests/segmentation/test_connected_components.py b/qim3d/tests/segmentation/test_connected_components.py index 9d59aff94281eb08acc361a2c633b4e9e30d6916..e5abb819f973e4b33f3fef7b11fd0e20db69bac1 100644 --- a/qim3d/tests/segmentation/test_connected_components.py +++ b/qim3d/tests/segmentation/test_connected_components.py @@ -4,46 +4,54 @@ import pytest from qim3d.segmentation._connected_components import get_3d_cc -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def setup_data(): - components = np.array([[0,0,1,1,0,0], - [0,0,0,1,0,0], - [1,1,0,0,1,0], - [0,0,0,1,0,0]]) + components = np.array( + [[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [1, 1, 0, 0, 1, 0], [0, 0, 0, 1, 0, 0]] + ) num_components = 4 connected_components = get_3d_cc(components) return connected_components, components, num_components + def test_connected_components_property(setup_data): connected_components, _, _ = setup_data - components = np.array([[0,0,1,1,0,0], - [0,0,0,1,0,0], - [2,2,0,0,3,0], - [0,0,0,4,0,0]]) + components = np.array( + [[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [2, 2, 0, 0, 3, 0], [0, 0, 0, 4, 0, 0]] + ) assert np.array_equal(connected_components.get_cc(), components) + def test_num_connected_components_property(setup_data): connected_components, _, num_components = setup_data assert len(connected_components) == num_components + def test_get_connected_component_with_index(setup_data): connected_components, _, _ = setup_data - expected_component = np.array([[0,0,1,1,0,0], - [0,0,0,1,0,0], - [0,0,0,0,0,0], - [0,0,0,0,0,0]], dtype=bool) + expected_component = np.array( + [ + [0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ) print(connected_components.get_cc(index=1)) print(expected_component) assert np.array_equal(connected_components.get_cc(index=1), expected_component) + def test_get_connected_component_without_index(setup_data): connected_components, _, _ = setup_data component = connected_components.get_cc() assert np.any(component) + def test_get_connected_component_with_invalid_index(setup_data): connected_components, _, num_components = setup_data with pytest.raises(AssertionError): connected_components.get_cc(index=0) with pytest.raises(AssertionError): - connected_components.get_cc(index=num_components + 1) \ No newline at end of file + connected_components.get_cc(index=num_components + 1) diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py index 809e7ddbe2771519a5c71d4032edd7ebbd029d7f..f25598171fe5a1e92abbe8e51e550901d86087b6 100644 --- a/qim3d/tests/utils/test_augmentations.py +++ b/qim3d/tests/utils/test_augmentations.py @@ -1,13 +1,14 @@ -import qim3d import albumentations import pytest +import qim3d + # unit tests for Augmentation() def test_augmentation(): augment_class = qim3d.ml.Augmentation() - assert augment_class.resize == "crop" + assert augment_class.resize == 'crop' def test_augment(): @@ -20,7 +21,7 @@ def test_augment(): # unit tests for ValueErrors in Augmentation() def test_resize(): - resize_str = "not valid resize" + resize_str = 'not valid resize' with pytest.raises( ValueError, @@ -32,7 +33,7 @@ def test_resize(): def test_levels(): augment_class = qim3d.ml.Augmentation() - level = "Not a valid level" + level = 'Not a valid level' with pytest.raises( ValueError, diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py index 81beef5334ddf4bbdd20de34b6d52f55a13f6307..6de4f551e656e3c9d68e07e0e865cb1b9cf15279 100644 --- a/qim3d/tests/utils/test_data.py +++ b/qim3d/tests/utils/test_data.py @@ -1,62 +1,83 @@ -import qim3d import pytest - from torch.utils.data.dataloader import DataLoader + +import qim3d from qim3d.tests import temp_data + # unit tests for Dataset() def test_dataset(): - img_shape = (32,32) + img_shape = (32, 32) folder = 'folder_data' - temp_data(folder, img_shape = img_shape) - + temp_data(folder, img_shape=img_shape) + images = qim3d.ml.Dataset(folder) assert images[0][0].shape == img_shape - temp_data(folder,remove=True) + temp_data(folder, remove=True) # unit tests for check_resize() def test_check_resize(): - h_adjust,w_adjust = qim3d.ml._data.check_resize(240,240,resize = 'crop',n_channels = 6) + h_adjust, w_adjust = qim3d.ml._data.check_resize( + 240, 240, resize='crop', n_channels=6 + ) + + assert (h_adjust, w_adjust) == (192, 192) - assert (h_adjust,w_adjust) == (192,192) def test_check_resize_pad(): - h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'padding',n_channels = 6) + h_adjust, w_adjust = qim3d.ml._data.check_resize( + 16, 16, resize='padding', n_channels=6 + ) - assert (h_adjust,w_adjust) == (64,64) + assert (h_adjust, w_adjust) == (64, 64) -def test_check_resize_fail(): - with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."): - h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'crop',n_channels = 6) +def test_check_resize_fail(): + with pytest.raises( + ValueError, + match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.", + ): + h_adjust, w_adjust = qim3d.ml._data.check_resize( + 16, 16, resize='crop', n_channels=6 + ) # unit tests for prepare_datasets() def test_prepare_datasets(): n = 3 - validation = 1/3 - + validation = 1 / 3 + folder = 'folder_data' - img = temp_data(folder,n = n) + img = temp_data(folder, n=n) my_model = qim3d.ml.models.UNet() my_augmentation = qim3d.ml.Augmentation(transform_test='light') - train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,validation,my_model,my_augmentation) + train_set, val_set, test_set = qim3d.ml.prepare_datasets( + folder, validation, my_model, my_augmentation + ) - assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) + assert (len(train_set), len(val_set), len(test_set)) == ( + int((1 - validation) * n), + int(n * validation), + n, + ) - temp_data(folder,remove=True) + temp_data(folder, remove=True) # unit test for validation in prepare_datasets() def test_validation(): validation = 10 - - with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."): - augment_class = qim3d.ml.prepare_datasets('folder',validation,'my_model','my_augmentation') + + with pytest.raises( + ValueError, match='The validation fraction must be a float between 0 and 1.' + ): + augment_class = qim3d.ml.prepare_datasets( + 'folder', validation, 'my_model', 'my_augmentation' + ) # unit test for prepare_dataloaders() @@ -67,12 +88,14 @@ def test_prepare_dataloaders(): batch_size = 1 my_model = qim3d.ml.models.UNet() my_augmentation = qim3d.ml.Augmentation() - train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,1/3,my_model,my_augmentation) + train_set, val_set, test_set = qim3d.ml.prepare_datasets( + folder, 1 / 3, my_model, my_augmentation + ) + + _, val_loader, _ = qim3d.ml.prepare_dataloaders( + train_set, val_set, test_set, batch_size, num_workers=1, pin_memory=False + ) - _,val_loader,_ = qim3d.ml.prepare_dataloaders(train_set,val_set,test_set, - batch_size,num_workers = 1, - pin_memory = False) - assert type(val_loader) == DataLoader - temp_data(folder,remove=True) \ No newline at end of file + temp_data(folder, remove=True) diff --git a/qim3d/tests/utils/test_doi.py b/qim3d/tests/utils/test_doi.py index 859ebdb92baefca4eba721b3ac71fb00e1d5dec6..003d5ab6c166fcb820fd8a2801a063b1f72a4962 100644 --- a/qim3d/tests/utils/test_doi.py +++ b/qim3d/tests/utils/test_doi.py @@ -1,15 +1,15 @@ import qim3d -doi = "https://doi.org/10.1007/s10851-021-01041-3" +doi = 'https://doi.org/10.1007/s10851-021-01041-3' def test_get_bibtex(): bibtext = qim3d.utils._doi.get_bibtex(doi) - assert "Measuring Shape Relations Using r-Parallel Sets" in bibtext + assert 'Measuring Shape Relations Using r-Parallel Sets' in bibtext def test_get_reference(): reference = qim3d.utils._doi.get_reference(doi) - assert "Stephensen" in reference + assert 'Stephensen' in reference diff --git a/qim3d/tests/utils/test_helpers.py b/qim3d/tests/utils/test_helpers.py index cdbbedc4589f9801077444605874916a8558f0c1..76e4c673b50f0dd3613ddcb18486b64af11ba180 100644 --- a/qim3d/tests/utils/test_helpers.py +++ b/qim3d/tests/utils/test_helpers.py @@ -1,8 +1,9 @@ -import qim3d import os import re from pathlib import Path +import qim3d + def test_mock_plot(): fig = qim3d.tests.mock_plot() @@ -11,12 +12,12 @@ def test_mock_plot(): def test_mock_write_file(): - filename = "test.txt" - content = "test file" + filename = 'test.txt' + content = 'test file' qim3d.tests.mock_write_file(filename, content=content) # Check contents - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding='utf-8') as f: file_content = f.read() # Remove temp file @@ -27,7 +28,7 @@ def test_mock_write_file(): def test_get_local_ip(): def validate_ip(ip_str): - reg = r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$" + reg = r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$' if re.match(reg, ip_str): return True else: @@ -40,7 +41,7 @@ def test_get_local_ip(): def test_stringify_path1(): """Test that the function converts os.PathLike objects to strings""" - blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif" + blobs_path = Path(qim3d.__file__).parents[0] / 'img_examples' / 'blobs_256x256.tif' assert str(blobs_path) == qim3d.utils._misc.stringify_path(blobs_path) @@ -48,6 +49,6 @@ def test_stringify_path1(): def test_stringify_path2(): """Test that the function returns input unchanged if input is a string""" # Create test_path - test_path = os.path.join("this", "path", "doesnt", "exist.tif") + test_path = os.path.join('this', 'path', 'doesnt', 'exist.tif') assert test_path == qim3d.utils._misc.stringify_path(test_path) diff --git a/qim3d/tests/utils/test_system.py b/qim3d/tests/utils/test_system.py index 2b1f88cdb2bbbc8be02a7c00f32a44154ccced47..0d33e134ee1c57a28d2741bc71d801705d0c68c9 100644 --- a/qim3d/tests/utils/test_system.py +++ b/qim3d/tests/utils/test_system.py @@ -1,11 +1,9 @@ import qim3d + def test_memory(): mem = qim3d.utils.Memory() - assert all([ - mem.total>0, - mem.used>0, - mem.free>0]) - + assert all([mem.total > 0, mem.used > 0, mem.free > 0]) + assert mem.used + mem.free == mem.total diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index d33d00cc54e0d989cf870e474343e2787e069cfd..cd806833e8cd877f1afd917cb63b69335d8d6917 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -1,17 +1,12 @@ -import pytest -import torch -import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt +import numpy as np import pytest -from torch import ones +import torch import qim3d from qim3d.tests import temp_data -import matplotlib.pyplot as plt -import ipywidgets as widgets - # unit tests for grid overview def test_grid_overview(): @@ -26,13 +21,13 @@ def test_grid_overview(): def test_grid_overview_tuple(): random_tuple = (torch.ones(256, 256), torch.ones(256, 256)) - with pytest.raises(ValueError, match="Data elements must be tuples"): + with pytest.raises(ValueError, match='Data elements must be tuples'): qim3d.viz.grid_overview(random_tuple, num_images=1) # unit tests for grid prediction def test_grid_pred(): - folder = "folder_data" + folder = 'folder_data' n = 4 temp_data(folder, n=n) @@ -57,8 +52,8 @@ def test_slices_numpy_array_input(): def test_slices_wrong_input_format(): - input = "not_a_volume" - with pytest.raises(ValueError, match="Data type not supported"): + input = 'not_a_volume' + with pytest.raises(ValueError, match='Data type not supported'): qim3d.viz.slices_grid(input) @@ -66,7 +61,7 @@ def test_slices_not_volume(): example_volume = np.ones((10, 10)) with pytest.raises( ValueError, - match="The provided object is not a volume as it has less than 3 dimensions.", + match='The provided object is not a volume as it has less than 3 dimensions.', ): qim3d.viz.slices_grid(example_volume) @@ -77,7 +72,7 @@ def test_slices_wrong_position_format1(): ValueError, match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', ): - qim3d.viz.slices_grid(example_volume, slice_positions="invalid_slice") + qim3d.viz.slices_grid(example_volume, slice_positions='invalid_slice') def test_slices_wrong_position_format2(): @@ -110,7 +105,7 @@ def test_slices_invalid_axis_value(): def test_slices_interpolation_option(): example_volume = np.ones((10, 10, 10)) img_width = 3 - interpolation_method = "bilinear" + interpolation_method = 'bilinear' fig = qim3d.viz.slices_grid( example_volume, num_slices=1, @@ -130,7 +125,9 @@ def test_slices_multiple_slices(): example_volume = np.ones((10, 10, 10)) image_width = 3 num_slices = 3 - fig = qim3d.viz.slices_grid(example_volume, num_slices=num_slices, image_width=image_width) + fig = qim3d.viz.slices_grid( + example_volume, num_slices=num_slices, image_width=image_width + ) # Add assertions for the expected number of subplots in the figure assert len(fig.get_axes()) == num_slices @@ -192,7 +189,7 @@ def test_slicer_with_different_parameters(): assert isinstance(slicer_obj, widgets.interactive) # Test with different colormaps - for cmap in ["viridis", "gray", "plasma"]: + for cmap in ['viridis', 'gray', 'plasma']: slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap) assert isinstance(slicer_obj, widgets.interactive) @@ -232,14 +229,18 @@ def test_orthogonal_with_torch_tensor(): def test_orthogonal_with_different_parameters(): # Test with different colormaps - for color_map in ["viridis", "gray", "plasma"]: - orthogonal_obj = qim3d.viz.slicer_orthogonal(np.random.rand(10, 10, 10), color_map=color_map) + for color_map in ['viridis', 'gray', 'plasma']: + orthogonal_obj = qim3d.viz.slicer_orthogonal( + np.random.rand(10, 10, 10), color_map=color_map + ) assert isinstance(orthogonal_obj, widgets.HBox) # Test with different image sizes for image_height, image_width in [(2, 2), (4, 4)]: orthogonal_obj = qim3d.viz.slicer_orthogonal( - np.random.rand(10, 10, 10), image_height=image_height, image_width=image_width + np.random.rand(10, 10, 10), + image_height=image_height, + image_width=image_width, ) assert isinstance(orthogonal_obj, widgets.HBox) @@ -266,7 +267,7 @@ def test_orthogonal_slider_description(): # Call the orthogonal function with the NumPy array orthogonal_obj = qim3d.viz.slicer_orthogonal(vol) for idx, slicer in enumerate(orthogonal_obj.children): - assert slicer.children[0].description == ["Z", "Y", "X"][idx] + assert slicer.children[0].description == ['Z', 'Y', 'X'][idx] # unit tests for local thickness visualization diff --git a/qim3d/tests/viz/test_visualizations.py b/qim3d/tests/viz/test_visualizations.py index c9a8beceac2ae601b03adbc77d473fd339b577d1..2831ba786244b0428b5f033b4ed7f06cfdde7788 100644 --- a/qim3d/tests/viz/test_visualizations.py +++ b/qim3d/tests/viz/test_visualizations.py @@ -3,17 +3,19 @@ import pytest import qim3d -#unit test for plot_metrics() +# unit test for plot_metrics() def test_plot_metrics(): - metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} - - fig = qim3d.viz.plot_metrics(metrics, figsize=(10,10)) + metrics = {'epoch_loss': [0.3, 0.2, 0.1], 'batch_loss': [0.3, 0.2, 0.1]} - assert (fig.get_figwidth(),fig.get_figheight()) == (10,10) + fig = qim3d.viz.plot_metrics(metrics, figsize=(10, 10)) + + assert (fig.get_figwidth(), fig.get_figheight()) == (10, 10) def test_plot_metrics_labels(): - metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]} - - with pytest.raises(ValueError,match="The number of metrics doesn't match the number of labels."): - qim3d.viz.plot_metrics(metrics,labels = ['a','b']) \ No newline at end of file + metrics = {'epoch_loss': [0.3, 0.2, 0.1], 'batch_loss': [0.3, 0.2, 0.1]} + + with pytest.raises( + ValueError, match="The number of metrics doesn't match the number of labels." + ): + qim3d.viz.plot_metrics(metrics, labels=['a', 'b']) diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py index bf4ca1fa152d07189d26525d856970063f1dbaad..0a7a174c9eb3f5a17f83cd8d3c66710380504b86 100644 --- a/qim3d/utils/__init__.py +++ b/qim3d/utils/__init__.py @@ -1,18 +1,15 @@ from ._doi import * -from ._system import Memory - from ._logger import log - from ._misc import ( - get_local_ip, - port_from_str, - gradio_header, - sizeof, + downscale_img, + get_css, get_file_size, + get_local_ip, get_port_dict, - get_css, - downscale_img, + gradio_header, + port_from_str, scale_to_float16, + sizeof, ) - from ._server import start_http_server +from ._system import Memory diff --git a/qim3d/utils/_cli.py b/qim3d/utils/_cli.py index 3dd42a5a66ff05b0a68da916f2f44e32473aba79..e2e3ff09065164d16f9c539159046deca945c630 100644 --- a/qim3d/utils/_cli.py +++ b/qim3d/utils/_cli.py @@ -1,17 +1,25 @@ import argparse -from qim3d.gui import data_explorer, iso3d, annotation_tool, local_thickness, layers2d + +from qim3d.gui import annotation_tool, data_explorer, iso3d, layers2d, local_thickness + def main(): parser = argparse.ArgumentParser(description='qim3d command-line interface.') subparsers = parser.add_subparsers(title='Subcommands', dest='subcommand') # subcommands - gui_parser = subparsers.add_parser('gui', help = 'Graphical User Interfaces.') + gui_parser = subparsers.add_parser('gui', help='Graphical User Interfaces.') - gui_parser.add_argument('--data-explorer', action='store_true', help='Run data explorer.') + gui_parser.add_argument( + '--data-explorer', action='store_true', help='Run data explorer.' + ) gui_parser.add_argument('--iso3d', action='store_true', help='Run iso3d.') - gui_parser.add_argument('--annotation-tool', action='store_true', help='Run annotation tool.') - gui_parser.add_argument('--local-thickness', action='store_true', help='Run local thickness tool.') + gui_parser.add_argument( + '--annotation-tool', action='store_true', help='Run annotation tool.' + ) + gui_parser.add_argument( + '--local-thickness', action='store_true', help='Run local thickness tool.' + ) gui_parser.add_argument('--layers2d', action='store_true', help='Run layers2d.') gui_parser.add_argument('--host', default='0.0.0.0', help='Desired host.') @@ -20,20 +28,20 @@ def main(): if args.subcommand == 'gui': arghost = args.host if args.data_explorer: - data_explorer.run_interface(arghost) elif args.iso3d: iso3d.run_interface(arghost) - + elif args.annotation_tool: annotation_tool.run_interface(arghost) - + elif args.local_thickness: local_thickness.run_interface(arghost) elif args.layers2d: layers2d.run_interface(arghost) - + + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/qim3d/utils/_doi.py b/qim3d/utils/_doi.py index 1859ed53eeb7c5e0ffa62da3066868b1de73053a..bc01eaed09066b33e9b25c57c3ca285317c376a6 100644 --- a/qim3d/utils/_doi.py +++ b/qim3d/utils/_doi.py @@ -1,21 +1,24 @@ -""" Deals with DOI for references """ +"""Deals with DOI for references""" + import json + import requests + from qim3d.utils._logger import log def _validate_response(response: requests.Response) -> bool: # Check if we got a good response if not response.ok: - log.error(f"Could not read the provided DOI ({response.reason})") + log.error(f'Could not read the provided DOI ({response.reason})') return False return True def _doi_to_url(doi: str) -> str: - if doi[:3] != "http": - url = "https://doi.org/" + doi + if doi[:3] != 'http': + url = 'https://doi.org/' + doi else: url = doi @@ -52,12 +55,14 @@ def _log_and_get_text(doi, header) -> str: def get_bibtex(doi: str): """Generates bibtex from doi""" - header = {"Accept": "application/x-bibtex"} + header = {'Accept': 'application/x-bibtex'} return _log_and_get_text(doi, header) + def custom_header(doi: str, header: str) -> str: - """Allows a custom header to be passed + """ + Allows a custom header to be passed Example: import qim3d @@ -68,15 +73,17 @@ def custom_header(doi: str, header: str) -> str: """ return _log_and_get_text(doi, header) + def get_metadata(doi: str) -> dict: """Generates a metadata dictionary from doi""" - header = {"Accept": "application/vnd.citationstyles.csl+json"} + header = {'Accept': 'application/vnd.citationstyles.csl+json'} response = _make_request(doi, header) metadata = json.loads(response.text) return metadata + def get_reference(doi: str) -> str: """Generates a metadata dictionary from doi and use it to build a reference string""" @@ -85,15 +92,18 @@ def get_reference(doi: str) -> str: return reference_string + def build_reference_string(metadata: dict) -> str: """Generates a reference string from metadata""" - authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']]) + authors = ', '.join( + [f"{author['family']} {author['given']}" for author in metadata['author']] + ) year = metadata['issued']['date-parts'][0][0] title = metadata['title'] publisher = metadata['publisher'] url = metadata['URL'] doi = metadata['DOI'] - reference_string = f"{authors} ({year}). {title}. {publisher} ({url}). DOI: {doi}" + reference_string = f'{authors} ({year}). {title}. {publisher} ({url}). DOI: {doi}' return reference_string diff --git a/qim3d/utils/_logger.py b/qim3d/utils/_logger.py index 599171bfb0a84d1384160d33df1d244d8d6bff75..cedbf8e4e4d9cf1bf32e558c067478059001e320 100644 --- a/qim3d/utils/_logger.py +++ b/qim3d/utils/_logger.py @@ -2,20 +2,23 @@ import logging -logger = logging.getLogger("qim3d") +logger = logging.getLogger('qim3d') + def set_detailed_output(): - """Configures the logging output to display detailed information. + """ + Configures the logging output to display detailed information. This function sets up a logging formatter with a specific format that includes the log level, filename, line number, and log message. Example: >>> set_detailed_output() + """ formatter = logging.Formatter( - "%(levelname)-10s%(filename)s:%(lineno)-5s%(message)s" + '%(levelname)-10s%(filename)s:%(lineno)-5s%(message)s' ) handler = logging.StreamHandler() handler.setFormatter(formatter) @@ -32,8 +35,9 @@ def set_simple_output(): Example: >>> set_simple_output() + """ - formatter = logging.Formatter("%(message)s") + formatter = logging.Formatter('%(message)s') handler = logging.StreamHandler() handler.setFormatter(formatter) logger.handlers = [] @@ -41,55 +45,66 @@ def set_simple_output(): def set_level_debug(): - """Sets the logging level of the "qim3d" logger to DEBUG. + """ + Sets the logging level of the "qim3d" logger to DEBUG. Example: >>> set_level_debug() + """ logger.setLevel(logging.DEBUG) def set_level_info(): - """Sets the logging level of the "qim3d" logger to INFO. + """ + Sets the logging level of the "qim3d" logger to INFO. Example: >>> set_level_info() + """ logger.setLevel(logging.INFO) def set_level_warning(): - """Sets the logging level of the "qim3d" logger to WARNING. + """ + Sets the logging level of the "qim3d" logger to WARNING. Example: >>> set_level_warning() + """ logger.setLevel(logging.WARNING) def set_level_error(): - """Sets the logging level of the "qim3d" logger to ERROR. + """ + Sets the logging level of the "qim3d" logger to ERROR. Example: >>> set_level_error() + """ logger.setLevel(logging.ERROR) def set_level_critical(): - """Sets the logging level of the "qim3d" logger to CRITICAL. + """ + Sets the logging level of the "qim3d" logger to CRITICAL. Example: >>> set_level_critical() + """ logger.setLevel(logging.CRITICAL) def level(log_level): - """Set the logging level based on the specified level. + """ + Set the logging level based on the specified level. Args: log_level (str or int): The logging level to set. It can be one of: @@ -104,19 +119,19 @@ def level(log_level): ValueError: If the specified level is not a valid logging level. """ - if log_level in ["DEBUG", "debug"]: + if log_level in ['DEBUG', 'debug']: set_level_debug() - elif log_level in ["INFO", "info"]: + elif log_level in ['INFO', 'info']: set_level_info() - elif log_level in ["WARNING", "warning"]: + elif log_level in ['WARNING', 'warning']: set_level_warning() - elif log_level in ["ERROR", "error"]: + elif log_level in ['ERROR', 'error']: set_level_error() - elif log_level in ["CRITICAL", "critical"]: + elif log_level in ['CRITICAL', 'critical']: set_level_critical() elif isinstance(log_level, int): @@ -130,6 +145,6 @@ def level(log_level): # Create the logger -log = logging.getLogger("qim3d") +log = logging.getLogger('qim3d') set_level_info() set_simple_output() diff --git a/qim3d/utils/_misc.py b/qim3d/utils/_misc.py index b8f61bd875ef6de8711d726d21240604e367e1c3..bd4eb317c491f13e922ee93b7989dd20202dbfe1 100644 --- a/qim3d/utils/_misc.py +++ b/qim3d/utils/_misc.py @@ -1,19 +1,22 @@ -""" Provides a collection of internal utility functions.""" +"""Provides a collection of internal utility functions.""" +import difflib import getpass import hashlib import os import socket + import numpy as np import outputformat as ouf import requests from scipy.ndimage import zoom -import difflib + import qim3d def get_local_ip() -> str: - """Retrieves the local IP address of the current machine. + """ + Retrieves the local IP address of the current machine. The function uses a socket to determine the local IP address. Then, it tries to connect to the IP address "192.255.255.255" @@ -23,20 +26,21 @@ def get_local_ip() -> str: network is not available, the function falls back to returning the loopback address "127.0.0.1". - Returns: + Returns str: The local IP address. Example usage: ip_address = get_local_ip() + """ _socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable - _socket.connect(("192.255.255.255", 1)) + _socket.connect(('192.255.255.255', 1)) ip_address = _socket.getsockname()[0] - except socket.error: - ip_address = "127.0.0.1" + except OSError: + ip_address = '127.0.0.1' finally: _socket.close() return ip_address @@ -61,12 +65,14 @@ def port_from_str(s: str) -> int: Example usage: port = port_from_str("my_specific_app_name") + """ - return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4) + return int(hashlib.sha1(s.encode('utf-8')).hexdigest(), 16) % (10**4) def gradio_header(title: str, port: int) -> None: - """Display the header for a Gradio server. + """ + Display the header for a Gradio server. Displays a formatted header containing the provided title, the port number being used, and the IP address where the server is running. @@ -93,14 +99,15 @@ def gradio_header(title: str, port: int) -> None: ouf.br(2) details = [ f'{ouf.c(title, color="rainbow", cmap="cool", bold=True, return_str=True)}', - f"Using port {port}", - f"Running at {get_local_ip()}", + f'Using port {port}', + f'Running at {get_local_ip()}', ] - ouf.showlist(details, style="box", title="Starting gradio server") + ouf.showlist(details, style='box', title='Starting gradio server') -def sizeof(num: float, suffix: str = "B") -> str: - """Converts a number to a human-readable string representing its size. +def sizeof(num: float, suffix: str = 'B') -> str: + """ + Converts a number to a human-readable string representing its size. Converts the given number to a human-readable string representing its size in a more readable format, such as bytes (B), kilobytes (KB), megabytes (MB), @@ -124,17 +131,18 @@ def sizeof(num: float, suffix: str = "B") -> str: '1.0 KB' >>> qim3d.utils.sizeof(1234567890) '1.1 GB' + """ - for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: if abs(num) < 1024.0: - return f"{num:3.1f} {unit}{suffix}" + return f'{num:3.1f} {unit}{suffix}' num /= 1024.0 - return f"{num:.1f} Y{suffix}" + return f'{num:.1f} Y{suffix}' def find_similar_paths(path: str) -> list[str]: - parent_dir = os.path.dirname(path) or "." - parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" + parent_dir = os.path.dirname(path) or '.' + parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else '' valid_paths = [os.path.join(parent_dir, file) for file in parent_files] similar_paths = difflib.get_close_matches(path, valid_paths) @@ -144,12 +152,13 @@ def find_similar_paths(path: str) -> list[str]: def get_file_size(file_path: str) -> int: """ Args: - ----- + ---- filename (str): Specifies full path to file Returns: - --------- + ------- size (int): size of file in bytes + """ try: file_size = os.path.getsize(file_path) @@ -161,8 +170,8 @@ def get_file_size(file_path: str) -> int: message = f"Invalid path. Did you mean '{suggestion}'?" raise FileNotFoundError(repr(message)) else: - raise FileNotFoundError("Invalid path") - + raise FileNotFoundError('Invalid path') + return file_size @@ -176,7 +185,7 @@ def stringify_path(path: os.PathLike) -> str: def get_port_dict() -> dict: # Gets user and port username = getpass.getuser() - url = f"https://platform.qim.dk/qim-api/get-port/{username}" + url = f'https://platform.qim.dk/qim-api/get-port/{username}' response = requests.get(url, timeout=10) # Check if the request was successful (status code 200) @@ -185,25 +194,25 @@ def get_port_dict() -> dict: port_dict = response.json() else: # Print an error message if the request was not successful - raise (f"Error: {response.status_code}") + raise (f'Error: {response.status_code}') return port_dict def get_css() -> str: - current_directory = os.path.dirname(os.path.abspath(__file__)) parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir)) - css_path = os.path.join(parent_directory, "css", "gradio.css") + css_path = os.path.join(parent_directory, 'css', 'gradio.css') - with open(css_path, "r") as file: + with open(css_path) as file: css_content = file.read() return css_content def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray: - """Downscale image if total number of voxels exceeds 512³. + """ + Downscale image if total number of voxels exceeds 512³. Args: img (np.Array): Input image. @@ -211,6 +220,7 @@ def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray: Returns: np.Array: Downscaled image if total number of voxels exceeds 512³. + """ # Calculate total number of pixels in the image @@ -231,10 +241,10 @@ def scale_to_float16(arr: np.ndarray) -> np.ndarray: """ Scale the input array to the float16 data type. - Parameters: + Parameters arr (np.ndarray): Input array to be scaled. - Returns: + Returns np.ndarray: Scaled array with dtype=np.float16. This function scales the input array to the float16 data type, ensuring that the @@ -242,6 +252,7 @@ def scale_to_float16(arr: np.ndarray) -> np.ndarray: for float16. If the maximum value of the input array exceeds the maximum representable value for float16, the array is scaled down proportionally to fit within the float16 range. + """ # Get the maximum value to comprare with the float16 maximum value diff --git a/qim3d/utils/_ome_zarr.py b/qim3d/utils/_ome_zarr.py index db7ad145ca96d0f4b0fade1c5558a6eb07aac01e..3d599db9e8ff75ef775134ed3ec6ee7ba52097b7 100644 --- a/qim3d/utils/_ome_zarr.py +++ b/qim3d/utils/_ome_zarr.py @@ -1,15 +1,17 @@ -from zarr.util import normalize_chunks, normalize_dtype, normalize_shape import numpy as np +from zarr.util import normalize_chunks, normalize_dtype, normalize_shape + -def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]: +def get_chunk_size(shape: tuple, dtype: tuple) -> tuple[int, ...]: """ - How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use + How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use in qim3d.io.export_ome_zarr function Parameters ---------- - shape: shape of the data - dtype: dtype of the data + """ object_codec = None dtype, object_codec = normalize_dtype(dtype, object_codec) @@ -29,12 +31,12 @@ def get_n_chunks(shapes: tuple, dtypes: tuple) -> int: ---------- - shapes: list of shapes of the data for each scale - dtype: dtype of the data + """ n_chunks = 0 for shape, dtype in zip(shapes, dtypes): chunk_size = np.array(get_chunk_size(shape, dtype)) shape = np.array(shape) - ratio = shape/chunk_size + ratio = shape / chunk_size n_chunks += np.prod(ratio) return int(n_chunks) - diff --git a/qim3d/utils/_progress_bar.py b/qim3d/utils/_progress_bar.py index 57f33738911a3c8493f5da855a8a3863f800f0c0..b9a792a91a30cf33961022219caebf1724407fbc 100644 --- a/qim3d/utils/_progress_bar.py +++ b/qim3d/utils/_progress_bar.py @@ -1,15 +1,16 @@ -from threading import Timer -import psutil -import sys import os +import sys from abc import ABC, abstractmethod +from threading import Timer +import psutil from tqdm.auto import tqdm from qim3d.utils._misc import get_file_size class RepeatTimer(Timer): + """ If the memory check is set as a normal thread, there is no garuantee it will switch resulting in not enough memory checks to create smooth progress bar or to make it @@ -23,19 +24,21 @@ class RepeatTimer(Timer): while not self.finished.wait(self.interval): self.function(*self.args, **self.kwargs) + class ProgressBar(ABC): - def __init__(self, tqdm_kwargs: dict, repeat_time: float, *args, **kwargs): + def __init__(self, tqdm_kwargs: dict, repeat_time: float, *args, **kwargs): """ - Context manager for ('with' statement) to track progress during a long progress over + Context manager for ('with' statement) to track progress during a long progress over which we don't have control (like loading a file) and thus can not insert the tqdm updates into loop Thus we have to run parallel thread with forced activation to check the state - Parameters: - ------------ + Parameters + ---------- - tqdm_kwargs (dict): Passed directly to tqdm constructor - repeat_time (float): How often the timer runs the function (in seconds) + """ self.timer = RepeatTimer(repeat_time, self.update_pbar) self.pbar = tqdm(**tqdm_kwargs) @@ -47,16 +50,13 @@ class ProgressBar(ABC): try: self.pbar.update(update) - except ( - AttributeError - ): # When we leave the context manager, we delete the pbar so it can not be updated anymore + except AttributeError: # When we leave the context manager, we delete the pbar so it can not be updated anymore # It's because it takes quite a long time for the timer to end and might update the pbar # one more time before ending which messes up the whole thing pass self.last_update = new_update - @abstractmethod def get_new_update(self): pass @@ -72,30 +72,30 @@ class ProgressBar(ABC): del self.pbar # So the update process can not update it anymore - class FileLoadingProgressBar(ProgressBar): def __init__(self, filename: str, repeat_time: float = 0.5, *args, **kwargs): """ Context manager ('with' statement) to track progress during loading a file into memory - Parameters: - ------------ + Parameters + ---------- - filename (str): to get size of the file - repeat_time (float, optional): How often the timer checks how many bytes were loaded. Even if very small, it doesn't make the progress bar smoother as there are only few visible changes in number of read_chars. Defaults to 0.5 + """ tqdm_kwargs = dict( total=get_file_size(filename), - desc="Loading: ", - unit="B", + desc='Loading: ', + unit='B', file=sys.stdout, unit_scale=True, unit_divisor=1024, - bar_format="{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, " - "{rate_fmt}{postfix}]", + bar_format='{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, ' + '{rate_fmt}{postfix}]', ) - super().__init__( tqdm_kwargs, repeat_time) + super().__init__(tqdm_kwargs, repeat_time) self.process = psutil.Process() def get_new_update(self) -> int: @@ -106,8 +106,9 @@ class FileLoadingProgressBar(ProgressBar): memory = counters.read_bytes + counters.other_bytes return memory + class OmeZarrExportProgressBar(ProgressBar): - def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"): + def __init__(self, path: str, n_chunks: int, reapeat_time: str = 'auto'): """ Context manager to track the exporting of OmeZarr files. @@ -118,14 +119,13 @@ class OmeZarrExportProgressBar(ProgressBar): n_chunks : int The total number of chunks to track. repeat_time : int or float, optional - The interval (in seconds) for updating the progress bar. Defaults to "auto", which + The interval (in seconds) for updating the progress bar. Defaults to "auto", which sets the update frequency based on the number of chunks. - """ - + """ # Calculate the repeat time for the progress bar - if reapeat_time == "auto": + if reapeat_time == 'auto': # Approximate the repeat time based on the number of chunks # This ratio is based on reading the HOA dataset over the network: # 620,000 files took 300 seconds to read @@ -142,11 +142,7 @@ class OmeZarrExportProgressBar(ProgressBar): self.path = path tqdm_kwargs = dict( - total = n_chunks, - unit = "Chunks", - desc = "Saving", - unit_scale = True - + total=n_chunks, unit='Chunks', desc='Saving', unit_scale=True ) super().__init__(tqdm_kwargs, reapeat_time) self.last_update = 0 @@ -154,7 +150,7 @@ class OmeZarrExportProgressBar(ProgressBar): def get_new_update(self): def file_count(folder_path: str) -> int: """ - Goes recursively through the folders and counts how many files are there, + Goes recursively through the folders and counts how many files are there, Doesn't count metadata json files """ count = 0 @@ -162,7 +158,7 @@ class OmeZarrExportProgressBar(ProgressBar): new_path = os.path.join(folder_path, path) if os.path.isfile(new_path): filename = os.path.basename(os.path.normpath(new_path)) - if not filename.startswith("."): + if not filename.startswith('.'): count += 1 else: count += file_count(new_path) diff --git a/qim3d/utils/_server.py b/qim3d/utils/_server.py index f61651585428ad293353bd9bdc837d29a6275acc..4515ca0ec5abb08aac63d8df6af07edfd14992ed 100644 --- a/qim3d/utils/_server.py +++ b/qim3d/utils/_server.py @@ -1,73 +1,79 @@ import os -from http.server import SimpleHTTPRequestHandler, HTTPServer import threading +from http.server import HTTPServer, SimpleHTTPRequestHandler + from qim3d.utils._logger import log + class CustomHTTPRequestHandler(SimpleHTTPRequestHandler): def end_headers(self): """Add CORS headers to each response.""" # Allow requests from any origin, or restrict to specific domains by specifying the origin - self.send_header("Access-Control-Allow-Origin", "*") + self.send_header('Access-Control-Allow-Origin', '*') # Allow specific methods - self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS') # Allow specific headers (if needed) - self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type") + self.send_header( + 'Access-Control-Allow-Headers', 'X-Requested-With, Content-Type' + ) super().end_headers() - def list_directory(self, path: str|os.PathLike): + def list_directory(self, path: str | os.PathLike): """Helper to produce a directory listing, includes hidden files.""" try: file_list = os.listdir(path) except OSError: - self.send_error(404, "No permission to list directory") + self.send_error(404, 'No permission to list directory') return None - + # Sort the file list file_list.sort(key=lambda a: a.lower()) - + # Format the list with hidden files included displaypath = os.path.basename(path) r = ['<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">'] - r.append(f"<html>\n<title>Directory listing for {displaypath}</title>\n") - r.append(f"<body>\n<h2>Directory listing for {displaypath}</h2>\n") - r.append("<hr>\n<ul>") + r.append(f'<html>\n<title>Directory listing for {displaypath}</title>\n') + r.append(f'<body>\n<h2>Directory listing for {displaypath}</h2>\n') + r.append('<hr>\n<ul>') for name in file_list: fullname = os.path.join(path, name) displayname = linkname = name - + # Append the files and directories to the HTML list if os.path.isdir(fullname): - displayname = name + "/" - linkname = name + "/" + displayname = name + '/' + linkname = name + '/' r.append(f'<li><a href="{linkname}">{displayname}</a></li>') - r.append("</ul>\n<hr>\n</body>\n</html>\n") - encoded = "\n".join(r).encode('utf-8', 'surrogateescape') + r.append('</ul>\n<hr>\n</body>\n</html>\n') + encoded = '\n'.join(r).encode('utf-8', 'surrogateescape') self.send_response(200) - self.send_header("Content-Type", "text/html; charset=utf-8") - self.send_header("Content-Length", str(len(encoded))) + self.send_header('Content-Type', 'text/html; charset=utf-8') + self.send_header('Content-Length', str(len(encoded))) self.end_headers() # Write the encoded HTML directly to the response self.wfile.write(encoded) + def start_http_server(directory: str, port: int = 8000) -> HTTPServer: """ Starts an HTTP server serving the specified directory on the given port with CORS enabled. - - Parameters: + + Parameters directory (str): The directory to serve. port (int): The port number to use (default is 8000). + """ # Change the working directory to the specified directory os.chdir(directory) - + # Create the server - server = HTTPServer(("", port), CustomHTTPRequestHandler) - + server = HTTPServer(('', port), CustomHTTPRequestHandler) + # Run the server in a separate thread so it doesn't block execution thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() - + log.info(f"Serving directory '{directory}'\nhttp://localhost:{port}/") - + return server diff --git a/qim3d/utils/_system.py b/qim3d/utils/_system.py index ccccb6724812cc162d948b805cde89844ce47431..e720a3f8d2283d9c7879405c86567ecf9592b082 100644 --- a/qim3d/utils/_system.py +++ b/qim3d/utils/_system.py @@ -1,20 +1,26 @@ """Provides tools for obtaining information about the system.""" + import os import time + +import numpy as np import psutil -from qim3d.utils._misc import sizeof + from qim3d.utils._logger import log -import numpy as np +from qim3d.utils._misc import sizeof class Memory: - """Class for obtaining current memory information - Attributes: + """ + Class for obtaining current memory information + + Attributes total (int): Total system memory in bytes free (int): Free system memory in bytes used (int): Used system memory in bytes used_pct (float): Used system memory in percentage + """ def __init__(self): @@ -28,7 +34,7 @@ class Memory: def report(self): log.info( - "System memory:\n • Total.: %s\n • Used..: %s (%s%%)\n • Free..: %s (%s%%)", + 'System memory:\n • Total.: %s\n • Used..: %s (%s%%)\n • Free..: %s (%s%%)', sizeof(self.total), sizeof(self.used), round(self.used_pct, 1), @@ -36,8 +42,11 @@ class Memory: round(self.free_pct, 1), ) -def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]: - ''' + +def _test_disk_speed( + file_size_bytes: int = 1024, ntimes: int = 10 +) -> tuple[float, float, float, float]: + """ Test the write and read speed of the disk by writing a file of a given size and then reading it back. @@ -56,7 +65,8 @@ def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[floa print(f"Write speed: {write_speed:.2f} GB/s") print(f"Read speed: {read_speed:.2f} GB/s") ``` - ''' + + """ write_speeds = [] read_speeds = [] @@ -64,26 +74,26 @@ def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[floa for _ in range(ntimes): # Generate random data for the file data = os.urandom(file_size_bytes) - + # Write data to a temporary file with open('temp_file.bin', 'wb') as f: start_write = time.time() f.write(data) end_write = time.time() - + # Read data from the temporary file with open('temp_file.bin', 'rb') as f: start_read = time.time() f.read() end_read = time.time() - + # Calculate read and write speed (GB/s) write_speed = file_size_bytes / (end_write - start_write) / (1024**3) read_speed = file_size_bytes / (end_read - start_read) / (1024**3) write_speeds.append(write_speed) read_speeds.append(read_speed) - + # Clean up temporary file os.remove('temp_file.bin') @@ -91,12 +101,12 @@ def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[floa write_speed_std = np.std(write_speeds) avg_read_speed = np.mean(read_speeds) read_speed_std = np.std(read_speeds) - + return avg_write_speed, write_speed_std, avg_read_speed, read_speed_std def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None: - ''' + """ Report the average write and read speed of the disk. Args: @@ -108,16 +118,19 @@ def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None: qim3d.io.logger.level("info") qim3d.utils.system.disk_report() ``` - ''' + + """ # Test disk speed - avg_write_speed, write_speed_std, avg_read_speed, read_speed_std = _test_disk_speed(file_size_bytes=file_size, ntimes=ntimes) - + avg_write_speed, write_speed_std, avg_read_speed, read_speed_std = _test_disk_speed( + file_size_bytes=file_size, ntimes=ntimes + ) + # Print disk information log.info( - "Disk:\n • Write speed..: %.2f GB/s (± %.2f GB/s)\n • Read speed...: %.2f GB/s (± %.2f GB/s)", + 'Disk:\n • Write speed..: %.2f GB/s (± %.2f GB/s)\n • Read speed...: %.2f GB/s (± %.2f GB/s)', avg_write_speed, write_speed_std, avg_read_speed, - read_speed_std + read_speed_std, ) diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 7f21d7b9046e7337f88e9c54eb81ff1abce0db2d..f785319d6eb90429839d5fc1a4770b4783e6c8d5 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,18 +1,19 @@ -from . import colormaps +from . import _layers2d, colormaps from ._cc import plot_cc -from ._detection import circles from ._data_exploration import ( + chunks, fade_mask, + histogram, + line_profile, slicer, slicer_orthogonal, slices_grid, - chunks, - histogram, + threshold, ) -from .itk_vtk_viewer import itk_vtk -from ._k3d import volumetric, mesh +from ._detection import circles +from ._k3d import mesh, volumetric from ._local_thickness import local_thickness -from ._structure_tensor import vectors -from ._metrics import plot_metrics, grid_overview, grid_pred, vol_masked +from ._metrics import grid_overview, grid_pred, plot_metrics, vol_masked from ._preview import image_preview -from . import _layers2d +from ._structure_tensor import vectors +from .itk_vtk_viewer import itk_vtk diff --git a/qim3d/viz/_cc.py b/qim3d/viz/_cc.py index 8ace8abf830f3cc3e0cc7eb5219f90c90f378242..89fb978fcaa0474edcf74add95014e8db577e9b9 100644 --- a/qim3d/viz/_cc.py +++ b/qim3d/viz/_cc.py @@ -1,8 +1,10 @@ import matplotlib.pyplot as plt import numpy as np + import qim3d -from qim3d.utils._logger import log from qim3d.segmentation._connected_components import CC +from qim3d.utils._logger import log + def plot_cc( connected_components: CC, @@ -11,7 +13,7 @@ def plot_cc( overlay: np.ndarray = None, crop: bool = False, display_figure: bool = True, - color_map: str = "viridis", + color_map: str = 'viridis', value_min: float = None, value_max: float = None, **kwargs, @@ -19,7 +21,7 @@ def plot_cc( """ Plots the connected components from a `qim3d.processing.cc.CC` object. If an overlay image is provided, the connected component will be masked to the overlay image. - Parameters: + Parameters connected_components (CC): The connected components object. component_indexs (list or tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None. max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32. @@ -31,7 +33,7 @@ def plot_cc( value_max (float or None, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None **kwargs (Any): Additional keyword arguments to pass to `qim3d.viz.slices_grid`. - Returns: + Returns figs (list[plt.Figure]): List of figures, if `display_figure=False`. Example: @@ -45,12 +47,13 @@ def plot_cc( ```   + """ # if no components are given, plot the first max_cc_to_plot=32 components if component_indexs is None: if len(connected_components) > max_cc_to_plot: log.warning( - f"More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components." + f'More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components.' ) component_indexs = range( 1, min(max_cc_to_plot + 1, len(connected_components) + 1) @@ -61,7 +64,7 @@ def plot_cc( if overlay is not None: assert ( overlay.shape == connected_components.shape - ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." + ), f'Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}.' # plots overlay masked to connected component if crop: @@ -75,16 +78,25 @@ def plot_cc( cc = connected_components.get_cc(component, crop=False) overlay_crop = np.where(cc == 0, 0, overlay) fig = qim3d.viz.slices_grid( - overlay_crop, display_figure=display_figure, color_map=color_map, value_min=value_min, value_max=value_max, **kwargs + overlay_crop, + display_figure=display_figure, + color_map=color_map, + value_min=value_min, + value_max=value_max, + **kwargs, ) else: # assigns discrete color map to each connected component if not given - if "color_map" not in kwargs: - kwargs["color_map"] = qim3d.viz.colormaps.segmentation(len(component_indexs)) + if 'color_map' not in kwargs: + kwargs['color_map'] = qim3d.viz.colormaps.segmentation( + len(component_indexs) + ) # Plot the connected component without overlay fig = qim3d.viz.slices_grid( - connected_components.get_cc(component, crop=crop), display_figure=display_figure, **kwargs + connected_components.get_cc(component, crop=crop), + display_figure=display_figure, + **kwargs, ) figs.append(fig) diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index 08f2ba9b8b5c28c9719cdd88aee02f5b50e00eb6..e8d3e0f4b4c72401db9dcb586e02c9406b27a92f 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -1,25 +1,33 @@ -""" +""" Provides a collection of visualization functions. """ import math import warnings - -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import dask.array as da import ipywidgets as widgets +import matplotlib import matplotlib.figure import matplotlib.pyplot as plt -import matplotlib -from IPython.display import SVG, display -import matplotlib import numpy as np -import zarr -from qim3d.utils._logger import log import seaborn as sns +import skimage.measure +from skimage.filters import ( + threshold_otsu, + threshold_isodata, + threshold_li, + threshold_mean, + threshold_minimum, + threshold_triangle, + threshold_yen, +) + +from IPython.display import clear_output, display import qim3d +from qim3d.utils._logger import log def slices_grid( @@ -41,7 +49,8 @@ def slices_grid( color_bar_style: str = "small", **matplotlib_imshow_kwargs, ) -> matplotlib.figure.Figure: - """Displays one or several slices from a 3d volume. + """ + Displays one or several slices from a 3d volume. By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices. If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position. @@ -82,6 +91,7 @@ def slices_grid( qim3d.viz.slices_grid(vol, num_slices=15) ```  + """ if image_size: image_height = image_size @@ -330,7 +340,8 @@ def slicer( color_bar: str = None, **matplotlib_imshow_kwargs, ) -> widgets.interactive: - """Interactive widget for visualizing slices of a 3D volume. + """ + Interactive widget for visualizing slices of a 3D volume. Args: volume (np.ndarray): The 3D volume to be sliced. @@ -355,34 +366,35 @@ def slicer( qim3d.viz.slicer(vol) ```  + """ if image_size: image_height = image_size image_width = image_size - color_bar_options = [None, 'slices', 'volume'] + color_bar_options = [None, "slices", "volume"] if color_bar not in color_bar_options: raise ValueError( f"Unrecognized value '{color_bar}' for parameter color_bar. " f"Expected one of {color_bar_options}." ) show_color_bar = color_bar is not None - if color_bar == 'slices': + if color_bar == "slices": # Precompute the minimum and maximum along each slice for faster widget sliding. non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis) slice_mins = np.min(volume, axis=non_slice_axes) slice_maxs = np.max(volume, axis=non_slice_axes) - + # Create the interactive widget def _slicer(slice_positions): - if color_bar == 'slices': + if color_bar == "slices": dynamic_min = slice_mins[slice_positions] dynamic_max = slice_maxs[slice_positions] else: dynamic_min = value_min dynamic_max = value_max - + fig = slices_grid( volume, slice_axis=slice_axis, @@ -424,8 +436,9 @@ def slicer_orthogonal( display_positions: bool = False, interpolation: Optional[str] = None, image_size: int = None, -)-> widgets.interactive: - """Interactive widget for visualizing orthogonal slices of a 3D volume. +) -> widgets.interactive: + """ + Interactive widget for visualizing orthogonal slices of a 3D volume. Args: volume (np.ndarray): The 3D volume to be sliced. @@ -448,6 +461,7 @@ def slicer_orthogonal( qim3d.viz.slicer_orthogonal(vol, color_map="magma") ```  + """ if image_size: @@ -483,8 +497,9 @@ def fade_mask( color_map: str = "magma", value_min: float = None, value_max: float = None, -)-> widgets.interactive: - """Interactive widget for visualizing the effect of edge fading on a 3D volume. +) -> widgets.interactive: + """ + Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. @@ -497,7 +512,7 @@ def fade_mask( Returns: slicer_obj (widgets.HBox): The interactive widget for visualizing fade mask on slices of a 3D volume. - + Example: ```python import qim3d @@ -605,7 +620,8 @@ def fade_mask( # Create the Checkbox widget invert_checkbox = widgets.Checkbox( - value=False, description="Invert" # default value + value=False, + description="Invert", # default value ) slicer_obj = widgets.interactive( @@ -621,7 +637,7 @@ def fade_mask( return slicer_obj -def chunks(zarr_path: str, **kwargs)-> widgets.interactive: +def chunks(zarr_path: str, **kwargs) -> widgets.interactive: """ Function to visualize chunks of a Zarr dataset using the specified visualization method. @@ -644,6 +660,7 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive: qim3d.viz.chunks("Escargot.zarr") ```  + """ # Load the Zarr dataset @@ -794,7 +811,6 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive: # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0 def update_coordinate_dropdowns(scale): - disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns multiscale_shape = zarr_data[scale].shape @@ -873,75 +889,64 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive: def histogram( volume: np.ndarray, bins: Union[int, str] = "auto", - slice_idx: Union[int, str] = None, + slice_idx: Union[int, str, None] = None, + vertical_line: int = None, axis: int = 0, kde: bool = True, log_scale: bool = False, despine: bool = True, show_title: bool = True, color: str = "qim3d", - edgecolor: str|None = None, - figsize: tuple[float, float] = (8, 4.5), + edgecolor: Optional[str] = None, + figsize: Tuple[float, float] = (8, 4.5), element: str = "step", return_fig: bool = False, show: bool = True, - **sns_kwargs, -) -> None|matplotlib.figure.Figure: + ax: Optional[plt.Axes] = None, + **sns_kwargs: Union[str, float, int, bool] +) -> Optional[Union[plt.Figure, plt.Axes]]: """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. - + Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Args: volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. - bins (int or str, optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". + bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". axis (int, optional): Axis along which to take a slice. Default is 0. - slice_idx (int or str or None, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. + slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None. + vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None. kde (bool, optional): Whether to overlay a kernel density estimate. Default is True. log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False. despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True. show_title (bool, optional): If True, displays a title with slice information. Default is True. color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d". edgecolor (str, optional): Color for the edges of the histogram bars. Default is None. - figsize (tuple of floats, optional): Size of the figure (width, height). Default is (8, 4.5). + figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5). element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step". return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False. show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True. - **sns_kwargs (Any): Additional keyword arguments for `seaborn.histplot`. + ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None. + **sns_kwargs: Additional keyword arguments for `seaborn.histplot`. Returns: - fig (Optional[matplotlib.figure.Figure]): If `return_fig` is True, returns the generated figure object. Otherwise, returns None. + Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: + If `return_fig` is True, returns the generated figure object. + If `return_fig` is False and `ax` is provided, returns the `Axes` object. + Otherwise, returns None. Raises: ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `slice_idx` is an integer and is out of range for the specified axis. - - Example: - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol) - ``` -  - - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True) - ``` -  """ - if not (0 <= axis < volume.ndim): raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.") if slice_idx == "middle": slice_idx = volume.shape[axis] // 2 - if slice_idx: + if slice_idx is not None: if 0 <= slice_idx < volume.shape[axis]: img_slice = np.take(volume, indices=slice_idx, axis=axis) data = img_slice.ravel() @@ -954,10 +959,14 @@ def histogram( data = volume.ravel() title = f"Intensity histogram for whole volume {volume.shape}" - fig, ax = plt.subplots(figsize=figsize) + # Use provided Axes or create new figure + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = None if log_scale: - plt.yscale("log") + ax.set_yscale("log") if color == "qim3d": color = qim3d.viz.colormaps.qim(1.0) @@ -969,13 +978,23 @@ def histogram( color=color, element=element, edgecolor=edgecolor, + ax=ax, # Plot directly on the specified Axes **sns_kwargs, ) + if vertical_line is not None: + ax.axvline( + x=vertical_line, + color='red', + linestyle="--", + linewidth=2, + + ) + if despine: sns.despine( fig=None, - ax=None, + ax=ax, top=True, right=True, left=False, @@ -984,17 +1003,521 @@ def histogram( trim=True, ) - plt.xlabel("Voxel Intensity") - plt.ylabel("Frequency") + ax.set_xlabel("Voxel Intensity") + ax.set_ylabel("Frequency") if show_title: - plt.title(title, fontsize=10) + ax.set_title(title, fontsize=10) # Handle show and return - if show: + if show and fig is not None: plt.show() - else: - plt.close(fig) if return_fig: return fig + elif ax is not None: + return ax + + + +class _LineProfile: + def __init__( + self, + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ): + self.volume = volume + self.slice_axis = slice_axis + + self.dims = np.array(volume.shape) + self.pad = 1 # Padding on pivot point to avoid border issues + self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1] + + self.initialize_widgets() + self.update_slice_axis(slice_axis) + self.slice_index_widget.value = slice_index + self.x_widget.value = horizontal_position + self.y_widget.value = vertical_position + self.angle_widget.value = angle + self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]] + + def update_slice_axis(self, slice_axis): + self.slice_axis = slice_axis + self.slice_index_widget.max = self.volume.shape[slice_axis] - 1 + self.slice_index_widget.value = self.volume.shape[slice_axis] // 2 + + self.x_max, self.y_max = np.delete(self.dims, self.slice_axis) - 1 + self.x_widget.max = self.x_max - self.pad + self.x_widget.value = self.x_max // 2 + self.y_widget.max = self.y_max - self.pad + self.y_widget.value = self.y_max // 2 + + def initialize_widgets(self): + layout = widgets.Layout(width="300px", height="auto") + self.x_widget = widgets.IntSlider( + min=self.pad, step=1, description="", layout=layout + ) + self.y_widget = widgets.IntSlider( + min=self.pad, step=1, description="", layout=layout + ) + self.angle_widget = widgets.IntSlider( + min=0, max=360, step=1, value=0, description="", layout=layout + ) + self.line_fraction_widget = widgets.FloatRangeSlider( + min=0, max=1, step=0.01, value=[0, 1], description="", layout=layout + ) + + self.slice_axis_widget = widgets.Dropdown( + options=[0, 1, 2], value=self.slice_axis, description="Slice axis" + ) + self.slice_axis_widget.layout.width = "250px" + + self.slice_index_widget = widgets.IntSlider( + min=0, step=1, description="Slice index", layout=layout + ) + self.slice_index_widget.layout.width = "400px" + + def calculate_line_endpoints(self, x, y, angle): + """ + Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)] + """ + if np.isclose(angle, 0): + return [0, y], [self.x_max, y] + elif np.isclose(angle, np.pi / 2): + return [x, 0], [x, self.y_max] + elif np.isclose(angle, np.pi): + return [self.x_max, y], [0, y] + elif np.isclose(angle, 3 * np.pi / 2): + return [x, self.y_max], [x, 0] + elif np.isclose(angle, 2 * np.pi): + return [0, y], [self.x_max, y] + + t_left = -x / np.cos(angle) + t_bottom = -y / np.sin(angle) + t_right = (self.x_max - x) / np.cos(angle) + t_top = (self.y_max - y) / np.sin(angle) + t_values = np.array([t_left, t_top, t_right, t_bottom]) + t_pos = np.min(t_values[t_values > 0]) + t_neg = np.max(t_values[t_values < 0]) + + src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)] + dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)] + return src, dst + + def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range): + if slice_axis != self.slice_axis: + self.update_slice_axis(slice_axis) + x = self.x_widget.value + y = self.y_widget.value + slice_index = self.slice_index_widget.value + + clear_output(wait=True) + + image = np.take(self.volume, slice_index, slice_axis) + angle = np.radians(angle_deg) + src, dst = ( + np.array(point, dtype="float32") + for point in self.calculate_line_endpoints(x, y, angle) + ) + + # Rescale endpoints + line_vec = dst - src + dst = src + fraction_range[1] * line_vec + src = src + fraction_range[0] * line_vec + + y_pline = skimage.measure.profile_line(image, src, dst) + + fig, ax = plt.subplots(1, 2, figsize=(10, 5)) + + # Image with color-gradiented line + num_segments = 100 + x_seg = np.linspace(src[0], dst[0], num_segments) + y_seg = np.linspace(src[1], dst[1], num_segments) + segments = np.stack( + [ + np.column_stack([y_seg[:-2], x_seg[:-2]]), + np.column_stack([y_seg[2:], x_seg[2:]]), + ], + axis=1, + ) + norm = plt.Normalize(vmin=0, vmax=num_segments - 1) + colors = self.cmap(norm(np.arange(num_segments - 1))) + lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2) + + ax[0].imshow(image, cmap="gray") + ax[0].add_collection(lc) + # pivot point + ax[0].plot(y, x, marker="s", linestyle="", color="cyan", markersize=4) + ax[0].set_xlabel(f"axis {np.delete(np.arange(3), self.slice_axis)[1]}") + ax[0].set_ylabel(f"axis {np.delete(np.arange(3), self.slice_axis)[0]}") + + # Profile intensity plot + norm = plt.Normalize(0, vmax=len(y_pline) - 1) + x_pline = np.arange(len(y_pline)) + points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :] + segments = np.concatenate([points[:-1], points[1:]], axis=1) + lc = matplotlib.collections.LineCollection( + segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2 + ) + + ax[1].add_collection(lc) + ax[1].autoscale() + ax[1].set_xlabel("Distance along line") + ax[1].grid(True) + plt.tight_layout() + plt.show() + + def build_interactive(self): + # Group widgets into two columns + title_style = ( + "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" + ) + title_column1 = widgets.HTML( + f"<div style='{title_style}'>Line parameterization</div>" + ) + title_column2 = widgets.HTML( + f"<div style='{title_style}'>Slice selection</div>" + ) + + # Make label widgets instead of descriptions which have different lengths. + label_layout = widgets.Layout(width="120px") + label_x = widgets.Label("Vertical position", layout=label_layout) + label_y = widgets.Label("Horizontal position", layout=label_layout) + label_angle = widgets.Label("Angle (°)", layout=label_layout) + label_fraction = widgets.Label("Fraction range", layout=label_layout) + + row_x = widgets.HBox([label_x, self.x_widget]) + row_y = widgets.HBox([label_y, self.y_widget]) + row_angle = widgets.HBox([label_angle, self.angle_widget]) + row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget]) + + controls_column1 = widgets.VBox( + [title_column1, row_x, row_y, row_angle, row_fraction] + ) + controls_column2 = widgets.VBox( + [title_column2, self.slice_axis_widget, self.slice_index_widget] + ) + controls = widgets.HBox([controls_column1, controls_column2]) + + interactive_plot = widgets.interactive_output( + self.update, + { + "slice_axis": self.slice_axis_widget, + "slice_index": self.slice_index_widget, + "x": self.x_widget, + "y": self.y_widget, + "angle_deg": self.angle_widget, + "fraction_range": self.line_fraction_widget, + }, + ) + + return widgets.VBox([controls, interactive_plot]) + + +def line_profile( + volume: np.ndarray, + slice_axis: int = 0, + slice_index: int | str = "middle", + vertical_position: int | str = "middle", + horizontal_position: int | str = "middle", + angle: int = 0, + fraction_range: Tuple[float, float] = (0.00, 1.00), +) -> widgets.interactive: + """ + Returns an interactive widget for visualizing the intensity profiles of lines on slices. + + Args: + volume (np.ndarray): The 3D volume of interest. + slice_axis (int, optional): Specifies the initial axis along which to slice. + slice_index (int or str, optional): Specifies the initial slice index along slice_axis. + vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point. + horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point. + angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. + fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0]. + + Returns: + widget (widgets.widget_box.VBox): The interactive widget. + + + Example: + ```python + import qim3d + + vol = qim3d.examples.bone_128x128x128 + qim3d.viz.line_profile(vol) + ``` +  + + """ + + def parse_position(pos, pos_range, name): + if isinstance(pos, int): + if not pos_range[0] <= pos < pos_range[1]: + raise ValueError( + f"Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]" + ) + return pos + elif isinstance(pos, str): + pos = pos.lower() + if pos == "start": + return pos_range[0] + elif pos == "middle": + return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 + elif pos == "end": + return pos_range[1] + else: + raise ValueError( + f"Invalid string '{pos}' for {name}. " + "Must be 'start', 'middle', or 'end'." + ) + else: + raise TypeError("Axis position must be of type int or str.") + + if not isinstance(volume, (np.ndarray, da.core.Array)): + raise ValueError("Data type for volume not supported.") + if volume.ndim != 3: + raise ValueError("Volume must be 3D.") + + dims = volume.shape + slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), "slice_index") + # the omission of the ends for the pivot point is due to border issues. + vertical_position = parse_position( + vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), "vertical_position" + ) + horizontal_position = parse_position( + horizontal_position, + (1, np.delete(dims, slice_axis)[1] - 2), + "horizontal_position", + ) + + if not isinstance(angle, int | float): + raise ValueError("Invalid type for angle.") + angle = round(angle) % 360 + + if not ( + 0.0 <= fraction_range[0] <= 1.0 + and 0.0 <= fraction_range[1] <= 1.0 + and fraction_range[0] <= fraction_range[1] + ): + raise ValueError("Invalid values for fraction_range.") + + lp = _LineProfile( + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ) + return lp.build_interactive() + + +def threshold( + volume: np.ndarray, + cmap_image: str = 'magma', + vmin: float = None, + vmax: float = None, +) -> widgets.VBox: + """ + This function provides an interactive interface to explore thresholding on a + 3D volume slice-by-slice. Users can either manually set the threshold value + using a slider or select an automatic thresholding method from `skimage`. + + The visualization includes the original image slice, a binary mask showing regions above the + threshold and an overlay combining the binary mask and the original image. + + Args: + volume (np.ndarray): 3D volume to threshold. + cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'. + cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'. + vmin (float, optional): Minimum value for the colormap. Defaults to None. + vmax (float, optional): Maximum value for the colormap. Defaults to None. + + Returns: + slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume. + + Interactivity: + - **Manual Thresholding**: + Select 'Manual' from the dropdown menu to manually adjust the threshold + using the slider. + - **Automatic Thresholding**: + Choose a method from the dropdown menu to apply an automatic thresholding + algorithm. Available methods include: + - Otsu + - Isodata + - Li + - Mean + - Minimum + - Triangle + - Yen + + The threshold slider will display the computed value and will be disabled + in this mode. + + + ```python + import qim3d + + # Load a sample volume + vol = qim3d.examples.bone_128x128x128 + + # Visualize interactive thresholding + qim3d.viz.threshold(vol) + ``` +  + + """ + + # Centralized state dictionary to track current parameters + state = { + "position": volume.shape[0] // 2, + "threshold": int((volume.min() + volume.max()) / 2), + "method": "Manual", + } + + threshold_methods = { + "Otsu": threshold_otsu, + "Isodata": threshold_isodata, + "Li": threshold_li, + "Mean": threshold_mean, + "Minimum": threshold_minimum, + "Triangle": threshold_triangle, + "Yen": threshold_yen, + } + + # Create an output widget to display the plot + output = widgets.Output() + + # Function to update the state and trigger visualization + def update_state(change): + # Update state based on widget values + state["position"] = position_slider.value + state["method"] = method_dropdown.value + + if state["method"] == "Manual": + state["threshold"] = threshold_slider.value + threshold_slider.disabled = False + else: + threshold_func = threshold_methods.get(state["method"]) + if threshold_func: + slice_img = volume[state["position"], :, :] + computed_threshold = threshold_func(slice_img) + state["threshold"] = computed_threshold + + # Programmatically update the slider without triggering callbacks + threshold_slider.unobserve_all() + threshold_slider.value = computed_threshold + threshold_slider.disabled = True + threshold_slider.observe(update_state, names="value") + else: + raise ValueError(f"Unsupported thresholding method: {state['method']}") + + # Trigger visualization + update_visualization() + + # Visualization function + def update_visualization(): + slice_img = volume[state["position"], :, :] + with output: + output.clear_output(wait=True) # Clear previous plot + fig, axes = plt.subplots(1, 4, figsize=(25, 5)) + + # Original image + new_vmin = ( + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin + ) + new_vmax = ( + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax + ) + axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) + axes[0].set_title("Original") + axes[0].axis("off") + + # Histogram + histogram( + volume=volume, + bins=32, + slice_idx=state["position"], + vertical_line=state["threshold"], + axis=1, + kde=False, + ax=axes[1], + show=False, + ) + axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}") + + # Binary mask + mask = slice_img > state["threshold"] + axes[2].imshow(mask, cmap="gray") + axes[2].set_title("Binary mask") + axes[2].axis("off") + + # Overlay + mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + mask_rgb[:, :, 0] = mask + masked_volume = qim3d.operations.overlay_rgb_images( + background=slice_img, + foreground=mask_rgb, + ) + axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax) + axes[3].set_title("Overlay") + axes[3].axis("off") + + plt.show() + + # Widgets + position_slider = widgets.IntSlider( + value=state["position"], + min=0, + max=volume.shape[0] - 1, + description="Slice", + ) + + threshold_slider = widgets.IntSlider( + value=state["threshold"], + min=volume.min(), + max=volume.max(), + description="Threshold", + ) + + method_dropdown = widgets.Dropdown( + options=[ + "Manual", + "Otsu", + "Isodata", + "Li", + "Mean", + "Minimum", + "Triangle", + "Yen", + ], + value=state["method"], + description="Method", + ) + + # Attach the state update function to widgets + position_slider.observe(update_state, names="value") + threshold_slider.observe(update_state, names="value") + method_dropdown.observe(update_state, names="value") + + # Layout + controls_left = widgets.VBox([position_slider, threshold_slider]) + controls_right = widgets.VBox([method_dropdown]) + controls_layout = widgets.HBox( + [controls_left, controls_right], + layout=widgets.Layout(justify_content="flex-start"), + ) + interactive_ui = widgets.VBox([controls_layout, output]) + update_visualization() + + return interactive_ui diff --git a/qim3d/viz/_detection.py b/qim3d/viz/_detection.py index 5b0ad9fc7059636f5d49a5d6ef4aac02c12fac2b..7839e7d4b74776e492682d9bb000421e82aee53f 100644 --- a/qim3d/viz/_detection.py +++ b/qim3d/viz/_detection.py @@ -1,12 +1,18 @@ +import ipywidgets as widgets import matplotlib.pyplot as plt -from qim3d.utils._logger import log import numpy as np -import ipywidgets as widgets from IPython.display import clear_output, display + import qim3d -def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive: +def circles( + blobs: tuple[float, float, float, float], + vol: np.ndarray, + alpha: float = 0.5, + color: str = '#ff9900', + **kwargs, +) -> widgets.interactive: """ Plots the blobs found on a slice of the volume. @@ -47,23 +53,23 @@ def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue') ```  + """ def _slicer(z_slice): clear_output(wait=True) fig = qim3d.viz.slices_grid( - vol[z_slice:z_slice + 1], + vol[z_slice : z_slice + 1], num_slices=1, - color_map="gray", + color_map='gray', display_figure=False, display_positions=False, - **kwargs + **kwargs, ) # Add circles from deteced blobs for detected in blobs: z, y, x, s = detected if abs(z - z_slice) < s: # The blob is in the slice - # Adjust the radius based on the distance from the center of the sphere distance_from_center = abs(z - z_slice) angle = ( @@ -89,10 +95,10 @@ def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float value=vol.shape[0] // 2, min=0, max=vol.shape[0] - 1, - description="Slice", + description='Slice', continuous_update=True, ) slicer_obj = widgets.interactive(_slicer, z_slice=position_slider) - slicer_obj.layout = widgets.Layout(align_items="flex-start") + slicer_obj.layout = widgets.Layout(align_items='flex-start') return slicer_obj diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py index 6fa107d4a21cebd432bf49e49f18a52db4830788..5d1ecb5532187fbabe2217e2b5eb826a03e7e843 100644 --- a/qim3d/viz/_k3d.py +++ b/qim3d/viz/_k3d.py @@ -7,26 +7,30 @@ Volumetric visualization using K3D """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np from matplotlib.colors import Colormap + from qim3d.utils._logger import log from qim3d.utils._misc import downscale_img, scale_to_float16 - +from pygel3d import hmesh +from pygel3d import jupyter_display as jd +import k3d +from typing import Optional def volumetric( img: np.ndarray, - aspectmode: str = "data", + aspectmode: str = 'data', show: bool = True, save: bool = False, grid_visible: bool = False, color_map: str = 'magma', constant_opacity: bool = False, - vmin: float|None = None, - vmax: float|None = None, - samples: int|str = "auto", + vmin: float | None = None, + vmax: float | None = None, + samples: int | str = 'auto', max_voxels: int = 512**3, - data_type: str = "scaled_float16", + data_type: str = 'scaled_float16', **kwargs, ): """ @@ -81,11 +85,10 @@ def volumetric( ``` """ - import k3d pixel_count = img.shape[0] * img.shape[1] * img.shape[2] # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html - if samples == "auto": + if samples == 'auto': y1, x1 = 256, 16777216 # 256 samples at res 256*256*256=16.777.216 y2, x2 = 32, 134217728 # 32 samples at res 512*512*512=134.217.728 @@ -97,7 +100,7 @@ def volumetric( else: samples = int(samples) # make sure it's an integer - if aspectmode.lower() not in ["data", "cube"]: + if aspectmode.lower() not in ['data', 'cube']: raise ValueError("aspectmode should be either 'data' or 'cube'") # check if image should be downsampled for visualization original_shape = img.shape @@ -107,7 +110,7 @@ def volumetric( if original_shape != new_shape: log.warning( - f"Downsampled image for visualization, from {original_shape} to {new_shape}" + f'Downsampled image for visualization, from {original_shape} to {new_shape}' ) # Scale the image to float16 if needed @@ -115,8 +118,7 @@ def volumetric( # When saving, we need float64 img = img.astype(np.float64) else: - - if data_type == "scaled_float16": + if data_type == 'scaled_float16': img = scale_to_float16(img) else: img = img.astype(data_type) @@ -151,7 +153,7 @@ def volumetric( img, bounds=( [0, img.shape[2], 0, img.shape[1], 0, img.shape[0]] - if aspectmode.lower() == "data" + if aspectmode.lower() == 'data' else None ), color_map=color_map, @@ -164,7 +166,7 @@ def volumetric( plot += plt_volume if save: # Save html to disk - with open(str(save), "w", encoding="utf-8") as fp: + with open(str(save), 'w', encoding='utf-8') as fp: fp.write(plot.get_snapshot()) if show: @@ -172,87 +174,106 @@ def volumetric( else: return plot - def mesh( - verts: np.ndarray, - faces: np.ndarray, + mesh, + backend: str = "pygel3d", wireframe: bool = True, flat_shading: bool = True, grid_visible: bool = False, show: bool = True, save: bool = False, **kwargs, -): - """ - Visualizes a 3D mesh using K3D. - +)-> Optional[k3d.Plot]: + """Visualize a 3D mesh using `pygel3d` or `k3d`. + Args: - verts (numpy.ndarray): A 2D array (Nx3) containing the vertices of the mesh. - faces (numpy.ndarray): A 2D array (Mx3) containing the indices of the mesh faces. - wireframe (bool, optional): If True, the mesh is rendered as a wireframe. Defaults to True. - flat_shading (bool, optional): If True, flat shading is applied to the mesh. Defaults to True. - grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. - show (bool, optional): If True, displays the visualization inline. Defaults to True. + mesh (pygel3d.hmesh.Manifold): The input mesh object. + backend (str, optional): The visualization backend to use. + Choose between `pygel3d` (default) and `k3d`. + wireframe (bool, optional): If True, displays the mesh as a wireframe. + Works both with `pygel3d` and `k3d`. Defaults to True. + flat_shading (bool, optional): If True, applies flat shading to the mesh. + Works only with `k3d`. Defaults to True. + grid_visible (bool, optional): If True, shows a grid in the visualization. + Works only with `k3d`. Defaults to False. + show (bool, optional): If True, displays the visualization inline. + Works only with `k3d`. Defaults to True. save (bool or str, optional): If True, saves the visualization as an HTML file. If a string is provided, it's interpreted as the file path where the HTML - file will be saved. Defaults to False. - **kwargs (Any): Additional keyword arguments to be passed to the `k3d.plot` function. - + file will be saved. Works only with `k3d`. Defaults to False. + **kwargs (Any): Additional keyword arguments specific to the chosen backend: + + - `k3d.plot` kwargs: Arguments that customize the [`k3d.plot`](https://k3d-jupyter.org/reference/factory.plot.html) visualization. + + - `pygel3d.display` kwargs: Arguments that customize the [`pygel3d.display`](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/jupyter_display.html#display) visualization. + Returns: - plot (k3d.plot): If `show=False`, returns the K3D plot object. + k3d.Plot or None: + + - If `backend="k3d"`, returns a `k3d.Plot` object. + - If `backend="pygel3d"`, the function displays the mesh but does not return a plot object. + + Raises: + ValueError: If `backend` is not `pygel3d` or `k3d`. Example: ```python import qim3d - - vol = qim3d.generate.noise_object(base_shape=(128,128,128), - final_shape=(128,128,128), - noise_scale=0.03, - order=1, - gamma=1, - max_value=255, - threshold=0.5, - dtype='uint8' - ) - mesh = qim3d.mesh.from_volume(vol, step_size=3) - qim3d.viz.mesh(mesh.vertices, mesh.faces) + synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015) + mesh = qim3d.mesh.from_volume(synthetic_blob) + qim3d.viz.mesh(mesh, backend="pygel3d") # or qim3d.viz.mesh(mesh, backend="k3d") ``` - <iframe src="https://platform.qim.dk/k3d/mesh_visualization.html" width="100%" height="500" frameborder="0"></iframe> +  """ - import k3d - - # Validate the inputs - if verts.shape[1] != 3: - raise ValueError("Vertices array must have shape (N, 3)") - if faces.shape[1] != 3: - raise ValueError("Faces array must have shape (M, 3)") - - # Ensure the correct data types and memory layout - verts = np.ascontiguousarray( - verts.astype(np.float32) - ) # Cast and ensure C-contiguous layout - faces = np.ascontiguousarray( - faces.astype(np.uint32) - ) # Cast and ensure C-contiguous layout - - # Create the mesh plot - plt_mesh = k3d.mesh( - vertices=verts, - indices=faces, - wireframe=wireframe, - flat_shading=flat_shading, - ) - # Create plot - plot = k3d.plot(grid_visible=grid_visible, **kwargs) - plot += plt_mesh - if save: - # Save html to disk - with open(str(save), "w", encoding="utf-8") as fp: - fp.write(plot.get_snapshot()) + if backend not in ["k3d", "pygel3d"]: + raise ValueError("Invalid backend. Choose 'pygel3d' or 'k3d'.") - if show: - plot.display() - else: - return plot + # Extract vertex positions and face indices + face_indices = list(mesh.faces()) + vertices_array = np.array(mesh.positions()) + + # Extract face vertex indices + face_vertices = [ + list(mesh.circulate_face(int(fid), mode="v"))[:3] for fid in face_indices + ] + face_vertices = np.array(face_vertices, dtype=np.uint32) + + # Validate the mesh structure + if vertices_array.shape[1] != 3 or face_vertices.shape[1] != 3: + raise ValueError("Vertices must have shape (N, 3) and faces (M, 3)") + + # Separate valid kwargs for each backend + valid_k3d_kwargs = {k: v for k, v in kwargs.items() if k not in ["smooth", "data"]} + valid_pygel_kwargs = {k: v for k, v in kwargs.items() if k in ["smooth", "data"]} + + if backend == "k3d": + vertices_array = np.ascontiguousarray(vertices_array.astype(np.float32)) + face_vertices = np.ascontiguousarray(face_vertices) + + mesh_plot = k3d.mesh( + vertices=vertices_array, + indices=face_vertices, + wireframe=wireframe, + flat_shading=flat_shading, + ) + + # Create plot + plot = k3d.plot(grid_visible=grid_visible, **valid_k3d_kwargs) + plot += mesh_plot + + if save: + # Save html to disk + with open(str(save), "w", encoding="utf-8") as fp: + fp.write(plot.get_snapshot()) + + if show: + plot.display() + else: + return plot + + + elif backend == "pygel3d": + jd.set_export_mode(True) + return jd.display(mesh, wireframe=wireframe, **valid_pygel_kwargs) diff --git a/qim3d/viz/_layers2d.py b/qim3d/viz/_layers2d.py index 676845a57180964107c880c08b32ff2e1c5bc2ed..feabcc02e9a0352ec179e4e5bece7ec4f9510b5e 100644 --- a/qim3d/viz/_layers2d.py +++ b/qim3d/viz/_layers2d.py @@ -1,10 +1,9 @@ -""" Provides a collection of visualisation functions for the Layers2d class.""" +"""Provides a collection of visualisation functions for the Layers2d class.""" + import io import matplotlib.pyplot as plt import numpy as np - - from PIL import Image @@ -17,21 +16,21 @@ def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> I lines: list of 1D arrays to be plotted on top of the image line_thickness: how thick is the line supposed to be - Returns: - ---------- - image_with_lines: + Returns + ------- + image_with_lines: + """ fig, ax = plt.subplots() - ax.imshow(image, cmap = 'gray') + ax.imshow(image, cmap='gray') ax.axis('off') for line in lines: - ax.plot(line, linewidth = line_thickness) + ax.plot(line, linewidth=line_thickness) buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close() buf.seek(0) - return Image.open(buf).resize(size = image.squeeze().shape[::-1]) - + return Image.open(buf).resize(size=image.squeeze().shape[::-1]) diff --git a/qim3d/viz/_local_thickness.py b/qim3d/viz/_local_thickness.py index 9336ee32acb23ac4b9e9a43df2c5a1e045b3ea23..2788a092c42c40bd04c419e20c7302b57e9a6a52 100644 --- a/qim3d/viz/_local_thickness.py +++ b/qim3d/viz/_local_thickness.py @@ -1,8 +1,11 @@ -from qim3d.utils._logger import log -import numpy as np -import matplotlib.pyplot as plt -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union + import ipywidgets as widgets +import matplotlib.pyplot as plt +import numpy as np + +from qim3d.utils._logger import log + def local_thickness( image: np.ndarray, @@ -13,7 +16,8 @@ def local_thickness( show: bool = False, figsize: Tuple[int, int] = (15, 5), ) -> Union[plt.Figure, widgets.interactive]: - """Visualizes the local thickness of a 2D or 3D image. + """ + Visualizes the local thickness of a 2D or 3D image. Args: image (np.ndarray): 2D or 3D NumPy array representing the image/volume. @@ -48,7 +52,7 @@ def local_thickness( ```  - + """ def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None): @@ -56,24 +60,24 @@ def local_thickness( image = image.take(slice_idx, axis=axis) image_lt = image_lt.take(slice_idx, axis=axis) - fig, axs = plt.subplots(1, 3, figsize=figsize, layout="constrained") + fig, axs = plt.subplots(1, 3, figsize=figsize, layout='constrained') - axs[0].imshow(image, cmap="gray") - axs[0].set_title("Original image") - axs[0].axis("off") + axs[0].imshow(image, cmap='gray') + axs[0].set_title('Original image') + axs[0].axis('off') - axs[1].imshow(image_lt, cmap="viridis") - axs[1].set_title("Local thickness") - axs[1].axis("off") + axs[1].imshow(image_lt, cmap='viridis') + axs[1].set_title('Local thickness') + axs[1].axis('off') plt.colorbar( - axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical" + axs[1].imshow(image_lt, cmap='viridis'), ax=axs[1], orientation='vertical' ) - axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor="black") - axs[2].set_title("Local thickness histogram") - axs[2].set_xlabel("Local thickness") - axs[2].set_ylabel("Count") + axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor='black') + axs[2].set_title('Local thickness histogram') + axs[2].set_xlabel('Local thickness') + axs[2].set_ylabel('Count') if show: plt.show() @@ -87,7 +91,7 @@ def local_thickness( if max_projection: if slice_idx is not None: log.warning( - "slice_idx is not used for max_projection. It will be ignored." + 'slice_idx is not used for max_projection. It will be ignored.' ) image = image.max(axis=axis) image_lt = image_lt.max(axis=axis) @@ -98,7 +102,7 @@ def local_thickness( elif isinstance(slice_idx, float): if slice_idx < 0 or slice_idx > 1: raise ValueError( - "Values of slice_idx of float type must be between 0 and 1." + 'Values of slice_idx of float type must be between 0 and 1.' ) slice_idx = int(slice_idx * image.shape[0]) - 1 slide_idx_slider = widgets.IntSlider( @@ -106,8 +110,8 @@ def local_thickness( max=image.shape[axis] - 1, step=1, value=slice_idx, - description="Slice index", - layout=widgets.Layout(width="450px"), + description='Slice index', + layout=widgets.Layout(width='450px'), ) widget_obj = widgets.interactive( _local_thickness, @@ -118,15 +122,15 @@ def local_thickness( axis=widgets.fixed(axis), slice_idx=slide_idx_slider, ) - widget_obj.layout = widgets.Layout(align_items="center") + widget_obj.layout = widgets.Layout(align_items='center') if show: display(widget_obj) return widget_obj else: if max_projection: log.warning( - "max_projection is only used for 3D images. It will be ignored." + 'max_projection is only used for 3D images. It will be ignored.' ) if slice_idx is not None: - log.warning("slice_idx is only used for 3D images. It will be ignored.") - return _local_thickness(image, image_lt, show, figsize) \ No newline at end of file + log.warning('slice_idx is only used for 3D images. It will be ignored.') + return _local_thickness(image, image_lt, show, figsize) diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py index 778f4bff7d524bdaddb67cc4833d81d44bd4a0d0..878ba3b69a7cec1af453ed1f44dbbe22bde04967 100644 --- a/qim3d/viz/_metrics.py +++ b/qim3d/viz/_metrics.py @@ -1,21 +1,24 @@ """Visualization tools""" +import matplotlib import matplotlib.figure -import numpy as np import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap +import numpy as np +import torch from matplotlib import colormaps +from matplotlib.colors import LinearSegmentedColormap + from qim3d.utils._logger import log -import matplotlib + def plot_metrics( *metrics: tuple[dict[str, float]], - linestyle: str = "-", - batch_linestyle: str = "dotted", + linestyle: str = '-', + batch_linestyle: str = 'dotted', labels: list | None = None, figsize: tuple = (16, 6), - show: bool = False + show: bool = False, ): """ Plots the metrics over epochs and batches. @@ -35,6 +38,7 @@ def plot_metrics( train_loss = {'epoch_loss' : [...], 'batch_loss': [...]} val_loss = {'epoch_loss' : [...], 'batch_loss': [...]} plot_metrics(train_loss,val_loss, labels=['Train','Valid.']) + """ import seaborn as snb @@ -44,9 +48,9 @@ def plot_metrics( raise ValueError("The number of metrics doesn't match the number of labels.") # plotting parameters - snb.set_style("darkgrid") + snb.set_style('darkgrid') snb.set(font_scale=1.5) - plt.rcParams["lines.linewidth"] = 2 + plt.rcParams['lines.linewidth'] = 2 fig = plt.figure(figsize=figsize) @@ -68,10 +72,10 @@ def plot_metrics( plt.legend() plt.ylabel(metric_name) - plt.xlabel("epoch") + plt.xlabel('epoch') # reset plotting parameters - snb.set_style("white") + snb.set_style('white') if show: plt.show() @@ -81,14 +85,15 @@ def plot_metrics( def grid_overview( - data: list, + data: list | torch.utils.data.Dataset, num_images: int = 7, - cmap_im: str = "gray", - cmap_segm: str = "viridis", + cmap_im: str = 'gray', + cmap_segm: str = 'viridis', alpha: float = 0.5, show: bool = False, ) -> matplotlib.figure.Figure: - """Displays an overview grid of images, labels, and masks (if they exist). + """ + Displays an overview grid of images, labels, and masks (if they exist). Labels are the annotated target segmentations Masks are applied to the output and target prior to the loss calculation in case of @@ -120,6 +125,7 @@ def grid_overview( - The number of displayed images is limited to the minimum between `num_images` and the length of the data. - The grid layout and dimensions vary based on the presence of a mask. + """ import torch @@ -128,12 +134,12 @@ def grid_overview( # Check if image data is RGB and inform the user if it's the case if len(data[0][0].squeeze().shape) > 2: - log.info("Input images are RGB: color map is ignored") + log.info('Input images are RGB: color map is ignored') # Check if dataset have at least specified number of images if len(data) < num_images: log.warning( - "Not enough images in the dataset. Changing num_images=%d to num_images=%d", + 'Not enough images in the dataset. Changing num_images=%d to num_images=%d', num_images, len(data), ) @@ -142,14 +148,14 @@ def grid_overview( # Adapt segmentation cmap so that background is transparent colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) colors_segm[:128, 3] = 0 - custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) + custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm) # Check if data have the right format if not isinstance(data[0], tuple): - raise ValueError("Data elements must be tuples") + raise ValueError('Data elements must be tuples') # Define row titles - row_titles = ["Input images", "Ground truth segmentation", "Mask"] + row_titles = ['Input images', 'Ground truth segmentation', 'Mask'] # Make new list such that possible augmentations remain identical for all three rows plot_data = [data[idx] for idx in range(num_images)] @@ -169,10 +175,10 @@ def grid_overview( if row in [1, 2]: # Ground truth segmentation and mask ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im) ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha) - ax.axis("off") + ax.axis('off') else: ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im) - ax.axis("off") + ax.axis('off') if show: plt.show() @@ -184,12 +190,13 @@ def grid_overview( def grid_pred( in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray], num_images: int = 7, - cmap_im: str = "gray", - cmap_segm: str = "viridis", + cmap_im: str = 'gray', + cmap_segm: str = 'viridis', alpha: float = 0.5, show: bool = False, ) -> matplotlib.figure.Figure: - """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. + """ + Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. Displays a grid of subplots representing different aspects of the input images and segmentations. The grid includes the following rows: @@ -221,25 +228,26 @@ def grid_pred( model = MySegmentationModel() in_targ_preds = qim3d.ml.inference(dataset,model) qim3d.viz.grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5) + """ import torch # Check if dataset have at least specified number of images if len(in_targ_preds[0]) < num_images: log.warning( - "Not enough images in the dataset. Changing num_images=%d to num_images=%d", + 'Not enough images in the dataset. Changing num_images=%d to num_images=%d', num_images, len(in_targ_preds[0]), ) num_images = len(in_targ_preds[0]) # Take only the number of images from in_targ_preds - inputs, targets, preds = [items[:num_images] for items in in_targ_preds] + inputs, targets, preds = (items[:num_images] for items in in_targ_preds) # Adapt segmentation cmap so that background is transparent colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) colors_segm[:128, 3] = 0 - custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) + custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm) N = num_images H = inputs[0].shape[-2] @@ -251,10 +259,10 @@ def grid_pred( comp_rgb[:, 3, :, :] = targets.logical_or(preds) row_titles = [ - "Input images", - "Predicted segmentation", - "Ground truth segmentation", - "True vs. predicted segmentation", + 'Input images', + 'Predicted segmentation', + 'Ground truth segmentation', + 'True vs. predicted segmentation', ] fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True) @@ -269,20 +277,20 @@ def grid_pred( for col, ax in enumerate(np.atleast_1d(axs)): if row == 0: ax.imshow(inputs[col], cmap=cmap_im) - ax.axis("off") + ax.axis('off') elif row == 1: # Predicted segmentation ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha) - ax.axis("off") + ax.axis('off') elif row == 2: # Ground truth segmentation ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha) - ax.axis("off") + ax.axis('off') else: ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha) - ax.axis("off") + ax.axis('off') if show: plt.show() @@ -315,8 +323,8 @@ def vol_masked( """ - background = (vol.astype("float") + viz_delta) * (1 - vol_mask) * -1 - foreground = (vol.astype("float") + viz_delta) * vol_mask + background = (vol.astype('float') + viz_delta) * (1 - vol_mask) * -1 + foreground = (vol.astype('float') + viz_delta) * vol_mask vol_masked_result = background + foreground return vol_masked_result diff --git a/qim3d/viz/_preview.py b/qim3d/viz/_preview.py index a41b7b2ae51ea4154db50804f7940e01e4fe2fe1..c39830d2a37db1f7eff4ec89172c4fa16e1f9f7a 100644 --- a/qim3d/viz/_preview.py +++ b/qim3d/viz/_preview.py @@ -1,13 +1,13 @@ import numpy as np from PIL import Image -# These are fixed because of unicode characters bitmaps. +# These are fixed because of unicode characters bitmaps. # It could only be flexible if each character had a function that generated the bitmap based on size X_STRIDE = 4 -Y_STRIDE = 8 +Y_STRIDE = 8 -BACK_TO_NORMAL = "\u001b[0m" +BACK_TO_NORMAL = '\u001b[0m' END_MARKER = -10 """ @@ -18,77 +18,115 @@ like in a field 4x8. BITMAPS = [ # Block graphics # 0xffff0000, 0x2580, // upper 1/2; redundant with inverse lower 1/2 - 0x00000000, '\u00a0', - 0x0000000f, '\u2581', # lower 1/8 - 0x000000ff, '\u2582', # lower 1/4 - 0x00000fff, '\u2583', - 0x0000ffff, '\u2584', # lower 1/2 - 0x000fffff, '\u2585', - 0x00ffffff, '\u2586', # lower 3/4 - 0x0fffffff, '\u2587', + 0x00000000, + '\u00a0', + 0x0000000F, + '\u2581', # lower 1/8 + 0x000000FF, + '\u2582', # lower 1/4 + 0x00000FFF, + '\u2583', + 0x0000FFFF, + '\u2584', # lower 1/2 + 0x000FFFFF, + '\u2585', + 0x00FFFFFF, + '\u2586', # lower 3/4 + 0x0FFFFFFF, + '\u2587', # 0xffffffff, 0x2588, # full; redundant with inverse space - - 0xeeeeeeee, '\u258a', # left 3/4 - 0xcccccccc, '\u258c', # left 1/2 - 0x88888888, '\u258e', # left 1/4 - - 0x0000cccc, '\u2596', # quadrant lower left - 0x00003333, '\u2597', # quadrant lower right - 0xcccc0000, '\u2598', # quadrant upper left + 0xEEEEEEEE, + '\u258a', # left 3/4 + 0xCCCCCCCC, + '\u258c', # left 1/2 + 0x88888888, + '\u258e', # left 1/4 + 0x0000CCCC, + '\u2596', # quadrant lower left + 0x00003333, + '\u2597', # quadrant lower right + 0xCCCC0000, + '\u2598', # quadrant upper left # 0xccccffff, 0x2599, # 3/4 redundant with inverse 1/4 - 0xcccc3333, '\u259a', # diagonal 1/2 + 0xCCCC3333, + '\u259a', # diagonal 1/2 # 0xffffcccc, 0x259b, # 3/4 redundant # 0xffff3333, 0x259c, # 3/4 redundant - 0x33330000, '\u259d', # quadrant upper right + 0x33330000, + '\u259d', # quadrant upper right # 0x3333cccc, 0x259e, # 3/4 redundant # 0x3333ffff, 0x259f, # 3/4 redundant - # Line drawing subset: no double lines, no complex light lines - - 0x000ff000, '\u2501', # Heavy horizontal - 0x66666666, '\u2503', # Heavy vertical - - 0x00077666, '\u250f', # Heavy down and right - 0x000ee666, '\u2513', # Heavy down and left - 0x66677000, '\u2517', # Heavy up and right - 0x666ee000, '\u251b', # Heavy up and left - - 0x66677666, '\u2523', # Heavy vertical and right - 0x666ee666, '\u252b', # Heavy vertical and left - 0x000ff666, '\u2533', # Heavy down and horizontal - 0x666ff000, '\u253b', # Heavy up and horizontal - 0x666ff666, '\u254b', # Heavy cross - - 0x000cc000, '\u2578', # Bold horizontal left - 0x00066000, '\u2579', # Bold horizontal up - 0x00033000, '\u257a', # Bold horizontal right - 0x00066000, '\u257b', # Bold horizontal down - - 0x06600660, '\u254f', # Heavy double dash vertical - - 0x000f0000, '\u2500', # Light horizontal - 0x0000f000, '\u2500', # - 0x44444444, '\u2502', # Light vertical - 0x22222222, '\u2502', - - 0x000e0000, '\u2574', # light left - 0x0000e000, '\u2574', # light left - 0x44440000, '\u2575', # light up - 0x22220000, '\u2575', # light up - 0x00030000, '\u2576', # light right - 0x00003000, '\u2576', # light right - 0x00004444, '\u2577', # light down - 0x00002222, '\u2577', # light down - - 0x11224488, '\u2571', # diagonals - 0x88442211, '\u2572', - 0x99666699, '\u2573', - - 0, END_MARKER, 0 # End marker + 0x000FF000, + '\u2501', # Heavy horizontal + 0x66666666, + '\u2503', # Heavy vertical + 0x00077666, + '\u250f', # Heavy down and right + 0x000EE666, + '\u2513', # Heavy down and left + 0x66677000, + '\u2517', # Heavy up and right + 0x666EE000, + '\u251b', # Heavy up and left + 0x66677666, + '\u2523', # Heavy vertical and right + 0x666EE666, + '\u252b', # Heavy vertical and left + 0x000FF666, + '\u2533', # Heavy down and horizontal + 0x666FF000, + '\u253b', # Heavy up and horizontal + 0x666FF666, + '\u254b', # Heavy cross + 0x000CC000, + '\u2578', # Bold horizontal left + 0x00066000, + '\u2579', # Bold horizontal up + 0x00033000, + '\u257a', # Bold horizontal right + 0x00066000, + '\u257b', # Bold horizontal down + 0x06600660, + '\u254f', # Heavy double dash vertical + 0x000F0000, + '\u2500', # Light horizontal + 0x0000F000, + '\u2500', # + 0x44444444, + '\u2502', # Light vertical + 0x22222222, + '\u2502', + 0x000E0000, + '\u2574', # light left + 0x0000E000, + '\u2574', # light left + 0x44440000, + '\u2575', # light up + 0x22220000, + '\u2575', # light up + 0x00030000, + '\u2576', # light right + 0x00003000, + '\u2576', # light right + 0x00004444, + '\u2577', # light down + 0x00002222, + '\u2577', # light down + 0x11224488, + '\u2571', # diagonals + 0x88442211, + '\u2572', + 0x99666699, + '\u2573', + 0, + END_MARKER, + 0, # End marker ] + class Color: - def __init__(self, red:int, green:int, blue:int): + def __init__(self, red: int, green: int, blue: int): self.check_value(red) self.check_value(green) self.check_value(blue) @@ -96,19 +134,21 @@ class Color: self.green = green self.blue = blue - def check_value(sel, value:int): - assert isinstance(value, int), F"Color value has to be integer, this is {type(value)}" - assert value < 256, F"Color value has to be between 0 and 255, this is {value}" - assert value >= 0, F"Color value has to be between 0 and 255, this is {value}" - + def check_value(sel, value: int): + assert isinstance( + value, int + ), f'Color value has to be integer, this is {type(value)}' + assert value < 256, f'Color value has to be between 0 and 255, this is {value}' + assert value >= 0, f'Color value has to be between 0 and 255, this is {value}' + def __str__(self): """ Returns the string in ansi color format """ - return F"{self.red};{self.green};{self.blue}" + return f'{self.red};{self.green};{self.blue}' -def chardata(unicodeChar: str, character_color:Color, background_color:Color) -> str: +def chardata(unicodeChar: str, character_color: Color, background_color: Color) -> str: """ Given the character and colors, it creates the string, which when printed in terminal simulates pixels. """ @@ -117,40 +157,42 @@ def chardata(unicodeChar: str, character_color:Color, background_color:Color) -> assert isinstance(character_color, Color) assert isinstance(background_color, Color) assert isinstance(unicodeChar, str) - return F"\033[38;2;{character_color}m\033[48;2;{background_color}m{unicodeChar}" + return f'\033[38;2;{character_color}m\033[48;2;{background_color}m{unicodeChar}' -def get_best_unicode_pattern(bitmap:int) -> tuple[int, str, bool]: + +def get_best_unicode_pattern(bitmap: int) -> tuple[int, str, bool]: """ Goes through the list of unicode characters and looks for the best match for bitmap representing the given segment It computes the difference by counting 1s after XORing the two. If they are identical, the count will be 0. This character will be printed - Parameters: - ----------- + Parameters + ---------- - bitmap (int): int representing the bitmap the image segment. - Returns: - ---------- + Returns + ------- - best_pattern (int): int representing the pattern that was the best match, is then used to calculate colors - unicode (str): the unicode character that represents the given bitmap the best and is then printed - - inverse (bool): The list does't contain unicode characters that are inverse of each other. The match can be achieved by simply using + - inverse (bool): The list does't contain unicode characters that are inverse of each other. The match can be achieved by simply using the inversed bitmap. But then we need to know if we have to switch background and foreground color. + """ best_diff = 8 - best_pattern = 0x0000ffff + best_pattern = 0x0000FFFF unicode = '\u2584' inverse = False - bit_not = lambda n: (1 << 32) - 1 - n + bit_not = lambda n: (1 << 32) - 1 - n i = 0 - while BITMAPS[i+1] != END_MARKER: + while BITMAPS[i + 1] != END_MARKER: pattern = BITMAPS[i] for j in range(2): diff = (pattern ^ bitmap).bit_count() if diff < best_diff: best_pattern = pattern - unicode = BITMAPS[i+1] + unicode = BITMAPS[i + 1] best_diff = diff inverse = bool(j) pattern = bit_not(pattern) @@ -158,17 +200,19 @@ def get_best_unicode_pattern(bitmap:int) -> tuple[int, str, bool]: i += 2 return best_pattern, unicode, inverse - -def int_bitmap_from_ndarray(array_bitmap:np.ndarray)->int: + + +def int_bitmap_from_ndarray(array_bitmap: np.ndarray) -> int: """ Flattens the array Changes all numbers to strings Creates a string representing binary number Casts it to integer """ - return int(F"0b{''.join([str(i) for i in array_bitmap.flatten()])}", base = 2) + return int(f"0b{''.join([str(i) for i in array_bitmap.flatten()])}", base=2) + -def ndarray_from_int_bitmap(bitmap:int, shape:tuple = (8, 4))-> np.ndarray: +def ndarray_from_int_bitmap(bitmap: int, shape: tuple = (8, 4)) -> np.ndarray: """ Gets the binary representation Gets rid of leading '0b @@ -178,67 +222,81 @@ def ndarray_from_int_bitmap(bitmap:int, shape:tuple = (8, 4))-> np.ndarray: """ string = str(bin(bitmap))[2:].zfill(shape[0] * shape[1]) return np.array([int(i) for i in string]).reshape(shape) - -def create_bitmap(image_segment:np.ndarray)->int: + + +def create_bitmap(image_segment: np.ndarray) -> int: """ - Parameters: - ------------ + Parameters + ---------- image_segment: np.ndarray of shape (x, y, 3) - Returns: - ---------- + Returns + ------- bitmap: int, each bit says if the unicode character should cover this bit or not + """ - max_color = np.max(np.max(image_segment, axis=0), axis = 0) - min_color = np.min(np.min(image_segment, axis=0), axis = 0) + max_color = np.max(np.max(image_segment, axis=0), axis=0) + min_color = np.min(np.min(image_segment, axis=0), axis=0) rng = np.absolute(max_color - min_color) max_index = np.argmax(rng) if np.sum(rng) == 0: return 0 - split_threshold = rng[max_index]/2 + min_color[max_index] - bitmap = np.array(image_segment[:, :, max_index] <= split_threshold, dtype = int) - + split_threshold = rng[max_index] / 2 + min_color[max_index] + bitmap = np.array(image_segment[:, :, max_index] <= split_threshold, dtype=int) return int_bitmap_from_ndarray(bitmap) -def get_color(image_segment:np.ndarray, char_array:np.ndarray) -> Color: + +def get_color(image_segment: np.ndarray, char_array: np.ndarray) -> Color: """ Computes the average color of the segment from pixels specified in charr_array The color is then average over the part then unicode character covers or the background - Parameters: - ----------- + Parameters + ---------- - image_segment: 4x8 part of the image with the original values so average color can be calculated - char_array: indices saying which pixels out of the 4x8 should be used for color calculation - Returns: - --------- + Returns + ------- - color: containing the average color over defined pixels + """ colors = [] for channel_index in range(image_segment.shape[2]): - channel = image_segment[:,:,channel_index] + channel = image_segment[:, :, channel_index] colors.append(int(np.average(channel[char_array]))) - return Color(colors[0], colors[1], colors[2]) if len(colors) == 3 else Color(colors[0], colors[0], colors[0]) + return ( + Color(colors[0], colors[1], colors[2]) + if len(colors) == 3 + else Color(colors[0], colors[0], colors[0]) + ) -def get_colors(image_segment:np.ndarray, char_array:np.ndarray) -> tuple[Color, Color]: + +def get_colors( + image_segment: np.ndarray, char_array: np.ndarray +) -> tuple[Color, Color]: """ - Parameters: + Parameters ---------- - image_segment - char_array - Returns: - ---------- + Returns + ------- - Foreground color - Background color + """ - return get_color(image_segment, char_array == 1), get_color(image_segment, char_array == 0) + return get_color(image_segment, char_array == 1), get_color( + image_segment, char_array == 0 + ) -def segment_string(image_segment:np.ndarray)-> str: + +def segment_string(image_segment: np.ndarray) -> str: """ Creates bitmap so its best represent the color distribution Finds the best match in unicode characters @@ -257,42 +315,42 @@ def segment_string(image_segment:np.ndarray)-> str: bg_color, fg_color = fg_color, bg_color return chardata(unicode, fg_color, bg_color) -def image_ansi_string(image:np.ndarray) -> str: + +def image_ansi_string(image: np.ndarray) -> str: """ For each segment 4x8 finds the string with colored unicode character Create the string for whole image - Parameters: - ----------- + Parameters + ---------- - image: image to be displayed in terminal - Returns: - ---------- + Returns + ------- - ansi_string: when printed, will render the image + """ string = [] for y in range(0, image.shape[0], Y_STRIDE): for x in range(0, image.shape[1], X_STRIDE): - - this_segment = image[y:y+Y_STRIDE, x:x+X_STRIDE, :] + this_segment = image[y : y + Y_STRIDE, x : x + X_STRIDE, :] if this_segment.shape[0] != Y_STRIDE: segment = np.zeros((Y_STRIDE, X_STRIDE, this_segment.shape[2])) - segment[:this_segment.shape[0], :, :] = this_segment + segment[: this_segment.shape[0], :, :] = this_segment this_segment = segment string.append(segment_string(this_segment)) - string.append(F"{BACK_TO_NORMAL}\n") + string.append(f'{BACK_TO_NORMAL}\n') return ''.join(string) - - ################################################################### # Image preparation ################################################################### -def rescale_image(image:np.ndarray, size:tuple)->np.ndarray: + +def rescale_image(image: np.ndarray, size: tuple) -> np.ndarray: """ The unicode bitmaps are hardcoded for 4x8 segments, they cannot be scaled Thus the image must be scaled to fit the desired resolution @@ -306,52 +364,57 @@ def rescale_image(image:np.ndarray, size:tuple)->np.ndarray: return image -def check_and_adjust_image_dims(image:np.ndarray) -> np.ndarray: +def check_and_adjust_image_dims(image: np.ndarray) -> np.ndarray: if image.ndim == 2: image = np.expand_dims(image, 2) elif image.ndim == 3: - if image.shape[2] == 1: # grayscale image - pass - elif image.shape[2] == 3: # colorful image + if image.shape[2] == 1 or image.shape[2] == 3: # grayscale image pass - elif image.shape[2] == 4: # contains alpha channel - image = image[:,:,:3] - elif image.shape[0] == 3: # torch images have color channels as the first axis + elif image.shape[2] == 4: # contains alpha channel + image = image[:, :, :3] + elif image.shape[0] == 3: # torch images have color channels as the first axis image = np.moveaxis(image, 0, -1) else: - raise ValueError(F"Image must have 2 (grayscale) or 3 (colorful) dimensions. Yours has {image.ndim}") - + raise ValueError( + f'Image must have 2 (grayscale) or 3 (colorful) dimensions. Yours has {image.ndim}' + ) + return image -def check_and_adjust_values(image:np.ndarray, relative_intensity:bool = True) -> np.ndarray: + +def check_and_adjust_values( + image: np.ndarray, relative_intensity: bool = True +) -> np.ndarray: """ Checks if the values are between 0 and 255 If not, normalizes the values so they are in that interval - Parameters: - ------------- + Parameters + ---------- - image - - relative_intensity: If maximum values are pretty low, they will be barely visible. If true, it normalizes + - relative_intensity: If maximum values are pretty low, they will be barely visible. If true, it normalizes the values, so that the maximum is at 255 - Returns: - ----------- + Returns + ------- - adjusted_image + """ m = np.max(image) if m > 255: - image = np.array(255*image/m, dtype = np.uint8) + image = np.array(255 * image / m, dtype=np.uint8) elif m < 1: - image = np.array(255*image, dtype = np.uint8) + image = np.array(255 * image, dtype=np.uint8) if relative_intensity: m = np.max(image) - image = np.array((image/m)*255, dtype = np.uint8) + image = np.array((image / m) * 255, dtype=np.uint8) return image -def choose_slice(image:np.ndarray, axis:int = None, slice:int = None): + +def choose_slice(image: np.ndarray, axis: int = None, slice: int = None): """ Preview give the possibility to choose axis to be sliced and slice to be displayed """ @@ -359,22 +422,29 @@ def choose_slice(image:np.ndarray, axis:int = None, slice:int = None): image = np.moveaxis(image, axis, -1) if slice is None: - slice = image.shape[2]//2 + slice = image.shape[2] // 2 else: if slice > image.shape[2]: - slice = image.shape[2]-1 - return image[:,:, slice] + slice = image.shape[2] - 1 + return image[:, :, slice] + ################################################################### # Main function ################################################################### -def image_preview(image:np.ndarray, image_width:int = 80, axis:int = None, slice:int = None, relative_intensity:bool = True): + +def image_preview( + image: np.ndarray, + image_width: int = 80, + axis: int = None, + slice: int = None, + relative_intensity: bool = True, +): if image.ndim == 3 and image.shape[2] > 4: image = choose_slice(image, axis, slice) image = check_and_adjust_image_dims(image) - ratio = X_STRIDE*image_width/image.shape[1] + ratio = X_STRIDE * image_width / image.shape[1] image = check_and_adjust_values(image, relative_intensity) - image = rescale_image(image, (X_STRIDE*image_width, int(ratio * image.shape[0]))) + image = rescale_image(image, (X_STRIDE * image_width, int(ratio * image.shape[0]))) print(image_ansi_string(image)) - diff --git a/qim3d/viz/_structure_tensor.py b/qim3d/viz/_structure_tensor.py index 13d45e1f252e139c5a9cd2518280136a821e9443..342285b31eb0e6524334bc157beb8c42b0ab4469 100644 --- a/qim3d/viz/_structure_tensor.py +++ b/qim3d/viz/_structure_tensor.py @@ -1,15 +1,14 @@ -import numpy as np -from typing import Optional, Union, Tuple -import matplotlib.pyplot as plt -from matplotlib.gridspec import GridSpec -import ipywidgets as widgets import logging -from qim3d.utils._logger import log +from typing import Tuple, Union + +import ipywidgets as widgets +import matplotlib.pyplot as plt +import numpy as np +from qim3d.utils._logger import log previous_logging_level = logging.getLogger().getEffectiveLevel() logging.getLogger().setLevel(logging.CRITICAL) -import structure_tensor as st logging.getLogger().setLevel(previous_logging_level) @@ -18,10 +17,10 @@ def vectors( volume: np.ndarray, vec: np.ndarray, axis: int = 0, - volume_cmap:str = 'grey', - vmin: float|None = None, - vmax: float|None = None, - slice_idx: Union[int, float]|None = None, + volume_cmap: str = 'grey', + vmin: float | None = None, + vmax: float | None = None, + slice_idx: Union[int, float] | None = None, grid_size: int = 10, interactive: bool = True, figsize: Tuple[int, int] = (10, 5), @@ -94,10 +93,9 @@ def vectors( if grid_size < min_grid_size or grid_size > max_grid_size: # Adjust grid size as little as possible to be within the limits grid_size = min(max(min_grid_size, grid_size), max_grid_size) - log.warning(f"Adjusting grid size to {grid_size} as it is out of bounds.") + log.warning(f'Adjusting grid size to {grid_size} as it is out of bounds.') def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show): - # Choose the appropriate slice based on the specified dimension if axis == 0: data_slice = volume[slice_idx, :, :] @@ -118,10 +116,10 @@ def vectors( vectors_slice_z = vec[0, :, :, slice_idx] else: - raise ValueError("Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.") + raise ValueError('Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.') # Create three subplots - fig, ax = plt.subplots(1, 3, figsize=figsize, layout="constrained") + fig, ax = plt.subplots(1, 3, figsize=figsize, layout='constrained') blend_hue_saturation = ( lambda hue, sat: hue * (1 - sat) + 0.5 * sat @@ -164,7 +162,7 @@ def vectors( vectors_slice_x[g, g], vectors_slice_y[g, g], color=rgba_quiver_flat, - angles="xy", + angles='xy', ) ax[0].quiver( ymesh[g, g], @@ -172,14 +170,14 @@ def vectors( -vectors_slice_x[g, g], -vectors_slice_y[g, g], color=rgba_quiver_flat, - angles="xy", + angles='xy', ) - ax[0].imshow(data_slice, cmap = volume_cmap, vmin = vmin, vmax = vmax) + ax[0].imshow(data_slice, cmap=volume_cmap, vmin=vmin, vmax=vmax) ax[0].set_title( - f"Orientation vectors (slice {slice_idx})" + f'Orientation vectors (slice {slice_idx})' if not interactive - else "Orientation vectors" + else 'Orientation vectors' ) ax[0].set_axis_off() @@ -218,14 +216,14 @@ def vectors( ) ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin) - ax[1].set_xlabel("Angle [radians]") + ax[1].set_xlabel('Angle [radians]') ax[1].set_xlim([0, np.pi]) ax[1].set_aspect(np.pi / ax[1].get_ylim()[1]) ax[1].set_xticks([0, np.pi / 2, np.pi]) - ax[1].set_xticklabels(["0", "$\\frac{\\pi}{2}$", "$\\pi$"]) + ax[1].set_xticklabels(['0', '$\\frac{\\pi}{2}$', '$\\pi$']) ax[1].set_yticks([]) - ax[1].set_ylabel("Frequency") - ax[1].set_title(f"Histogram over orientation angles") + ax[1].set_ylabel('Frequency') + ax[1].set_title('Histogram over orientation angles') # ----- Subplot 3: Image slice colored according to orientation ----- # # Calculate z-component (saturation) @@ -240,13 +238,13 @@ def vectors( # Grayscale image slice blended with orientation colors data_slice_orientation_colored = ( blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255 - ).astype("uint8") + ).astype('uint8') ax[2].imshow(data_slice_orientation_colored) ax[2].set_title( - f"Colored orientations (slice {slice_idx})" + f'Colored orientations (slice {slice_idx})' if not interactive - else "Colored orientations" + else 'Colored orientations' ) ax[2].set_axis_off() @@ -260,7 +258,7 @@ def vectors( if vec.ndim == 5: vec = vec[0, ...] log.warning( - "Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used." + 'Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used.' ) if slice_idx is None: @@ -269,7 +267,7 @@ def vectors( elif isinstance(slice_idx, float): if slice_idx < 0 or slice_idx > 1: raise ValueError( - "Values of slice_idx of float type must be between 0 and 1." + 'Values of slice_idx of float type must be between 0 and 1.' ) slice_idx = int(slice_idx * volume.shape[0]) - 1 @@ -279,8 +277,8 @@ def vectors( max=volume.shape[axis] - 1, step=1, value=slice_idx, - description="Slice index", - layout=widgets.Layout(width="450px"), + description='Slice index', + layout=widgets.Layout(width='450px'), ) grid_size_slider = widgets.IntSlider( @@ -288,8 +286,8 @@ def vectors( max=max_grid_size, step=1, value=grid_size, - description="Grid size", - layout=widgets.Layout(width="450px"), + description='Grid size', + layout=widgets.Layout(width='450px'), ) widget_obj = widgets.interactive( @@ -305,7 +303,7 @@ def vectors( # Arrange sliders horizontally sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider]) widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]]) - widget_obj.layout.align_items = "center" + widget_obj.layout.align_items = 'center' if show: display(widget_obj) diff --git a/qim3d/viz/colormaps/__init__.py b/qim3d/viz/colormaps/__init__.py index 9807422f1fec28680c17a383a7d7377096220d6c..bc677f09be832ef499392d47a5b5ecaffe63563f 100644 --- a/qim3d/viz/colormaps/__init__.py +++ b/qim3d/viz/colormaps/__init__.py @@ -1,2 +1,2 @@ +from ._qim_colors import qim from ._segmentation import segmentation -from ._qim_colors import qim \ No newline at end of file diff --git a/qim3d/viz/colormaps/_qim_colors.py b/qim3d/viz/colormaps/_qim_colors.py index 3ac7a4005ee2d92f12a00a3588acb482fe308dc5..5dad22517287e7ded72ff17caf91ebac2b7e06f9 100644 --- a/qim3d/viz/colormaps/_qim_colors.py +++ b/qim3d/viz/colormaps/_qim_colors.py @@ -1,9 +1,8 @@ from matplotlib import colormaps from matplotlib.colors import LinearSegmentedColormap - qim = LinearSegmentedColormap.from_list( - "qim", + 'qim', [ (0.6, 0.0, 0.0), # 990000 (1.0, 0.6, 0.0), # ff9900 diff --git a/qim3d/viz/colormaps/_segmentation.py b/qim3d/viz/colormaps/_segmentation.py index 715eb6da5cbdac3c7a50487782e24a40c028b0bd..390a9f84148f2134f7b211c84a599449d5b958c4 100644 --- a/qim3d/viz/colormaps/_segmentation.py +++ b/qim3d/viz/colormaps/_segmentation.py @@ -3,9 +3,10 @@ This module provides a collection of colormaps useful for 3D visualization. """ import colorsys -from typing import Union, Tuple -import numpy as np import math +from typing import Tuple, Union + +import numpy as np from matplotlib.colors import LinearSegmentedColormap @@ -34,7 +35,7 @@ def rearrange_colors(randRGBcolors_old, min_dist=0.5): def segmentation( num_labels: int, - style: str = "bright", + style: str = 'bright', first_color_background: bool = True, last_color_background: bool = False, background_color: Union[Tuple[float, float, float], str] = (0.0, 0.0, 0.0), @@ -90,23 +91,24 @@ def segmentation( ```python qim3d.viz.slices_grid(segmented_volume, color_map = 'objects') ``` - which automatically detects number of unique classes + which automatically detects number of unique classes and creates the colormap object with defualt arguments. Tip: The `min_dist` parameter can be used to control the distance between neighboring colors.  + """ from skimage import color # Check style - if style not in ("bright", "soft", "earth", "ocean"): + if style not in ('bright', 'soft', 'earth', 'ocean'): raise ValueError( f'Please choose "bright", "soft", "earth" or "ocean" for style in qim3dCmap not "{style}"' ) # Translate strings to background color - color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)} + color_dict = {'black': (0.0, 0.0, 0.0), 'white': (1.0, 1.0, 1.0)} if not isinstance(background_color, tuple): try: background_color = color_dict[background_color] @@ -122,7 +124,7 @@ def segmentation( rng = np.random.default_rng(seed) # Generate color map for bright colors, based on hsv - if style == "bright": + if style == 'bright': randHSVcolors = [ ( rng.uniform(low=0.0, high=1), @@ -140,7 +142,7 @@ def segmentation( ) # Generate soft pastel colors, by limiting the RGB spectrum - if style == "soft": + if style == 'soft': low = 0.6 high = 0.95 randRGBcolors = [ @@ -153,7 +155,7 @@ def segmentation( ] # Generate color map for earthy colors, based on LAB - if style == "earth": + if style == 'earth': randLABColors = [ ( rng.uniform(low=25, high=110), @@ -169,7 +171,7 @@ def segmentation( randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist()) # Generate color map for ocean colors, based on LAB - if style == "ocean": + if style == 'ocean': randLABColors = [ ( rng.uniform(low=0, high=110), @@ -195,8 +197,6 @@ def segmentation( randRGBcolors[-1] = background_color # Create colormap - objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=num_labels) + objects = LinearSegmentedColormap.from_list('objects', randRGBcolors, N=num_labels) return objects - - diff --git a/qim3d/viz/itk_vtk_viewer/helpers.py b/qim3d/viz/itk_vtk_viewer/helpers.py index bd426b11ffb0ba217e68b4bd64932107f5ff50b3..4f85f5062bd114490c8f58812cc310d956824024 100644 --- a/qim3d/viz/itk_vtk_viewer/helpers.py +++ b/qim3d/viz/itk_vtk_viewer/helpers.py @@ -1,45 +1,59 @@ -from pathlib import Path import os import platform +from pathlib import Path from typing import Callable import qim3d -class NotInstalledError(Exception): pass -SOURCE_FNM = "fnm env --use-on-cd | Out-String | Invoke-Expression;" +class NotInstalledError(Exception): + pass + + +SOURCE_FNM = 'fnm env --use-on-cd | Out-String | Invoke-Expression;' LINUX = 'Linux' WINDOWS = 'Windows' MAC = 'Darwin' + def get_itk_dir() -> Path: - qim_dir = Path(qim3d.__file__).parents[0] #points to .../qim3d/qim3d/ - dir = qim_dir.joinpath("viz/itk_vtk_viewer") + qim_dir = Path(qim3d.__file__).parents[0] # points to .../qim3d/qim3d/ + dir = qim_dir.joinpath('viz/itk_vtk_viewer') return dir -def get_nvm_dir(dir:Path = None) -> Path: + +def get_nvm_dir(dir: Path = None) -> Path: if platform.system() in [LINUX, MAC]: - following_folder = ".nvm" + following_folder = '.nvm' elif platform.system() == WINDOWS: following_folder = '' - return dir.joinpath(following_folder) if dir is not None else get_qim_dir().joinpath(following_folder) + return ( + dir.joinpath(following_folder) + if dir is not None + else get_qim_dir().joinpath(following_folder) + ) + -def get_node_binaries_dir(nvm_dir:Path = None) -> Path: +def get_node_binaries_dir(nvm_dir: Path = None) -> Path: """ Versions could change in time. This makes sure we use the newest one. For windows we have to pass the argument nvm_dir and it is the itk-vtk_dir """ if platform.system() in [LINUX, MAC]: - following_folder = "versions/node" + following_folder = 'versions/node' binaries_folder = 'bin' elif platform.system() == WINDOWS: following_folder = 'node-versions' binaries_folder = 'installation' - node_folder = nvm_dir.joinpath(following_folder) if nvm_dir is not None else get_nvm_dir().joinpath(following_folder) - + node_folder = ( + nvm_dir.joinpath(following_folder) + if nvm_dir is not None + else get_nvm_dir().joinpath(following_folder) + ) + # We don't wanna throw an error # Instead we return None and check the returned value in run.py if not os.path.isdir(node_folder): @@ -50,19 +64,28 @@ def get_node_binaries_dir(nvm_dir:Path = None) -> Path: path = node_folder.joinpath(name) if os.path.isdir(path): return path.joinpath(binaries_folder) - -def get_viewer_dir(dir:Path = None) -> Path: - following_folder = "viewer_app" - return dir.joinpath(following_folder) if dir is not None else get_qim_dir().joinpath(following_folder) -def get_viewer_binaries(viewer_dir:Path = None) -> Path: + +def get_viewer_dir(dir: Path = None) -> Path: + following_folder = 'viewer_app' + return ( + dir.joinpath(following_folder) + if dir is not None + else get_qim_dir().joinpath(following_folder) + ) + + +def get_viewer_binaries(viewer_dir: Path = None) -> Path: following_folder1 = 'node_modules' following_folder2 = '.bin' if viewer_dir is None: viewer_dir = get_viewer_dir() return viewer_dir.joinpath(following_folder1).joinpath(following_folder2) -def run_for_platform(linux_func:Callable, windows_func:Callable, macos_func:Callable): + +def run_for_platform( + linux_func: Callable, windows_func: Callable, macos_func: Callable +): this_platform = platform.system() if this_platform == LINUX: return linux_func() @@ -70,6 +93,7 @@ def run_for_platform(linux_func:Callable, windows_func:Callable, macos_func:Call return windows_func() elif this_platform == MAC: return macos_func() - + + def lambda_raise(err): raise err diff --git a/qim3d/viz/itk_vtk_viewer/installation.py b/qim3d/viz/itk_vtk_viewer/installation.py index b77375337c8c197131e2209733e2faf2d58a3e56..083329f2e6b3dc1a7af1a11ce2f6afe12cdfacad 100644 --- a/qim3d/viz/itk_vtk_viewer/installation.py +++ b/qim3d/viz/itk_vtk_viewer/installation.py @@ -1,22 +1,35 @@ -from pathlib import Path -import subprocess import os import platform +import subprocess +from pathlib import Path -from .helpers import get_itk_dir, get_nvm_dir, get_node_binaries_dir, get_viewer_dir, SOURCE_FNM, NotInstalledError, run_for_platform - +from .helpers import ( + SOURCE_FNM, + NotInstalledError, + get_itk_dir, + get_node_binaries_dir, + get_nvm_dir, + get_viewer_dir, + run_for_platform, +) class Installer: + """ - Implements installation procedure of itk-vtk-viewer for each OS. + Implements installation procedure of itk-vtk-viewer for each OS. Also goes for minimal installation: checking if the necessary binaries aren't already installed """ + def __init__(self): self.platform = platform.system() - self.install_functions = (self.install_node_manager, self.install_node, self.install_viewer) + self.install_functions = ( + self.install_node_manager, + self.install_node, + self.install_viewer, + ) - self.dir = get_itk_dir() # itk_vtk_viewer folder within qim3d.viz + self.dir = get_itk_dir() # itk_vtk_viewer folder within qim3d.viz # If nvm was already installed, there should be this environment variable # However it could have also been installed via our process, or user deleted the folder but didn't adjusted the bashrc, that's why we check again @@ -32,80 +45,101 @@ class Installer: """ Checks for global and local installation of nvm (Node Version Manager) """ + def _linux() -> bool: - command_f = lambda nvmsh: F'/bin/bash -c "source {nvmsh} && nvm"' + command_f = lambda nvmsh: f'/bin/bash -c "source {nvmsh} && nvm"' if self.os_nvm_dir is not None: nvmsh = self.os_nvm_dir.joinpath('nvm.sh') - output = subprocess.run(command_f(nvmsh), shell = True, capture_output = True) + output = subprocess.run( + command_f(nvmsh), shell=True, capture_output=True + ) if not output.stderr: self.nvm_dir = self.os_nvm_dir return True - + nvmsh = self.qim_nvm_dir.joinpath('nvm.sh') - output = subprocess.run(command_f(nvmsh), shell = True, capture_output = True) + output = subprocess.run(command_f(nvmsh), shell=True, capture_output=True) self.nvm_dir = self.qim_nvm_dir - return not bool(output.stderr) # If there is an error running the above command then it is not installed (not in expected location) - + return not bool( + output.stderr + ) # If there is an error running the above command then it is not installed (not in expected location) + def _windows() -> bool: - output = subprocess.run(['powershell.exe', 'fnm --version'], capture_output=True) + output = subprocess.run( + ['powershell.exe', 'fnm --version'], capture_output=True + ) return not bool(output.stderr) - - return run_for_platform(linux_func=_linux, windows_func=_windows,macos_func= _linux) + + return run_for_platform( + linux_func=_linux, windows_func=_windows, macos_func=_linux + ) @property def is_node_already_installed(self) -> bool: """ Checks for global and local installation of Node.js and npm (Node Package Manager) """ + def _linux() -> bool: # get_node_binaries_dir might return None if the folder is not there # In that case there is 'None' added to the PATH, thats not a problem # the command will return an error to the output and it will be evaluated as not installed - command = F'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm version' + command = f'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm version' - output = subprocess.run(command, shell = True, capture_output = True) + output = subprocess.run(command, shell=True, capture_output=True) return not bool(output.stderr) - + def _windows() -> bool: # Didn't figure out how to install the viewer and run it properly when using global npm return False - - return run_for_platform(linux_func=_linux,windows_func= _windows,macos_func= _linux) - + return run_for_platform( + linux_func=_linux, windows_func=_windows, macos_func=_linux + ) def install(self): """ - First check if some of the binaries are not installed already. + First check if some of the binaries are not installed already. If node.js is already installed (it was able to call npm without raising an error) it only has to install the viewer and doesn't have to go through the process """ if self.is_node_manager_already_installed: self.install_status = 1 - print("Node manager already installed") + print('Node manager already installed') if self.is_node_already_installed: self.install_status = 2 - print("Node.js already installed") - + print('Node.js already installed') + else: self.install_status = 0 - - for install_function in self.install_functions[self.install_status:]: + + for install_function in self.install_functions[self.install_status :]: install_function() def install_node_manager(self): def _linux(): - print(F'Installing Node manager into {self.nvm_dir}...') - _ = subprocess.run([F'export NVM_DIR={self.nvm_dir} && bash {self.dir.joinpath("install_nvm.sh")}'], shell = True, capture_output=True) + print(f'Installing Node manager into {self.nvm_dir}...') + _ = subprocess.run( + [ + f'export NVM_DIR={self.nvm_dir} && bash {self.dir.joinpath("install_nvm.sh")}' + ], + shell=True, + capture_output=True, + ) def _windows(): - print("Installing node manager...") - subprocess.run(["powershell.exe", F'$env:XDG_DATA_HOME = "{self.dir}";', "winget install Schniz.fnm"]) - + print('Installing node manager...') + subprocess.run( + [ + 'powershell.exe', + f'$env:XDG_DATA_HOME = "{self.dir}";', + 'winget install Schniz.fnm', + ] + ) # self._run_for_platform(_linux, None, _windows) - run_for_platform(linux_func=_linux,windows_func= _windows,macos_func= _linux) - print("Node manager installed") + run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux) + print('Node manager installed') def install_node(self): def _linux(): @@ -114,49 +148,63 @@ class Installer: We have to source that file either way, to be able to call nvm function If it was't installed before, we need to export NVM_DIR in order to install npm to correct location """ - print(F'Installing node.js into {self.nvm_dir}...') + print(f'Installing node.js into {self.nvm_dir}...') if self.install_status == 0: nvm_dir = self.nvm_dir - prefix = F'export NVM_DIR={nvm_dir} && ' + prefix = f'export NVM_DIR={nvm_dir} && ' elif self.install_status == 1: nvm_dir = self.os_nvm_dir prefix = '' - + nvmsh = Path(nvm_dir).joinpath('nvm.sh') command = f'{prefix}/bin/bash -c "source {nvmsh} && nvm install 22"' - output = subprocess.run(command, shell = True, capture_output=True) + output = subprocess.run(command, shell=True, capture_output=True) def _windows(): - subprocess.run(["powershell.exe",F'$env:XDG_DATA_HOME = "{self.dir}";', SOURCE_FNM, F"fnm use --fnm-dir {self.dir} --install-if-missing 22"]) - - print(F'Installing node.js...') - run_for_platform(linux_func = _linux, windows_func=_windows, macos_func=_linux) - print("Node.js installed") + subprocess.run( + [ + 'powershell.exe', + f'$env:XDG_DATA_HOME = "{self.dir}";', + SOURCE_FNM, + f'fnm use --fnm-dir {self.dir} --install-if-missing 22', + ] + ) + + print('Installing node.js...') + run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux) + print('Node.js installed') def install_viewer(self): def _linux(): - # Adds local binaries to the path in case we had to install node first (locally into qim folder), but shouldnt interfere even if + # Adds local binaries to the path in case we had to install node first (locally into qim folder), but shouldnt interfere even if # npm is installed globally - command = F'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm install --prefix {self.viewer_dir} itk-vtk-viewer' + command = f'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm install --prefix {self.viewer_dir} itk-vtk-viewer' output = subprocess.run([command], shell=True, capture_output=True) # print(output.stderr) def _windows(): try: node_bin = get_node_binaries_dir(self.dir) - print(F'Installing into {self.viewer_dir}') - subprocess.run(["powershell.exe", F'$env:PATH=$env:PATH + \';{node_bin}\';', F"npm install --prefix {self.viewer_dir} itk-vtk-viewer"], capture_output=True) - except NotInstalledError: # Not installed in qim - subprocess.run(["powershell.exe", SOURCE_FNM, F"npm install itk-vtk-viewer"], capture_output=True) - + print(f'Installing into {self.viewer_dir}') + subprocess.run( + [ + 'powershell.exe', + f"$env:PATH=$env:PATH + ';{node_bin}';", + f'npm install --prefix {self.viewer_dir} itk-vtk-viewer', + ], + capture_output=True, + ) + except NotInstalledError: # Not installed in qim + subprocess.run( + ['powershell.exe', SOURCE_FNM, 'npm install itk-vtk-viewer'], + capture_output=True, + ) self.viewer_dir = get_viewer_dir(self.dir) if not os.path.isdir(self.viewer_dir): os.mkdir(self.viewer_dir) - print(F"Installing itk-vtk-viewer...") + print('Installing itk-vtk-viewer...') run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux) - print("Itk-vtk-viewer installed") - - \ No newline at end of file + print('Itk-vtk-viewer installed') diff --git a/qim3d/viz/itk_vtk_viewer/run.py b/qim3d/viz/itk_vtk_viewer/run.py index 33aade01793b69ec3e0c87d19302cf3bb171670e..8c505f522b6d56ff2a5c6b42d2a3eaa146574c82 100644 --- a/qim3d/viz/itk_vtk_viewer/run.py +++ b/qim3d/viz/itk_vtk_viewer/run.py @@ -1,9 +1,9 @@ -import subprocess -from pathlib import Path import os -import webbrowser +import subprocess import threading import time +import webbrowser +from pathlib import Path import qim3d.utils from qim3d.utils._logger import log @@ -11,10 +11,8 @@ from qim3d.utils._logger import log from .helpers import * from .installation import Installer - - # Start viewer -START_COMMAND = "itk-vtk-viewer -s" +START_COMMAND = 'itk-vtk-viewer -s' # Lock, so two threads can safely read and write to is_installed c = threading.Condition() @@ -23,12 +21,12 @@ is_installed = True def run_global(port=3000): linux_func = lambda: subprocess.run( - START_COMMAND+f" -p {port}", shell=True, stderr=subprocess.DEVNULL + START_COMMAND + f' -p {port}', shell=True, stderr=subprocess.DEVNULL ) # First sourcing the node.js, if sourcing via fnm doesnt help and user would have to do it any other way, it would throw an error and suggest to install viewer to qim library windows_func = lambda: subprocess.run( - ["powershell.exe", SOURCE_FNM, START_COMMAND+f" -p {port}"], + ['powershell.exe', SOURCE_FNM, START_COMMAND + f' -p {port}'], shell=True, stderr=subprocess.DEVNULL, ) @@ -38,7 +36,7 @@ def run_global(port=3000): ) -def run_within_qim_dir(port=3000): +def run_within_qim_dir(port=3000): dir = get_itk_dir() viewer_dir = get_viewer_dir(dir) viewer_bin = get_viewer_binaries(viewer_dir) @@ -48,7 +46,7 @@ def run_within_qim_dir(port=3000): node_bin = get_node_binaries_dir(get_nvm_dir(dir)) if node_bin is None: # Didn't find node binaries there so it looks for enviroment variable to tell it where is nvm folder - node_bin = get_node_binaries_dir(Path(str(os.getenv("NVM_DIR")))) + node_bin = get_node_binaries_dir(Path(str(os.getenv('NVM_DIR')))) if node_bin is not None: subprocess.run( @@ -62,9 +60,9 @@ def run_within_qim_dir(port=3000): if node_bin is not None: subprocess.run( [ - "powershell.exe", + 'powershell.exe', f"$env:PATH = $env:PATH + ';{viewer_bin};{node_bin}';", - START_COMMAND+f" -p {port}", + START_COMMAND + f' -p {port}', ], stderr=subprocess.DEVNULL, ) @@ -118,7 +116,7 @@ def try_opening_itk_vtk( <pre style="margin-left: 12px; margin-right: 12px; color:#454545"> Downloading Okinawa_Foram_1.tif https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Okinawa_Forams/Okinawa_Foram_1.tif - 1.85GB [00:17, 111MB/s] + 1.85GB [00:17, 111MB/s] Loading Okinawa_Foram_1.tif Loading: 100% @@ -149,7 +147,7 @@ def try_opening_itk_vtk( http://localhost:3000/?rotate=false&fileToLoad=http://localhost:8042/Okinawa_Foram_1.zarr </pre> - +  """ @@ -169,7 +167,6 @@ def try_opening_itk_vtk( global is_installed c.acquire() if is_installed: - # Normalize the filename. This is necessary for trailing slashes by the end of the path filename_norm = os.path.normpath(os.path.abspath(filename)) @@ -178,12 +175,12 @@ def try_opening_itk_vtk( os.path.dirname(filename_norm), port=file_server_port ) - viz_url = f"http://localhost:{viewer_port}/?rotate=false&fileToLoad=http://localhost:{file_server_port}/{os.path.basename(filename_norm)}" + viz_url = f'http://localhost:{viewer_port}/?rotate=false&fileToLoad=http://localhost:{file_server_port}/{os.path.basename(filename_norm)}' if open_browser: webbrowser.open_new_tab(viz_url) - log.info(f"\nVisualization url:\n{viz_url}\n") + log.info(f'\nVisualization url:\n{viz_url}\n') c.release() # Start the delayed open in a separate thread @@ -214,28 +211,32 @@ def itk_vtk( filename: str = None, open_browser: bool = True, file_server_port: int = 8042, - viewer_port: int = 3000 - ): + viewer_port: int = 3000, +): """ Command to run in cli/__init__.py. Tries to run the vizualization, if that fails, asks the user to install it. This function is needed - here so we don't have to import NotInstalledError and Installer, + here so we don't have to import NotInstalledError and Installer, which exposes these to user. """ try: - try_opening_itk_vtk(filename, - open_browser=open_browser, - file_server_port = file_server_port, - viewer_port = viewer_port) + try_opening_itk_vtk( + filename, + open_browser=open_browser, + file_server_port=file_server_port, + viewer_port=viewer_port, + ) except NotInstalledError: message = "Itk-vtk-viewer is not installed or qim3d can not find it.\nYou can either:\n\to Use 'qim3d viz SOURCE -m k3d' to display data using different method\n\to Install itk-vtk-viewer yourself following https://kitware.github.io/itk-vtk-viewer/docs/cli.html#Installation\n\to Let qim3D install itk-vtk-viewer now (it will also install node.js in qim3d library)\nDo you want qim3D to install itk-vtk-viewer now?" print(message) - answer = input("[Y/n]:") - if answer in "Yy": + answer = input('[Y/n]:') + if answer in 'Yy': Installer().install() - try_opening_itk_vtk(filename, - open_browser=open_browser, - file_server_port = file_server_port, - viewer_port = viewer_port) \ No newline at end of file + try_opening_itk_vtk( + filename, + open_browser=open_browser, + file_server_port=file_server_port, + viewer_port=viewer_port, + ) diff --git a/requirements.txt b/requirements.txt index 6dc8be727846326f2bfb077f1520caf21dc254c5..c17889959dbab43dd7ada03e9adac084c65f86b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,8 @@ scipy>=1.11.2 seaborn>=0.12.2 pydicom==2.4.4 setuptools>=68.0.0 -imagecodecs==2023.7.10 -tifffile==2023.8.12 +imagecodecs>=2024.12.30 +tifffile>=2025.1.10 torch>=2.0.1 torchvision>=0.15.2 torchinfo>=1.8.0 @@ -31,4 +31,4 @@ ome_zarr>=0.9.0 dask-image>=2024.5.3 trimesh>=4.4.9 slgbuilder>=0.2.1 -testbook>=0.4.2 +PyGEL3D>=0.5.2 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..04e3c4ff65f978e3819145dff45b78896032664f --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,3 @@ +pre-commit>=4.1.0 +ruff>=0.9.3 +testbook>=0.4.2 \ No newline at end of file diff --git a/setup.py b/setup.py index d109c860caea7e854c75e81abe1f0497f164cca1..1da74828d168f805e4ad3d530be16c5865b87244 100644 --- a/setup.py +++ b/setup.py @@ -1,49 +1,48 @@ import os import re + from setuptools import find_packages, setup # Read the contents of your README file -with open("README.md", "r", encoding="utf-8") as f: +with open('README.md', encoding='utf-8') as f: long_description = f.read() + # Read the version from the __init__.py file def read_version(): - with open(os.path.join("qim3d", "__init__.py"), "r", encoding="utf-8") as f: + with open(os.path.join('qim3d', '__init__.py'), encoding='utf-8') as f: version_file = f.read() version_match = re.search(r'^__version__ = ["\']([^"\']*)["\']', version_file, re.M) if version_match: return version_match.group(1) - raise RuntimeError("Unable to find version string.") + raise RuntimeError('Unable to find version string.') + setup( - name="qim3d", + name='qim3d', version=read_version(), - author="Felipe Delestro", - author_email="fima@dtu.dk", - description="QIM tools and user interfaces for volumetric imaging", + author='Felipe Delestro', + author_email='fima@dtu.dk', + description='QIM tools and user interfaces for volumetric imaging', long_description=long_description, - long_description_content_type="text/markdown", - url="https://platform.qim.dk/qim3d", + long_description_content_type='text/markdown', + url='https://platform.qim.dk/qim3d', packages=find_packages(), include_package_data=True, - entry_points = { - 'console_scripts': [ - 'qim3d=qim3d.cli:main' - ] - }, + entry_points={'console_scripts': ['qim3d=qim3d.cli:main']}, classifiers=[ - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "Natural Language :: English", - "Operating System :: OS Independent", - "Topic :: Scientific/Engineering :: Image Processing", - "Topic :: Scientific/Engineering :: Visualization", - "Topic :: Software Development :: User Interfaces", + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Topic :: Scientific/Engineering :: Image Processing', + 'Topic :: Scientific/Engineering :: Visualization', + 'Topic :: Software Development :: User Interfaces', ], - python_requires=">=3.10", + python_requires='>=3.10', install_requires=[ "gradio==4.44", "h5py>=3.9.0", @@ -57,8 +56,8 @@ setup( "scipy>=1.11.2", "seaborn>=0.12.2", "setuptools>=68.0.0", - "tifffile==2023.8.12", - "imagecodecs==2023.7.10", + "tifffile>=2025.1.10", + "imagecodecs>=2024.12.30", "tqdm>=4.65.0", "nibabel>=5.2.0", "ipywidgets>=8.1.2", @@ -72,7 +71,8 @@ setup( "ome_zarr>=0.9.0", "dask-image>=2024.5.3", "scikit-image>=0.24.0", - "trimesh>=4.4.9" + "trimesh>=4.4.9", + "PyGEL3D>=0.5.2" ], extras_require={ "deep-learning": [ @@ -81,6 +81,9 @@ setup( "torchvision>=0.15.2", "torchinfo>=1.8.0", "monai>=1.2.0", + ], + 'test': [ + 'testbook>=0.4.2' ] } )