diff --git a/README.md b/README.md index dcda6af6d8c24b1dde0af913fa9dec28884e6b4a..88203b9eafc9139bd683ec51fde41577917eb834 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# QIM3D (Quantitative Imaging in 3D) +# Qim3D (Quantitative Imaging in 3D) -The `qim3d` library is designed to make it easier to work with 3D imaging data in Python. It offers a range of features, including data loading and manipulation, image processing and filtering, visualization of 3D data, and analysis of imaging results. +The `qim3d` (kɪm θriË diË) library is designed to make it easier to work with 3D imaging data in Python. It offers a range of features, including data loading and manipulation, image processing and filtering, visualization of 3D data, and analysis of imaging results. You can easily load and process 3D image data from various file formats, apply filters and transformations to the data, visualize the results using interactive plots and 3D rendering, and perform quantitative analysis on the images. diff --git a/docs/assets/qim3d.mp3 b/docs/assets/qim3d.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..ccf1bda48006dd28ab17e3aba322acded5ab8f0b Binary files /dev/null and b/docs/assets/qim3d.mp3 differ diff --git a/docs/generate.md b/docs/generate.md index 2082eb33b5025b6e0e34a241cb0b72354bc2b07b..fef26fe50f0381bfffd871e32c2d564e298cb6d4 100644 --- a/docs/generate.md +++ b/docs/generate.md @@ -6,4 +6,4 @@ The `qim3d` library provides a set of methods for generating volumes consisting options: members: - blob - - collection \ No newline at end of file + - collection diff --git a/docs/index.md b/docs/index.md index ea6111ebbe6f4a893f5aaaf65c99f42418f43386..c61719a598a9cb33a76e39352a3dbf429b13934c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,45 +1,81 @@ -# { width="256" } +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css"> + +<audio id="audio" src="assets/qim3d.mp3"></audio> + +<script> +document.addEventListener("DOMContentLoaded", function() { + const audio = document.getElementById("audio"); + const playButton = document.getElementById("playButton"); + + playButton.addEventListener("click", function() { + const icon = playButton.querySelector("i"); + if (audio.paused) { + audio.play(); + icon.classList.remove("fa-circle-play"); + icon.classList.add("fa-circle-pause"); + } else { + audio.pause(); + icon.classList.remove("fa-circle-pause"); + icon.classList.add("fa-circle-play"); + } + }); + + audio.addEventListener("ended", function() { + const icon = playButton.querySelector("i"); + icon.classList.remove("fa-circle-pause"); + icon.classList.add("fa-circle-play"); + }); +}); +</script> + +# { width="25%" } [](https://badge.fury.io/py/qim3d) [](https://pepy.tech/project/qim3d) +The **`qim3d`** (kɪm θriË diË <button id="playButton"><i class="fa-regular fa-circle-play"></i></button>) library is designed for **Quantitative Imaging in 3D** using Python. It offers a range of features, including data loading and manipulation, image processing and filtering, data visualization, and analysis of imaging results. -The `qim3d` library is designed to make it easier to work with 3D imaging data in Python. It offers a range of features, including data loading and manipulation, image processing and filtering, visualization of 3D data, and analysis of imaging results. - -You can easily load and process 3D image data from various file formats, apply filters and transformations to the data, visualize the results using interactive plots and 3D rendering, and perform quantitative analysis on the images. +You can easily load and process 3D image data from various file formats, apply filters and transformations to the data, visualize the results using interactive plots and 3D volumetric rendering. Whether you are working with medical imaging data, materials science data, or any other type of 3D imaging data, `qim3d` provides a convenient and powerful set of tools to help you analyze and understand your data. - -!!! Example +!!! Example "Interactive volume slicer" ```python import qim3d - import qim3d.processing.filters as filters - # Get data - vol = qim3d.examples.fly_150x256x256 + vol = qim3d.examples.bone_128x128x128 + qim3d.viz.slicer(vol) + ``` +  - # Show original - qim3d.viz.slices(vol, show=True) +!!! Example "Synthetic data generation" + ```python + import qim3d - # Create filter pipeline - pipeline = filters.Pipeline( - filters.Median(size=5), - filters.Gaussian(sigma=3)) + # Generate synthetic collection of blobs + num_objects = 15 + synthetic_collection, labels = qim3d.generate.collection(num_objects = num_objects) - # Apply pipeline - filtered_vol = pipeline(vol) + # Visualize synthetic collection + qim3d.viz.vol(synthetic_collection) + ``` + <iframe src="https://platform.qim.dk/k3d/synthetic_collection_default.html" width="100%" height="500" frameborder="0"></iframe> +!!! Example "Structure tensor" + ```python + import qim3d - # Show filtered - qim3d.viz.slices(filtered_vol) + vol = qim3d.examples.NT_128x128x128 + val, vec = qim3d.processing.structure_tensor(vol, visualize = True, axis = 2) ``` -  + +  ## Installation -Creating a `conda` environment is not required but recommended. +### Create environment +Creating a `conda` environment is not required but recommended. ??? info "Miniconda installation and setup" @@ -92,38 +128,35 @@ Creating a `conda` environment is not required but recommended. ``` Once you have `conda` installed, create a new enviroment: -``` -conda create -n qim3d python=3.11 -``` -After the environment is created, activate it by running: -``` -conda activate qim3d -``` - - + conda create -n qim3d python=3.11 +After the environment is created, activate it by running: + conda activate qim3d ### Install using `pip` The latest stable version can be simply installed using `pip`: -``` -pip install qim3d -``` + pip install qim3d !!! note - Installing `qim3d` may take a bit of time due to its dependencies. Thank you for your patience! + The base installation of `qim3d` does not include deep-learning dependencies by design, keeping the library lightweight for scenarios where these dependencies are unnecessary. -### Upgrade + If you need to use deep-learning features, you can install the additional dependencies by running: `pip install qim3d['deep-learning']` + +## Troubleshooting + +### Get the latest version The library is under constant development, so make sure to keep your installation updated: -``` -pip install --upgrade qim3d -``` + + pip install --upgrade qim3d + ## Collaboration -Contributions to `qim3d` are welcome! + +Contributions to `qim3d` are welcome! If you find a bug, have a feature request, or would like to contribute code, please open an issue or submit a pull request. @@ -151,4 +184,4 @@ Below is a list of contributors to the project, arranged in chronological order The development of `qim3d` is supported by: -{ width="256" } \ No newline at end of file +{ width="256" } diff --git a/docs/io.md b/docs/io.md index da1364ecf4492c2337a2d44af0111e24fd5b3b4e..fdb27c4fbf1eb43f9c2c7eccfa8b88eab812a4c5 100644 --- a/docs/io.md +++ b/docs/io.md @@ -8,5 +8,4 @@ Currently, it is possible to directly load `tiff`, `h5`, `nii`,`txm`, `vol` and members: - load - save - - Downloader - - ImgExamples \ No newline at end of file + - Downloader \ No newline at end of file diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index cecdfdbbcc71dea3723fb53903a59ea383be21f1..824d104e4c752cb75864498b15ab69cc38f87110 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,6 +1,6 @@ # Adding notebooks -Jupyter notebooks can be added to this directory, but following this guidelines: +Jupyter notebooks can be added to this directory, but following these guidelines: - File size should be kept under 5MB - Make a clean run of the notebook before saving and pushing - Add descriptions of the intent and processes happening diff --git a/docs/releases.md b/docs/releases.md index 92748cc5425b5a8a9e4b60574d7231189dfe53b7..08c1d2049bcc13552bd51a5c7345a23e7ecf16db 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -9,7 +9,16 @@ As the library is still in its early development stages, **there may be breaking And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that you have the latest features! +### v0.4.0 (29/07/2024) + +- Refactored imports to use lazy loading +- Namespace `utils` reserved for internal tools only +- All deep-learning related functions moved to `models` +- Running `pip install qim3d` does not install DL requirements. For those, use `pip install qim3d['deep-learning']` + + ### v0.3.9 (24/07/2024) + - Loading and saving for Zarr files - File convertion using the CLI, including Zarr - Refactoring for the GUIs @@ -18,14 +27,17 @@ And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that - Introduction of `qim3d.generate.collection` 🎉  ### v0.3.8 (20/06/2024) + - Minor refactoring and bug fixes ### v0.3.7 (17/06/2024) + - Performance improvements when importing - Refactoring for blob detection  ### v0.3.6 (30/05/2024) + - Refactoring for performance improvement - Welcome message for the CLI - Introduction of `qim3d.processing.fade_mask` 🎉  @@ -33,6 +45,7 @@ And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that ### v0.3.5 (27/05/2024) + - Added runtime and memory usage in the documentation - Introduction of `qim3d.utils.generate_volume` 🎉  - CLI refactoring, adding welcome message to the user  @@ -40,6 +53,7 @@ And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that ### v0.3.4 (22/05/2024) + - Documentation for `qim3d.viz.plot_cc` - Fixed issue with Annotation tool and recent Gradio versions - New colormap: `qim3d.viz.colormaps.qim`, showcasing the Qim colors! @@ -52,6 +66,7 @@ And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that - Aspect ratio issue for k3d fixed ### v0.3.3 (11/04/2024) + - Introduction of `qim3d.viz.slicer` (and also `qim3d.viz.orthogonal` ) 🎉 - Introduction of `qim3d.gui.annotation_tool` 🎉 - Introduction of `qim3d.processing.Blob` for blob detection 🎉 diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index d5d3f1dc12f7d5e3b213b22eee651ee3cef48ae2..885bf7a51e3b2b2359385e489d80c47f9c372982 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -100,4 +100,10 @@ code { .md-typeset .example>.admonition-title:after, .md-typeset .example>summary:after { color: #ff9900; -} \ No newline at end of file +} + +#playButton { + cursor: pointer; + margin-top: 0.05em; + vertical-align: middle; +} diff --git a/mkdocs.yml b/mkdocs.yml index 5ef8bb6660eee863a1e91005363cc1694b31efc1..d0c7a0a2117933e039c044b549a55d4beb58a9ed 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,7 +13,6 @@ nav: - Data Generation: generate.md - Processing: processing.md - Visualization: viz.md - - Utils: utils.md - GUIs: gui.md - ML Models: models.md - CLI: cli.md diff --git a/qim3d/__init__.py b/qim3d/__init__.py index 3441196ad28ddde1be69c5d224851c9dd9c49d77..11591ba58f1a6a1d0dfd6d8992b68f82a3fd84e6 100644 --- a/qim3d/__init__.py +++ b/qim3d/__init__.py @@ -2,23 +2,49 @@ The qim3d library is designed to make it easier to work with 3D imaging data in Python. It offers a range of features, including data loading and manipulation, - image processing and filtering, visualization of 3D data, and analysis of imaging results. +image processing and filtering, visualization of 3D data, and analysis of imaging results. Documentation available at https://platform.qim.dk/qim3d/ """ -__version__ = "0.3.9" +__version__ = "0.4.0" -from . import io -from . import gui -from . import viz -from . import utils -from . import processing -from . import generate -# commented out to avoid torch import -# from . import models +import importlib as _importlib -examples = io.ImgExamples() -io.logger.set_level_info() + +class _LazyLoader: + """Lazy loader to load submodules only when they are accessed""" + + def __init__(self, module_name): + self.module_name = module_name + self.module = None + + def _load(self): + if self.module is None: + self.module = _importlib.import_module(self.module_name) + return self.module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + +# List of submodules +_submodules = [ + "examples", + "generate", + "gui", + "io", + "models", + "processing", + "tests", + "utils", + "viz", + "cli", +] + +# Creating lazy loaders for each submodule +for submodule in _submodules: + globals()[submodule] = _LazyLoader(f"qim3d.{submodule}") diff --git a/qim3d/utils/cli.py b/qim3d/cli.py similarity index 79% rename from qim3d/utils/cli.py rename to qim3d/cli.py index 6dfff26d7c2184780b3fddbc92abdb719606ec4f..032751b6ccddcc0c0f9a537f0c08abcbf4bd8d91 100644 --- a/qim3d/utils/cli.py +++ b/qim3d/cli.py @@ -1,22 +1,19 @@ import argparse import webbrowser import outputformat as ouf - -from qim3d.gui import annotation_tool, data_explorer, iso3d, local_thickness -from qim3d.io.loading import DataLoader -from qim3d.utils import image_preview -from qim3d import __version__ as version -import qim3d.io +import qim3d QIM_TITLE = ouf.rainbow( - f"\n _ _____ __ \n ____ _(_)___ ___ |__ /____/ / \n / __ `/ / __ `__ \ /_ </ __ / \n/ /_/ / / / / / / /__/ / /_/ / \n\__, /_/_/ /_/ /_/____/\__,_/ \n /_/ v{version}\n\n", + f"\n _ _____ __ \n ____ _(_)___ ___ |__ /____/ / \n / __ `/ / __ `__ \ /_ </ __ / \n/ /_/ / / / / / / /__/ / /_/ / \n\__, /_/_/ /_/ /_/____/\__,_/ \n /_/ v{qim3d.__version__}\n\n", return_str=True, cmap="hot", ) + def parse_tuple(arg): # Remove parentheses if they are included and split by comma - return tuple(map(int, arg.strip('()').split(','))) + return tuple(map(int, arg.strip("()").split(","))) + def main(): parser = argparse.ArgumentParser(description="Qim3d command-line interface.") @@ -43,9 +40,9 @@ def main(): "--no-browser", action="store_true", help="Do not launch browser." ) - # K3D + # Viz viz_parser = subparsers.add_parser("viz", help="Volumetric visualization.") - viz_parser.add_argument("--source", default=False, help="Path to the image file") + viz_parser.add_argument("filename", default=False, help="Path to the image file") viz_parser.add_argument( "--destination", default="k3d.html", help="Path to save html file." ) @@ -111,41 +108,43 @@ def main(): "--chunks", type=parse_tuple, metavar="Chunk shape", - default=(64,64,64), + default=(64, 64, 64), help="Chunk size for the zarr file. Defaults to (64, 64, 64).", ) args = parser.parse_args() if args.subcommand == "gui": + arghost = args.host inbrowser = not args.no_browser # Should automatically open in browser - + interface = None if args.data_explorer: - interface_class = data_explorer.Interface + interface_class = qim3d.gui.data_explorer.Interface elif args.iso3d: - interface_class = iso3d.Interface + interface_class = qim3d.gui.iso3d.Interface elif args.annotation_tool: - interface_class = annotation_tool.Interface + interface_class = qim3d.gui.annotation_tool.Interface elif args.local_thickness: - interface_class = local_thickness.Interface + interface_class = qim3d.gui.local_thickness.Interface else: - print("Please select a tool by choosing one of the following flags:\n\t--data-explorer\n\t--iso3d\n\t--annotation-tool\n\t--local-thickness") + print( + "Please select a tool by choosing one of the following flags:\n\t--data-explorer\n\t--iso3d\n\t--annotation-tool\n\t--local-thickness" + ) return - interface = interface_class() # called here if we add another arguments to initialize + interface = ( + interface_class() + ) # called here if we add another arguments to initialize if args.platform: - interface.run_interface(host = arghost) + interface.run_interface(host=arghost) else: - interface.launch(inbrowser = inbrowser, force_light_mode = False) + interface.launch(inbrowser=inbrowser, force_light_mode=False) elif args.subcommand == "viz": - if not args.source: - print("Please specify a source file using the argument --source") - return # Load the data - print(f"Loading data from {args.source}") - volume = qim3d.io.load(str(args.source)) + print(f"Loading data from {args.filename}") + volume = qim3d.io.load(str(args.filename)) print(f"Done, volume shape: {volume.shape}") # Make k3d plot @@ -158,9 +157,10 @@ def main(): webbrowser.open_new_tab(args.destination) elif args.subcommand == "preview": - image = DataLoader().load(args.filename) - image_preview( + image = qim3d.io.load(args.filename) + + qim3d.viz.image_preview( image, image_width=args.resolution, axis=args.axis, @@ -169,6 +169,7 @@ def main(): ) elif args.subcommand == "convert": + qim3d.io.convert(args.input_path, args.output_path, chunk_shape=args.chunks) elif args.subcommand is None: @@ -183,9 +184,6 @@ def main(): parser.print_help() print("\n") - elif args.subcommand == 'convert': - qim3d.io.convert(args.input_path, args.output_path) - if __name__ == "__main__": main() diff --git a/qim3d/img_examples/NT_128x128x128.tif b/qim3d/examples/NT_128x128x128.tif similarity index 100% rename from qim3d/img_examples/NT_128x128x128.tif rename to qim3d/examples/NT_128x128x128.tif diff --git a/qim3d/examples/__init__.py b/qim3d/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf57469608720b5ca8c7c7b6b77c6f4cc8cd1f1b --- /dev/null +++ b/qim3d/examples/__init__.py @@ -0,0 +1,17 @@ +""" Example images for testing and demonstration purposes. """ + +from pathlib import Path as _Path +from qim3d.utils.logger import log as _log +from qim3d.io import load as _load + +# Save the original log level and set to ERROR +# to suppress the log messages during loading +_original_log_level = _log.level +_log.setLevel("ERROR") + +# Load image examples +for _file_path in _Path(__file__).resolve().parent.glob("*.tif"): + globals().update({_file_path.stem: _load(_file_path, progress_bar=False)}) + +# Restore the original log level +_log.setLevel(_original_log_level) diff --git a/qim3d/img_examples/bone_128x128x128.tif b/qim3d/examples/bone_128x128x128.tif similarity index 100% rename from qim3d/img_examples/bone_128x128x128.tif rename to qim3d/examples/bone_128x128x128.tif diff --git a/qim3d/img_examples/cement_128x128x128.tif b/qim3d/examples/cement_128x128x128.tif similarity index 100% rename from qim3d/img_examples/cement_128x128x128.tif rename to qim3d/examples/cement_128x128x128.tif diff --git a/qim3d/img_examples/fly_150x256x256.tif b/qim3d/examples/fly_150x256x256.tif similarity index 100% rename from qim3d/img_examples/fly_150x256x256.tif rename to qim3d/examples/fly_150x256x256.tif diff --git a/qim3d/img_examples/shell_225x128x128.tif b/qim3d/examples/shell_225x128x128.tif similarity index 100% rename from qim3d/img_examples/shell_225x128x128.tif rename to qim3d/examples/shell_225x128x128.tif diff --git a/qim3d/generate/collection_.py b/qim3d/generate/collection_.py index 2e78633f30b83782ad95063c6ca9bf3a5a9e189b..f07e500c0d0e7bf51827e9b020528b7dfc3bcb0d 100644 --- a/qim3d/generate/collection_.py +++ b/qim3d/generate/collection_.py @@ -1,11 +1,9 @@ import numpy as np import scipy.ndimage from tqdm.notebook import tqdm -from skimage.filters import threshold_li -from qim3d.generate import blob as generate_blob -from qim3d.processing import get_3d_cc -from qim3d.io.logger import log +import qim3d.generate +from qim3d.utils.logger import log def random_placement( @@ -192,7 +190,7 @@ def collection( num_objects = 15 synthetic_collection, labels = qim3d.generate.collection(num_objects = num_objects) - # Visualize synthetic collection + # Visualize synthetic collection qim3d.viz.vol(synthetic_collection) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_collection_default.html" width="100%" height="500" frameborder="0"></iframe> @@ -214,7 +212,7 @@ def collection( ```python import qim3d - # Generate synthetic collection of dense blobs + # Generate synthetic collection of dense blobs synthetic_collection, labels = qim3d.generate.collection( min_high_value = 255, max_high_value = 255, @@ -225,12 +223,12 @@ def collection( min_gamma = 0.02, max_gamma = 0.02) - # Visualize synthetic collection + # Visualize synthetic collection qim3d.viz.vol(synthetic_collection) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_collection_dense.html" width="100%" height="500" frameborder="0"></iframe> - + Example: ```python @@ -252,7 +250,7 @@ def collection( max_gamma = 0.03 ) - # Visualize synthetic collection + # Visualize synthetic collection qim3d.viz.vol(synthetic_collection) ``` <iframe src="https://platform.qim.dk/k3d/synthetic_collection_tubular.html" width="100%" height="500" frameborder="0"></iframe> @@ -271,7 +269,7 @@ def collection( if len(min_shape) != len(max_shape): raise ValueError("Object shapes must be tuples of the same length") - + # if not isinstance(blob_shapes, list) or \ # len(blob_shapes) != 2 or len(blob_shapes[0]) != 3 or len(blob_shapes[1]) != 3: # raise TypeError("Blob shapes must be a list of two tuples with three dimensions (z, y, x)") @@ -320,7 +318,7 @@ def collection( log.debug(f"- Threshold: {threshold:.3f}") # Generate synthetic blob - blob = generate_blob( + blob = qim3d.generate.blob( base_shape=blob_shape, final_shape=tuple(l * r for l, r in zip(blob_shape, object_shape_zoom)), noise_scale=noise_scale, diff --git a/qim3d/gui/__init__.py b/qim3d/gui/__init__.py index e06589b43ac728f3021a270356279e845eb31abd..f1f7591b8a47e33c4ba90f391fcefc4325f5d0d7 100644 --- a/qim3d/gui/__init__.py +++ b/qim3d/gui/__init__.py @@ -1,4 +1,35 @@ +from fastapi import FastAPI +import qim3d.utils from . import data_explorer from . import iso3d from . import local_thickness -from . import annotation_tool \ No newline at end of file +from . import annotation_tool +from .qim_theme import QimTheme + + +def run_gradio_app(gradio_interface, host="0.0.0.0"): + import gradio as gr + import uvicorn + + # Get port using the QIM API + port_dict = qim3d.utils.get_port_dict() + + if "gradio_port" in port_dict: + port = port_dict["gradio_port"] + elif "port" in port_dict: + port = port_dict["port"] + else: + raise Exception("Port not specified from QIM API") + + qim3d.utils.gradio_header(gradio_interface.title, port) + + # Create FastAPI with mounted gradio interface + app = FastAPI() + path = f"/gui/{port_dict['username']}/{port}/" + app = gr.mount_gradio_app(app, gradio_interface, path=path) + + # Full path + print(f"http://{host}:{port}{path}") + + # Run the FastAPI server usign uvicorn + uvicorn.run(app, host=host, port=int(port)) diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py index 35b6083557d5c53ced3276d986b5909ac842b184..4d2b1f51bb807de6bb226616091588c91d0c4837 100644 --- a/qim3d/gui/data_explorer.py +++ b/qim3d/gui/data_explorer.py @@ -25,8 +25,8 @@ import numpy as np import outputformat as ouf from qim3d.io import load -from qim3d.io.logger import log -from qim3d.utils import internal_tools +from qim3d.utils.logger import log +from qim3d.utils import misc from qim3d.gui.interface import BaseInterface @@ -550,7 +550,7 @@ class Interface(BaseInterface): def show_data_summary(self): summary_dict = { "Last modified": datetime.datetime.fromtimestamp(os.path.getmtime(self.file_path)).strftime("%Y-%m-%d %H:%M"), - "File size": internal_tools.sizeof(os.path.getsize(self.file_path)), + "File size": misc.sizeof(os.path.getsize(self.file_path)), "Z-size": str(self.vol.shape[self.axis_dict["Z"]]), "Y-size": str(self.vol.shape[self.axis_dict["Y"]]), "X-size": str(self.vol.shape[self.axis_dict["X"]]), diff --git a/qim3d/gui/images/qim_platform-icon.svg b/qim3d/gui/images/qim_platform-icon.svg deleted file mode 100644 index b40bab5f4008d6b876a90481e89f603c00b7d530..0000000000000000000000000000000000000000 --- a/qim3d/gui/images/qim_platform-icon.svg +++ /dev/null @@ -1,104 +0,0 @@ -<?xml version="1.0" encoding="UTF-8" standalone="no"?> -<!-- Created with Inkscape (http://www.inkscape.org/) --> - -<svg - width="8.3444281mm" - height="7.7738566mm" - viewBox="0 0 8.344428 7.7738567" - version="1.1" - id="svg1" - inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)" - sodipodi:docname="qim_platform-icon.svg" - inkscape:export-filename="qim_platform-icon.png" - inkscape:export-xdpi="779.25055" - inkscape:export-ydpi="779.25055" - xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" - xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" - xmlns="http://www.w3.org/2000/svg" - xmlns:svg="http://www.w3.org/2000/svg"> - <sodipodi:namedview - id="namedview1" - pagecolor="#ffffff" - bordercolor="#000000" - borderopacity="0.25" - inkscape:showpageshadow="2" - inkscape:pageopacity="0.0" - inkscape:pagecheckerboard="0" - inkscape:deskcolor="#d1d1d1" - inkscape:document-units="mm" - inkscape:zoom="3.1345226" - inkscape:cx="90.603908" - inkscape:cy="61.253347" - inkscape:window-width="1266" - inkscape:window-height="631" - inkscape:window-x="3830" - inkscape:window-y="56" - inkscape:window-maximized="0" - inkscape:current-layer="layer1" /> - <defs - id="defs1" /> - <g - inkscape:label="Layer 1" - inkscape:groupmode="layer" - id="layer1" - transform="translate(-49.212495,-93.140004)"> - <g - id="g95" - transform="matrix(1.0725239,0,0,1.0727401,-76.182898,-88.386314)" - style="stroke-width:0.932287"> - <path - id="path94" - style="fill:#ff9900;fill-opacity:1;stroke:none;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 117.13806,171.16039 v -1.08393 c 0,-0.35299 0.2857,-0.63715 0.64059,-0.63715 v 0 h 6.05523 c 0.35489,0 0.64059,0.28416 0.64059,0.63715 v 1.08393" - sodipodi:nodetypes="csscssc" /> - <path - id="rect75" - style="fill:none;fill-opacity:1;stroke:#000000;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 117.13806,172.211 v -2.13454 c 0,-0.35299 0.2857,-0.63715 0.64059,-0.63715 v 0 h 6.05523 c 0.35489,0 0.64059,0.28416 0.64059,0.63715 v 3.9123" - sodipodi:nodetypes="csscssc" /> - <path - id="path76" - style="fill:none;fill-opacity:1;stroke:#000000;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 124.47447,174.94459 v 0.66053 c 0,0.35299 -0.2857,0.63715 -0.64059,0.63715 h -6.05523 c -0.35489,0 -0.64059,-0.28416 -0.64059,-0.63715 v -2.43526" - sodipodi:nodetypes="cssssc" /> - <path - style="fill:#cd4d00;fill-opacity:1;stroke:#000000;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 118.00096,170.34725 h 0.28389" - id="path72" - sodipodi:nodetypes="cc" /> - <path - style="fill:#cd4d00;fill-opacity:1;stroke:#000000;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 120.63414,170.34725 h 2.93385" - id="path74" - sodipodi:nodetypes="cc" /> - <path - style="fill:#cd4d00;fill-opacity:1;stroke:#000000;stroke-width:0.443769;stroke-linecap:round;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1;paint-order:normal" - d="m 119.04851,170.34725 h 0.28389" - id="path75" - sodipodi:nodetypes="cc" /> - <g - id="g93" - transform="matrix(0.36505828,0,0,0.35972642,83.784309,128.37775)" - style="stroke-width:2.57263"> - <path - id="path91" - d="m 100.83326,125.49607 -5.27832,2.18097 c -0.37814,0.14893 -0.41489,0.7496 -0.0409,0.91436 l 5.17239,2.13082 c 0.4877,0.18116 1.00239,0.18129 1.45479,0 l 5.17239,-2.13082 c 0.37397,-0.16476 0.33723,-0.76543 -0.0409,-0.91436 l -5.27831,-2.18097 c -0.44426,-0.18069 -0.70409,-0.18273 -1.1611,0 z" - fill="#ffab61" - paint-order="normal" - style="fill:#990000;fill-opacity:1;stroke:#000000;stroke-width:1.2246;stroke-dasharray:none;stroke-opacity:1" /> - <path - id="path92" - d="m 100.83326,123.02262 -5.27832,2.18097 c -0.37814,0.14893 -0.41489,0.7496 -0.0409,0.91436 l 5.17239,2.13082 c 0.4877,0.18116 1.00239,0.18129 1.45479,0 l 5.17239,-2.13082 c 0.37397,-0.16476 0.33723,-0.76543 -0.0409,-0.91436 l -5.27831,-2.18097 c -0.44426,-0.18069 -0.70409,-0.18272 -1.1611,0 z" - fill="#ffc861" - paint-order="normal" - style="fill:#cd4d00;fill-opacity:1;stroke:#000000;stroke-width:1.2246;stroke-dasharray:none;stroke-opacity:1" /> - <path - id="path93" - d="m 100.83325,120.5494 -5.27832,2.18098 c -0.37814,0.14893 -0.41489,0.7496 -0.0409,0.91436 l 5.17239,2.13082 c 0.4877,0.18115 1.00239,0.18128 1.45479,0 l 5.17239,-2.13082 c 0.37397,-0.16477 0.33723,-0.76544 -0.0409,-0.91436 l -5.27831,-2.18098 c -0.44426,-0.18069 -0.70409,-0.18271 -1.1611,0 z" - fill="#55a1ff" - paint-order="normal" - style="fill:#ff9900;fill-opacity:1;stroke:#000000;stroke-width:1.2246;stroke-dasharray:none;stroke-opacity:1" /> - </g> - </g> - </g> -</svg> diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py index ccce9a2540c59fa7bebd3039531f6f5aeb252bbd..64e5808c49c4514293701625c15a3455957e229b 100644 --- a/qim3d/gui/interface.py +++ b/qim3d/gui/interface.py @@ -5,48 +5,56 @@ from os import path import gradio as gr from .qim_theme import QimTheme -from qim3d.utils import internal_tools -import qim3d -#TODO: when offline it throws an error in cli +import qim3d.gui + + +# TODO: when offline it throws an error in cli class BaseInterface(ABC): """ Annotation tool and Data explorer as those don't need any examples. """ - def __init__(self, - title:str, - height:int, - width:int = "100%", - verbose: bool = False, - custom_css:str = None): + + def __init__( + self, + title: str, + height: int, + width: int = "100%", + verbose: bool = False, + custom_css: str = None, + ): """ title: Is displayed in tab height, width: If inline in launch method is True, sets the paramters of the widget. Inline defaults to True in py notebooks, otherwise is False verbose: If True, updates are printed into terminal - custom_css: Only the name of the file in the css folder. + custom_css: Only the name of the file in the css folder. """ - + self.title = title self.height = height self.width = width self.verbose = bool(verbose) - self.interface = None + self.interface = None self.qim_dir = Path(qim3d.__file__).parents[0] - self.custom_css = path.join(self.qim_dir, "css", custom_css) if custom_css is not None else None + self.custom_css = ( + path.join(self.qim_dir, "css", custom_css) + if custom_css is not None + else None + ) def set_visible(self): return gr.update(visible=True) - + def set_invisible(self): - return gr.update(visible = False) + return gr.update(visible=False) - def launch(self, img = None, force_light_mode:bool = True, **kwargs): + def launch(self, img=None, force_light_mode: bool = True, **kwargs): """ img: If None, user can upload image after the interface is launched. If defined, the interface will be launched with the image already there - This argument is used especially in jupyter notebooks, where you can launch + This argument is used especially in jupyter notebooks, where you can launch interface in loop with different picture every step - force_light_mode: The qim platform doesn't have night mode. The qim_theme thus + force_light_mode: The qim platform doesn't have night mode. The qim_theme thus has option to display only light mode so it corresponds with the website. Preferably will be removed as we add night mode to the website. """ @@ -54,25 +62,29 @@ class BaseInterface(ABC): # Create gradio interface if img is not None: self.img = img - self.interface = self.create_interface(force_light_mode = force_light_mode) - + self.interface = self.create_interface(force_light_mode=force_light_mode) self.interface.launch( - quiet= not self.verbose, - height = self.height, - width = self.width, - favicon_path = Path(qim3d.__file__).parents[0] / "gui/images/qim_platform-icon.svg", + quiet=not self.verbose, + height=self.height, + width=self.width, + favicon_path=Path(qim3d.__file__).parents[0] + / "../docs/assets/qim3d-icon.svg", **kwargs, ) - + def clear(self): """Used to reset outputs with the clear button""" return None - - def create_interface(self, force_light_mode:bool = True, **kwargs): + + def create_interface(self, force_light_mode: bool = True, **kwargs): # kwargs["img"] = self.img - with gr.Blocks(theme = QimTheme(force_light_mode=force_light_mode), title = self.title, css=self.custom_css) as gradio_interface: - gr.Markdown(F"# {self.title}") + with gr.Blocks( + theme=qim3d.gui.QimTheme(force_light_mode=force_light_mode), + title=self.title, + css=self.custom_css, + ) as gradio_interface: + gr.Markdown(f"# {self.title}") self.define_interface(**kwargs) return gradio_interface @@ -80,19 +92,23 @@ class BaseInterface(ABC): def define_interface(self, **kwargs): pass - def run_interface(self, host:str = "0.0.0.0"): - internal_tools.run_gradio_app(self.create_interface(), host) + def run_interface(self, host: str = "0.0.0.0"): + qim3d.gui.run_gradio_app(self.create_interface(), host) + class InterfaceWithExamples(BaseInterface): """ For Iso3D and Local Thickness """ - def __init__(self, - title:str, - height:int, - width:int, - verbose: bool = False, - custom_css:str = None): + + def __init__( + self, + title: str, + height: int, + width: int, + verbose: bool = False, + custom_css: str = None, + ): super().__init__(title, height, width, verbose, custom_css) self._set_examples_list() @@ -108,6 +124,4 @@ class InterfaceWithExamples(BaseInterface): ] self.img_examples = [] for example in examples: - self.img_examples.append( - [path.join(self.qim_dir, "img_examples", example)] - ) + self.img_examples.append([path.join(self.qim_dir, "img_examples", example)]) diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py index 811af41edd12b15d5639d824e4630c49117c6be1..16bf8ffe67eb08618a1585dd8941b6dcd5b3aef0 100644 --- a/qim3d/gui/iso3d.py +++ b/qim3d/gui/iso3d.py @@ -23,7 +23,7 @@ import plotly.graph_objects as go from scipy import ndimage from qim3d.io import load -from qim3d.io.logger import log +from qim3d.utils.logger import log from qim3d.gui.interface import InterfaceWithExamples diff --git a/qim3d/img_examples/NT_10x200x100.tif b/qim3d/img_examples/NT_10x200x100.tif deleted file mode 100644 index 283a98acf626cb07b5e146c7a017a16e10baa077..0000000000000000000000000000000000000000 Binary files a/qim3d/img_examples/NT_10x200x100.tif and /dev/null differ diff --git a/qim3d/img_examples/blobs_256x256.tif b/qim3d/img_examples/blobs_256x256.tif deleted file mode 100644 index 84bdc64ea2cf8febbf686932b945ce81b155c0c7..0000000000000000000000000000000000000000 Binary files a/qim3d/img_examples/blobs_256x256.tif and /dev/null differ diff --git a/qim3d/img_examples/blobs_256x256x256.tif b/qim3d/img_examples/blobs_256x256x256.tif deleted file mode 100644 index ad5325fdd7ed82b2f5fee1e5b03c6c1b07f893e9..0000000000000000000000000000000000000000 Binary files a/qim3d/img_examples/blobs_256x256x256.tif and /dev/null differ diff --git a/qim3d/io/__init__.py b/qim3d/io/__init__.py index 7efd57b821d824d32f8d39e7bc8403cea044fabd..b9c0e64c9bc055c30f87d779bf40ea8cdd570465 100644 --- a/qim3d/io/__init__.py +++ b/qim3d/io/__init__.py @@ -1,6 +1,6 @@ -from .loading import DataLoader, load, ImgExamples +from .loading import DataLoader, load from .downloader import Downloader from .saving import DataSaver, save from .sync import Sync from .convert import convert -from . import logger \ No newline at end of file +from ..utils import logger diff --git a/qim3d/io/convert.py b/qim3d/io/convert.py index f2af9e6b4f201621bf55190247bc8ac6d90a1038..411c35b274a7b653969515299db5957512a9670e 100644 --- a/qim3d/io/convert.py +++ b/qim3d/io/convert.py @@ -8,7 +8,7 @@ import tifffile as tiff import zarr from tqdm import tqdm -from qim3d.utils.internal_tools import stringify_path +from qim3d.utils.misc import stringify_path from qim3d.io.saving import save diff --git a/qim3d/io/downloader.py b/qim3d/io/downloader.py index 0c347a6c0ebc84cbbbbea9b6c39b32210d0af379..af20024b26419cf1906ee996aa977f853372185c 100644 --- a/qim3d/io/downloader.py +++ b/qim3d/io/downloader.py @@ -8,7 +8,7 @@ from tqdm import tqdm from pathlib import Path from qim3d.io import load -from qim3d.io.logger import log +from qim3d.utils.logger import log import outputformat as ouf diff --git a/qim3d/io/loading.py b/qim3d/io/loading.py index 3bf903528bc750eae1c5178eb1f48ff164089670..9ae4a5e1184e33e10d619d6bbd29998663cae8df 100644 --- a/qim3d/io/loading.py +++ b/qim3d/io/loading.py @@ -23,8 +23,8 @@ from dask import delayed from PIL import Image, UnidentifiedImageError import qim3d -from qim3d.io.logger import log -from qim3d.utils.internal_tools import get_file_size, sizeof, stringify_path +from qim3d.utils.logger import log +from qim3d.utils.misc import get_file_size, sizeof, stringify_path from qim3d.utils.system import Memory from qim3d.utils import ProgressBar @@ -723,10 +723,8 @@ class DataLoader: # Fails else: # Find the closest matching path to warn the user - parent_dir = os.path.dirname(path) or "." - parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" - valid_paths = [os.path.join(parent_dir, file) for file in parent_files] - similar_paths = difflib.get_close_matches(path, valid_paths) + similar_paths = qim3d.utils.misc.find_similar_paths(path) + if similar_paths: suggestion = similar_paths[0] # Get the closest match message = f"Invalid path. Did you mean '{suggestion}'?" @@ -863,39 +861,3 @@ def load( return data - -class ImgExamples: - """Image examples - - Attributes: - blobs_256x256 (numpy.ndarray): A 2D image of blobs. - blobs_256x256x256 (numpy.ndarray): A 3D volume of blobs. - bone_128x128x128 (numpy.ndarray): A 3D volume of bone. - cement_128x128x128 (numpy.ndarray): A 3D volume of cement. - fly_150x256x256 (numpy.ndarray): A 3D volume of a fly. - NT_10x200x100 (numpy.ndarray): A 3D volume of a neuron. - NT_128x128x128 (numpy.ndarray): A 3D volume of a neuron. - shell_225x128x128 (numpy.ndarray): A 3D volume of a shell. - - Tip: - Simply call `qim3d.examples.<name>` to access the image examples. - - Example: - ```python - import qim3d - - vol = qim3d.examples.shell_225x128x128 - qim3d.viz.slices(vol, n_slices=15) - ``` -  - - - - """ - - def __init__(self): - img_examples_path = Path(qim3d.__file__).parents[0] / "img_examples" - img_paths = list(img_examples_path.glob("*.tif")) - - update_dict = {path.stem: load(path, progress_bar = False) for path in img_paths} - self.__dict__.update(update_dict) diff --git a/qim3d/io/saving.py b/qim3d/io/saving.py index 4a4c3fc1413c83e02c86d0414e81e681dfc986d4..ed5012aad8e8b4090c4d0ea34f5d55833673b120 100644 --- a/qim3d/io/saving.py +++ b/qim3d/io/saving.py @@ -35,8 +35,8 @@ import zarr from pydicom.dataset import FileDataset, FileMetaDataset from pydicom.uid import UID -from qim3d.io.logger import log -from qim3d.utils.internal_tools import sizeof, stringify_path +from qim3d.utils.logger import log +from qim3d.utils.misc import sizeof, stringify_path class DataSaver: diff --git a/qim3d/io/sync.py b/qim3d/io/sync.py index c7358cf85a7e18fb1638e137f865a87726588089..3992cd33f29acf26949f44b0e89e244ced0aa20e 100644 --- a/qim3d/io/sync.py +++ b/qim3d/io/sync.py @@ -2,7 +2,7 @@ import os import subprocess import outputformat as ouf -from qim3d.io.logger import log +from qim3d.utils.logger import log from pathlib import Path diff --git a/qim3d/models/__init__.py b/qim3d/models/__init__.py index b5701e8190c734eaa5fd5af05595c8afc107d5a6..a145a44fd76a561174a7f28154bd20dc0134d6fd 100644 --- a/qim3d/models/__init__.py +++ b/qim3d/models/__init__.py @@ -1 +1,4 @@ -from .unet import UNet, Hyperparameters \ No newline at end of file +from .unet import UNet, Hyperparameters +from .augmentations import Augmentation +from .data import Dataset, prepare_dataloaders, prepare_datasets +from .models import inference, model_summary, train_model diff --git a/qim3d/utils/augmentations.py b/qim3d/models/augmentations.py similarity index 100% rename from qim3d/utils/augmentations.py rename to qim3d/models/augmentations.py diff --git a/qim3d/utils/data.py b/qim3d/models/data.py similarity index 99% rename from qim3d/utils/data.py rename to qim3d/models/data.py index ae91017997608628c632cc53bcc9d8b0f99f0548..fc61d262f753ebc637c27e7d60e38d0730a714c8 100644 --- a/qim3d/utils/data.py +++ b/qim3d/models/data.py @@ -1,7 +1,7 @@ """Provides a custom Dataset class for building a PyTorch dataset.""" from pathlib import Path from PIL import Image -from qim3d.io.logger import log +from qim3d.utils.logger import log import torch import numpy as np diff --git a/qim3d/utils/models.py b/qim3d/models/models.py similarity index 82% rename from qim3d/utils/models.py rename to qim3d/models/models.py index 82cfc78918eda7d9e7d84a9253b1d6b36aaa8c61..f04052498907ec6f0bda81097c93bfada7faf8ed 100644 --- a/qim3d/utils/models.py +++ b/qim3d/models/models.py @@ -1,19 +1,28 @@ """ Tools performed with models.""" + import torch import numpy as np -import matplotlib.pyplot as plt from torchinfo import summary -from qim3d.io.logger import log, level -from qim3d.viz.visualizations import plot_metrics +from qim3d.utils.logger import log +from qim3d.viz.metrics import plot_metrics from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True, return_loss = False): - """ Function for training Neural Network models. - +def train_model( + model, + hyperparameters, + train_loader, + val_loader, + eval_every=1, + print_every=5, + plot=True, + return_loss=False, +): + """Function for training Neural Network models. + Args: model (torch.nn.Module): PyTorch model. hyperparameters (class): Dictionary with n_epochs, optimizer and criterion. @@ -23,17 +32,17 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 print_every (int, optional): Frequency of log for model performance. Defaults to every 5 epochs. plot (bool, optional): If True, plots the training and validation loss after the model is done training. return_loss (bool, optional), If True, returns a dictionary with the history of the train and validation losses. - + Returns: if return_loss = True: tuple: train_loss (dict): Dictionary with average losses and batch losses for training loop. val_loss (dict): Dictionary with average losses and batch losses for validation loop. - + Example: # defining the model. model = qim3d.utils.UNet() - + # choosing the hyperparameters hyperparameters = qim3d.utils.hyperparameters(model) @@ -45,76 +54,76 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 train_loss,val_loss = train_model(model, hyperparameters, train_loader, val_loader) """ params_dict = hyperparameters() - n_epochs = params_dict['n_epochs'] - optimizer = params_dict['optimizer'] - criterion = params_dict['criterion'] + n_epochs = params_dict["n_epochs"] + optimizer = params_dict["optimizer"] + criterion = params_dict["criterion"] # Choosing best device available. - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - + # Avoid logging twice. log.propagate = False - - train_loss = {'loss' : [],'batch_loss': []} - val_loss = {'loss' : [], 'batch_loss' : []} + + train_loss = {"loss": [], "batch_loss": []} + val_loss = {"loss": [], "batch_loss": []} with logging_redirect_tqdm(): - for epoch in tqdm(range(n_epochs)): + for epoch in tqdm(range(n_epochs)): epoch_loss = 0 step = 0 - + model.train() - + for data in train_loader: inputs, targets = data inputs = inputs.to(device) targets = targets.to(device).unsqueeze(1) - + optimizer.zero_grad() outputs = model(inputs) - + loss = criterion(outputs, targets) - + # Backpropagation loss.backward() optimizer.step() - + epoch_loss += loss.detach().item() step += 1 # Log and store batch training loss. - train_loss['batch_loss'].append(loss.detach().item()) - + train_loss["batch_loss"].append(loss.detach().item()) + # Log and store average training loss per epoch. epoch_loss = epoch_loss / step - train_loss['loss'].append(epoch_loss) - - if epoch % eval_every ==0: + train_loss["loss"].append(epoch_loss) + + if epoch % eval_every == 0: eval_loss = 0 step = 0 - + model.eval() - + for data in val_loader: inputs, targets = data inputs = inputs.to(device) targets = targets.to(device).unsqueeze(1) - + with torch.no_grad(): outputs = model(inputs) loss = criterion(outputs, targets) - + eval_loss += loss.item() step += 1 # Log and store batch validation loss. - val_loss['batch_loss'].append(loss.item()) - + val_loss["batch_loss"].append(loss.item()) + # Log and store average validation loss. eval_loss = eval_loss / step - val_loss['loss'].append(eval_loss) - + val_loss["loss"].append(eval_loss) + if epoch % print_every == 0: log.info( f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, " @@ -122,13 +131,13 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 ) if plot: - plot_metrics(train_loss, val_loss, labels = ['Train','Valid.'], show = True) + plot_metrics(train_loss, val_loss, labels=["Train", "Valid."], show=True) if return_loss: - return train_loss,val_loss + return train_loss, val_loss -def model_summary(dataloader,model): +def model_summary(dataloader, model): """Prints the summary of a PyTorch model. Args: @@ -144,23 +153,23 @@ def model_summary(dataloader,model): summary = model_summary(model, dataloader) print(summary) """ - images,_ = next(iter(dataloader)) + images, _ = next(iter(dataloader)) batch_size = tuple(images.shape) - model_s = summary(model,batch_size,depth = torch.inf) - + model_s = summary(model, batch_size, depth=torch.inf) + return model_s -def inference(data,model): +def inference(data, model): """Performs inference on input data using the specified model. - + Performs inference on the input data using the provided model. The input data should be in the form of a list, where each item is a tuple containing the input image tensor and the corresponding target label tensor. The function checks the format and validity of the input data, ensures the model is in evaluation mode, and generates predictions using the model. The input images, target labels, and predicted labels are returned as a tuple. - + Args: data (torch.utils.data.Dataset): A Torch dataset containing input image and ground truth label data. @@ -181,7 +190,7 @@ def inference(data,model): model = MySegmentationModel() inference(data,model) """ - + # Get device device = "cuda" if torch.cuda.is_available() else "cpu" @@ -205,7 +214,7 @@ def inference(data,model): # Make new list such that possible augmentations remain identical for all three rows plot_data = [data[idx] for idx in range(len(data))] - + # Create input and target batch inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device) targets = torch.stack([item[1] for item in plot_data], dim=0) @@ -218,7 +227,9 @@ def inference(data,model): inputs = inputs.cpu().squeeze() targets = targets.squeeze() if outputs.shape[1] == 1: - preds = outputs.cpu().squeeze() > 0.5 # TODO: outputs from model are not between [0,1] yet, need to implement that + preds = ( + outputs.cpu().squeeze() > 0.5 + ) # TODO: outputs from model are not between [0,1] yet, need to implement that else: preds = outputs.cpu().argmax(axis=1) @@ -227,12 +238,12 @@ def inference(data,model): inputs = inputs.unsqueeze(0) targets = targets.unsqueeze(0) preds = preds.unsqueeze(0) - - return inputs,targets,preds + + return inputs, targets, preds def volume_inference(volume, model, threshold=0.5): - ''' + """ Compute on the entire volume Args: volume (numpy.ndarray): A 3D numpy array representing the input volume. @@ -242,10 +253,10 @@ def volume_inference(volume, model, threshold=0.5): numpy.ndarray: A 3D numpy array representing the model predictions for each slice of the input volume. Raises: ValueError: If the input volume is not a 3D numpy array. - ''' + """ if len(volume.shape) != 3: raise ValueError("Input volume must be a 3D numpy array") - + device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() @@ -257,9 +268,9 @@ def volume_inference(volume, model, threshold=0.5): input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device) input_tensor = input_tensor.unsqueeze(0) output = model(input_tensor) > threshold - output = output.cpu() if device == 'cuda' else output + output = output.cpu() if device == "cuda" else output output_detached = output.detach() output_numpy = output_detached.numpy()[0, 0, :, :] inference_vol[idx] = output_numpy - return inference_vol \ No newline at end of file + return inference_vol diff --git a/qim3d/models/unet.py b/qim3d/models/unet.py index 6ca19bcb890b96513960262f661af7d27e0e225a..3d25a78af9f2ecc5305807228c39d790e06c6960 100644 --- a/qim3d/models/unet.py +++ b/qim3d/models/unet.py @@ -2,7 +2,7 @@ import torch.nn as nn -from qim3d.io.logger import log +from qim3d.utils.logger import log class UNet(nn.Module): diff --git a/qim3d/processing/cc.py b/qim3d/processing/cc.py index 48cc004a56bfade258ab9bc370ed60f6a072728a..43283c230fd71cd36f841861f58c28a1a3ec7777 100644 --- a/qim3d/processing/cc.py +++ b/qim3d/processing/cc.py @@ -1,7 +1,6 @@ import numpy as np -import torch from scipy.ndimage import find_objects, label -from qim3d.io.logger import log +from qim3d.utils.logger import log class CC: @@ -70,11 +69,11 @@ class CC: return find_objects(self._connected_components) -def get_3d_cc(image: np.ndarray | torch.Tensor) -> CC: +def get_3d_cc(image: np.ndarray) -> CC: """ Returns an object (CC) containing the connected components of the input volume. Use plot_cc to visualize the connected components. Args: - image (np.ndarray | torch.Tensor): An array-like object to be labeled. Any non-zero values in `input` are + image (np.ndarray): An array-like object to be labeled. Any non-zero values in `input` are counted as features and zero values are considered the background. Returns: diff --git a/qim3d/processing/detection.py b/qim3d/processing/detection.py index b967e84472d5ce15e0ac80268425deac9187efb9..e3650cf111b71427e5e673b952244f9064e74c36 100644 --- a/qim3d/processing/detection.py +++ b/qim3d/processing/detection.py @@ -1,7 +1,7 @@ """ Blob detection using Difference of Gaussian (DoG) method """ import numpy as np -from qim3d.io.logger import log +from qim3d.utils.logger import log def blob_detection( vol: np.ndarray, diff --git a/qim3d/processing/filters.py b/qim3d/processing/filters.py index fc2d89134ded623e2e478f24f75a4e3735224169..e188c475fce5ce5a89bcc2fb79cf80ddcf340f7a 100644 --- a/qim3d/processing/filters.py +++ b/qim3d/processing/filters.py @@ -6,7 +6,7 @@ import numpy as np from scipy import ndimage from skimage import morphology -from qim3d.io.logger import log +from qim3d.utils.logger import log __all__ = [ "Gaussian", diff --git a/qim3d/processing/local_thickness_.py b/qim3d/processing/local_thickness_.py index 7a9e30aaeb4f00d14cebd34d46c656f75638b015..96968519c0f9723db9828ea9798b2f00e9834d30 100644 --- a/qim3d/processing/local_thickness_.py +++ b/qim3d/processing/local_thickness_.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional -from qim3d.io.logger import log +from qim3d.utils.logger import log import qim3d diff --git a/qim3d/processing/operations.py b/qim3d/processing/operations.py index 72d3c45c44754f46710e8e2347bbee2b15968d6a..70e15ed71987f19de76cba680e1fbaf544ef33ff 100644 --- a/qim3d/processing/operations.py +++ b/qim3d/processing/operations.py @@ -1,6 +1,6 @@ import numpy as np import qim3d.processing.filters as filters -from qim3d.io.logger import log +from qim3d.utils.logger import log def remove_background( diff --git a/qim3d/tests/__init__.py b/qim3d/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c08e68dde17af7ac5c45d2fab81447f168750f6a --- /dev/null +++ b/qim3d/tests/__init__.py @@ -0,0 +1,117 @@ +"Helper functions for testing" + +import os +import matplotlib +import matplotlib.pyplot as plt +from pathlib import Path +import shutil +from PIL import Image +import socket +import numpy as np +from qim3d.utils.logger import log + + +def mock_plot(): + """Creates a mock plot of a sine wave. + + Returns: + matplotlib.figure.Figure: The generated plot figure. + + Example: + Creates a mock plot of a sine wave and displays the plot using `plt.show()`. + + >>> fig = mock_plot() + >>> plt.show() + """ + + matplotlib.use("Agg") + + fig = plt.figure(figsize=(5, 4)) + axes = fig.add_axes([0.1, 0.1, 0.8, 0.8]) + values = np.arange(0, 2 * np.pi, 0.01) + axes.plot(values, np.sin(values)) + + return fig + + +def mock_write_file(path, content="File created by qim3d"): + """ + Creates a file at the specified path and writes a predefined text into it. + + Args: + path (str): The path to the file to be created. + + Example: + >>> mock_write_file("example.txt") + """ + _file = open(path, "w", encoding="utf-8") + _file.write(content) + _file.close() + + +def is_server_running(ip, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect((ip, int(port))) + s.shutdown(2) + return True + except: + return False + + +def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): + """Creates a temporary folder to test deep learning tools. + + Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'. + n random images are then added to all four subfolders. + If the 'remove' variable is True, the folders and their content are removed. + + Args: + folder (str): The path where the folders should be placed. + remove (bool, optional): If True, all folders are removed from their location. + n (int, optional): Number of random images and labels in the temporary dataset. + img_shape (tuple, options): Tuple with the height and width of the images and labels. + + Example: + >>> tempdata('temporary_folder',n = 10, img_shape = (16,16)) + """ + folder_trte = ["train", "test"] + sub_folders = ["images", "labels"] + + # Creating train/test folder + path_train = Path(folder) / folder_trte[0] + path_test = Path(folder) / folder_trte[1] + + # Creating folders for images and labels + path_train_im = path_train / sub_folders[0] + path_train_lab = path_train / sub_folders[1] + path_test_im = path_test / sub_folders[0] + path_test_lab = path_test / sub_folders[1] + + # Random image + img = np.random.randint(2, size=img_shape, dtype=np.uint8) + img = Image.fromarray(img) + + if not os.path.exists(path_train): + os.makedirs(path_train_im) + os.makedirs(path_test_im) + os.makedirs(path_train_lab) + os.makedirs(path_test_lab) + for i in range(n): + img.save(path_train_im / f"img_train{i}.png") + img.save(path_train_lab / f"img_train{i}.png") + img.save(path_test_im / f"img_test{i}.png") + img.save(path_test_lab / f"img_test{i}.png") + + if remove: + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + log.warning("Failed to delete %s. Reason: %s" % (file_path, e)) + + os.rmdir(folder) diff --git a/qim3d/tests/gui/test_annotation_tool.py b/qim3d/tests/gui/test_annotation_tool.py index 2b80b9fb178df7d57f6e55633152b78e7876da20..06a0bedfb869766c91277951f15ac4ddc7ab64f6 100644 --- a/qim3d/tests/gui/test_annotation_tool.py +++ b/qim3d/tests/gui/test_annotation_tool.py @@ -6,7 +6,7 @@ import time def test_starting_class(): app = qim3d.gui.annotation_tool.Interface() - assert app.title == "Annotation tool" + assert app.title == "Annotation Tool" def start_server(ip, port): @@ -27,7 +27,7 @@ def test_app_launch(): check = 0 server_running = False while check < max_checks and not server_running: - server_running = qim3d.utils.internal_tools.is_server_running(ip, port) + server_running = qim3d.tests.is_server_running(ip, port) time.sleep(1) check += 1 diff --git a/qim3d/tests/gui/test_iso3d.py b/qim3d/tests/gui/test_iso3d.py index 864ede391c6fb4300ac0b3d4bb54afef44dfbd9c..1d2e4955c396cfb217c77eb416f5dcbbfa6431d4 100644 --- a/qim3d/tests/gui/test_iso3d.py +++ b/qim3d/tests/gui/test_iso3d.py @@ -27,7 +27,7 @@ def test_app_launch(): check = 0 server_running = False while check < max_checks and not server_running: - server_running = qim3d.utils.internal_tools.is_server_running(ip, port) + server_running = qim3d.tests.is_server_running(ip, port) time.sleep(1) check += 1 diff --git a/qim3d/tests/io/test_load.py b/qim3d/tests/io/test_load.py index 8df32d00150e8e8daa82c142c0effd4edb59f4a2..88d0f092b9576ae293c26254de7bf0ee85e14fce 100644 --- a/qim3d/tests/io/test_load.py +++ b/qim3d/tests/io/test_load.py @@ -5,33 +5,36 @@ import os import pytest import re -# Load blobs volume into memory -vol = qim3d.examples.blobs_256x256 +# Load volume into memory +vol = qim3d.examples.bone_128x128x128 + +# Ceate memory map to blobs +volume_path = Path(qim3d.__file__).parents[0] / "examples" / "bone_128x128x128.tif" +vol_memmap = qim3d.io.load(volume_path, virtual_stack=True) -# Ceate memory map to blobs -blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif" -vol_memmap = qim3d.io.load(blobs_path,virtual_stack=True) def test_load_shape(): - assert vol.shape == vol_memmap.shape == (256,256) - + assert vol.shape == vol_memmap.shape == (128, 128, 128) + + def test_load_type(): - assert isinstance(vol,np.ndarray) + assert isinstance(vol, np.ndarray) + def test_load_type_memmap(): - assert isinstance(vol_memmap,np.memmap) + assert isinstance(vol_memmap, np.memmap) + def test_invalid_path(): - invalid_path = os.path.join('this','path','doesnt','exist.tif') + invalid_path = os.path.join("this", "path", "doesnt", "exist.tif") - with pytest.raises(ValueError,match='Invalid path'): + with pytest.raises(FileNotFoundError): qim3d.io.load(invalid_path) + def test_did_you_mean(): # Remove last two characters from the path - blobs_path_misspelled = str(blobs_path)[:-2] - - message = f"Invalid path. Did you mean '{blobs_path}'?" + path_misspelled = str(volume_path)[:-2] - with pytest.raises(ValueError,match=re.escape(repr(message))): - qim3d.io.load(blobs_path_misspelled) \ No newline at end of file + with pytest.raises(FileNotFoundError, match=re.escape(repr(str(volume_path)))): + qim3d.io.load(path_misspelled) diff --git a/qim3d/tests/io/test_save.py b/qim3d/tests/io/test_save.py index ad61587213c36a52e569f88726a7fddef3257ebb..d6cbfb3e69aa2cbcbbd6098e08cd67cd99b1c1b7 100644 --- a/qim3d/tests/io/test_save.py +++ b/qim3d/tests/io/test_save.py @@ -25,7 +25,7 @@ def test_image_exist(): def test_compression(): # Get test image (should not be random in order for compression to function) - test_image = qim3d.examples.blobs_256x256 + test_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -66,7 +66,7 @@ def test_image_matching(): def test_compressed_image_matching(): # Get test image (should not be random in order for compression to function) - original_image = qim3d.examples.blobs_256x256 + original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -226,7 +226,7 @@ def test_tiff_stack_slicing_dim(): def test_tiff_save_load(): # Create random test image - original_image = qim3d.examples.blobs_256x256 + original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -247,7 +247,7 @@ def test_tiff_save_load(): def test_vol_save_load(): # Create random test image - original_image = qim3d.examples.blobs_256x256x256 + original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -268,7 +268,7 @@ def test_vol_save_load(): def test_pil_save_load(): # Create random test image - original_image = qim3d.examples.blobs_256x256 + original_image = qim3d.examples.bone_128x128x128[0] # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -295,7 +295,7 @@ def test_pil_save_load(): def test_nifti_save_load(): # Create random test image - original_image = qim3d.examples.blobs_256x256 + original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: @@ -329,7 +329,7 @@ def test_nifti_save_load(): def test_h5_save_load(): # Create random test image - original_image = qim3d.examples.blobs_256x256x256 + original_image = qim3d.examples.bone_128x128x128 # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: diff --git a/qim3d/tests/models/test_models.py b/qim3d/tests/models/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..83543b76299d400ac6c848fd6e6213a428e1bfd0 --- /dev/null +++ b/qim3d/tests/models/test_models.py @@ -0,0 +1,113 @@ +import qim3d +import pytest +from torch import ones + +from qim3d.tests import temp_data + + +# unit test for model summary() +def test_model_summary(): + n = 10 + img_shape = (32, 32) + folder = "folder_data" + temp_data(folder, img_shape=img_shape, n=n) + + unet = qim3d.models.UNet(size="small") + augment = qim3d.models.Augmentation(transform_train=None) + train_set, val_set, test_set = qim3d.models.prepare_datasets( + folder, 1 / 3, unet, augment + ) + + _, val_loader, _ = qim3d.models.prepare_dataloaders( + train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False + ) + summary = qim3d.models.model_summary(val_loader, unet) + + assert summary.input_size[0] == (1, 1) + img_shape + + temp_data(folder, remove=True) + + +# unit test for inference() +def test_inference(): + folder = "folder_data" + temp_data(folder) + + unet = qim3d.models.UNet(size="small") + augment = qim3d.models.Augmentation(transform_train=None) + train_set, _, _ = qim3d.models.prepare_datasets(folder, 1 / 3, unet, augment) + + _, targ, _ = qim3d.models.inference(train_set, unet) + + assert tuple(targ[0].unique()) == (0, 1) + + temp_data(folder, remove=True) + + +# unit test for tuple ValueError(). +def test_inference_tuple(): + folder = "folder_data" + temp_data(folder) + + unet = qim3d.models.UNet(size="small") + + data = [1, 2, 3] + with pytest.raises(ValueError, match="Data items must be tuples"): + qim3d.models.inference(data, unet) + + temp_data(folder, remove=True) + + +# unit test for tensor ValueError(). +def test_inference_tensor(): + folder = "folder_data" + temp_data(folder) + + unet = qim3d.models.UNet(size="small") + + data = [(1, 2)] + with pytest.raises(ValueError, match="Data items must consist of tensors"): + qim3d.models.inference(data, unet) + + temp_data(folder, remove=True) + + +# unit test for dimension ValueError(). +def test_inference_dim(): + folder = "folder_data" + temp_data(folder) + + unet = qim3d.models.UNet(size="small") + + data = [(ones(1), ones(1))] + # need the r"" for special characters + with pytest.raises(ValueError, match=r"Input image must be \(C,H,W\) format"): + qim3d.models.inference(data, unet) + + temp_data(folder, remove=True) + + +# unit test for train_model() +def test_train_model(): + folder = "folder_data" + temp_data(folder) + + n_epochs = 1 + + unet = qim3d.models.UNet(size="small") + augment = qim3d.models.Augmentation(transform_train=None) + hyperparams = qim3d.models.Hyperparameters(unet, n_epochs=n_epochs) + train_set, val_set, test_set = qim3d.models.prepare_datasets( + folder, 1 / 3, unet, augment + ) + train_loader, val_loader, _ = qim3d.models.prepare_dataloaders( + train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False + ) + + train_loss, _ = qim3d.models.train_model( + unet, hyperparams, train_loader, val_loader, plot=False, return_loss=True + ) + + assert len(train_loss["loss"]) == n_epochs + + temp_data(folder, remove=True) diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py index da6c490d2706011a4dddfe274bb661e3dc303127..289ebcb4a723c6ab9432ca375d51e4a4cd01479a 100644 --- a/qim3d/tests/utils/test_augmentations.py +++ b/qim3d/tests/utils/test_augmentations.py @@ -2,31 +2,40 @@ import qim3d import albumentations import pytest + # unit tests for Augmentation() def test_augmentation(): - augment_class = qim3d.utils.Augmentation() + augment_class = qim3d.models.Augmentation() + + assert augment_class.resize == "crop" - assert augment_class.resize == 'crop' def test_augment(): - augment_class = qim3d.utils.Augmentation() + augment_class = qim3d.models.Augmentation() - album_augment = augment_class.augment(256,256) + album_augment = augment_class.augment(256, 256) assert type(album_augment) == albumentations.core.composition.Compose + # unit tests for ValueErrors in Augmentation() def test_resize(): - resize_str = 'not valid resize' + resize_str = "not valid resize" - with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."): - augment_class = qim3d.utils.Augmentation(resize = resize_str) + with pytest.raises( + ValueError, + match=f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'.", + ): + augment_class = qim3d.models.Augmentation(resize=resize_str) def test_levels(): - augment_class = qim3d.utils.Augmentation() + augment_class = qim3d.models.Augmentation() - level = 'Not a valid level' + level = "Not a valid level" - with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): - augment_class.augment(256,256,level) \ No newline at end of file + with pytest.raises( + ValueError, + match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.", + ): + augment_class.augment(256, 256, level) diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py index 928ef26f2f785b7760c073f2519f37eb9ebb61f6..32b290a86b55a0c23a7d75d551559437b3df7b7e 100644 --- a/qim3d/tests/utils/test_data.py +++ b/qim3d/tests/utils/test_data.py @@ -2,7 +2,7 @@ import qim3d import pytest from torch.utils.data.dataloader import DataLoader -from qim3d.utils.internal_tools import temp_data +from qim3d.tests import temp_data # unit tests for Dataset() def test_dataset(): @@ -10,7 +10,7 @@ def test_dataset(): folder = 'folder_data' temp_data(folder, img_shape = img_shape) - images = qim3d.utils.Dataset(folder) + images = qim3d.models.Dataset(folder) assert images[0][0].shape == img_shape @@ -19,19 +19,19 @@ def test_dataset(): # unit tests for check_resize() def test_check_resize(): - h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6) + h_adjust,w_adjust = qim3d.models.data.check_resize(240,240,resize = 'crop',n_channels = 6) assert (h_adjust,w_adjust) == (192,192) def test_check_resize_pad(): - h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'padding',n_channels = 6) + h_adjust,w_adjust = qim3d.models.data.check_resize(16,16,resize = 'padding',n_channels = 6) assert (h_adjust,w_adjust) == (64,64) def test_check_resize_fail(): with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."): - h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'crop',n_channels = 6) + h_adjust,w_adjust = qim3d.models.data.check_resize(16,16,resize = 'crop',n_channels = 6) # unit tests for prepare_datasets() @@ -43,8 +43,8 @@ def test_prepare_datasets(): img = temp_data(folder,n = n) my_model = qim3d.models.UNet() - my_augmentation = qim3d.utils.Augmentation(transform_test='light') - train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation) + my_augmentation = qim3d.models.Augmentation(transform_test='light') + train_set, val_set, test_set = qim3d.models.prepare_datasets(folder,validation,my_model,my_augmentation) assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) @@ -56,7 +56,7 @@ def test_validation(): validation = 10 with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."): - augment_class = qim3d.utils.prepare_datasets('folder',validation,'my_model','my_augmentation') + augment_class = qim3d.models.prepare_datasets('folder',validation,'my_model','my_augmentation') # unit test for prepare_dataloaders() @@ -66,10 +66,10 @@ def test_prepare_dataloaders(): batch_size = 1 my_model = qim3d.models.UNet() - my_augmentation = qim3d.utils.Augmentation() - train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation) + my_augmentation = qim3d.models.Augmentation() + train_set, val_set, test_set = qim3d.models.prepare_datasets(folder,1/3,my_model,my_augmentation) - _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, + _,val_loader,_ = qim3d.models.prepare_dataloaders(train_set,val_set,test_set, batch_size,num_workers = 1, pin_memory = False) diff --git a/qim3d/tests/utils/test_internal_tools.py b/qim3d/tests/utils/test_helpers.py similarity index 67% rename from qim3d/tests/utils/test_internal_tools.py rename to qim3d/tests/utils/test_helpers.py index ed318e7a5c6a93ca2ca0caba2085bb3076e62add..04ae8e63db39d1b47a9fe1f6ae24aecc29c29ac4 100644 --- a/qim3d/tests/utils/test_internal_tools.py +++ b/qim3d/tests/utils/test_helpers.py @@ -5,7 +5,7 @@ from pathlib import Path def test_mock_plot(): - fig = qim3d.utils.internal_tools.mock_plot() + fig = qim3d.tests.mock_plot() assert fig.get_figwidth() == 5.0 @@ -13,7 +13,7 @@ def test_mock_plot(): def test_mock_write_file(): filename = "test.txt" content = "test file" - qim3d.utils.internal_tools.mock_write_file(filename, content=content) + qim3d.tests.mock_write_file(filename, content=content) # Check contents with open(filename, "r", encoding="utf-8") as f: @@ -33,21 +33,21 @@ def test_get_local_ip(): else: return False - local_ip = qim3d.utils.internal_tools.get_local_ip() - + local_ip = qim3d.utils.misc.get_local_ip() + assert validate_ip(local_ip) == True + def test_stringify_path1(): - """Test that the function converts os.PathLike objects to strings - """ + """Test that the function converts os.PathLike objects to strings""" blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif" - - assert str(blobs_path) == qim3d.utils.internal_tools.stringify_path(blobs_path) + + assert str(blobs_path) == qim3d.utils.misc.stringify_path(blobs_path) + def test_stringify_path2(): - """Test that the function returns input unchanged if input is a string - """ - # Create test_path - test_path = os.path.join('this','path','doesnt','exist.tif') + """Test that the function returns input unchanged if input is a string""" + # Create test_path + test_path = os.path.join("this", "path", "doesnt", "exist.tif") - assert test_path == qim3d.utils.internal_tools.stringify_path(test_path) + assert test_path == qim3d.utils.misc.stringify_path(test_path) diff --git a/qim3d/tests/utils/test_models.py b/qim3d/tests/utils/test_models.py deleted file mode 100644 index 37262ad6517b3e603c972de6526e7219cccf2611..0000000000000000000000000000000000000000 --- a/qim3d/tests/utils/test_models.py +++ /dev/null @@ -1,107 +0,0 @@ -import qim3d -import pytest -from torch import ones - -from qim3d.utils.internal_tools import temp_data - -# unit test for model summary() -def test_model_summary(): - n = 10 - img_shape = (32,32) - folder = 'folder_data' - temp_data(folder,img_shape=img_shape,n = n) - - unet = qim3d.models.UNet(size = 'small') - augment = qim3d.utils.Augmentation(transform_train=None) - train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) - - _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, - batch_size = 1,num_workers = 1, - pin_memory = False) - summary = qim3d.utils.model_summary(val_loader,unet) - - assert summary.input_size[0] == (1,1) + img_shape - - temp_data(folder,remove=True) - - -# unit test for inference() -def test_inference(): - folder = 'folder_data' - temp_data(folder) - - unet = qim3d.models.UNet(size = 'small') - augment = qim3d.utils.Augmentation(transform_train=None) - train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) - - _, targ,_ = qim3d.utils.inference(train_set,unet) - - assert tuple(targ[0].unique()) == (0,1) - - temp_data(folder,remove=True) - - -#unit test for tuple ValueError(). -def test_inference_tuple(): - folder = 'folder_data' - temp_data(folder) - - unet = qim3d.models.UNet(size = 'small') - - data = [1,2,3] - with pytest.raises(ValueError,match="Data items must be tuples"): - qim3d.utils.inference(data,unet) - - temp_data(folder,remove=True) - - -#unit test for tensor ValueError(). -def test_inference_tensor(): - folder = 'folder_data' - temp_data(folder) - - unet = qim3d.models.UNet(size = 'small') - - data = [(1,2)] - with pytest.raises(ValueError,match="Data items must consist of tensors"): - qim3d.utils.inference(data,unet) - - temp_data(folder,remove=True) - - -#unit test for dimension ValueError(). -def test_inference_dim(): - folder = 'folder_data' - temp_data(folder) - - unet = qim3d.models.UNet(size = 'small') - - data = [(ones(1),ones(1))] - # need the r"" for special characters - with pytest.raises(ValueError,match=r"Input image must be \(C,H,W\) format"): - qim3d.utils.inference(data,unet) - - temp_data(folder,remove=True) - - -# unit test for train_model() -def test_train_model(): - folder = 'folder_data' - temp_data(folder) - - n_epochs = 1 - - unet = qim3d.models.UNet(size = 'small') - augment = qim3d.utils.Augmentation(transform_train=None) - hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs) - train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) - train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, - batch_size = 1,num_workers = 1, - pin_memory = False) - - train_loss,_ = qim3d.utils.train_model(unet,hyperparams,train_loader,val_loader, - plot = False, return_loss = True) - - assert len(train_loss['loss']) == n_epochs - - temp_data(folder,remove=True) \ No newline at end of file diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 84dd4ee591c54c98d9694df83059dc7cc320a912..bcbed39d1cfca3dc5179c8b33c19527a9a5db81e 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -7,45 +7,46 @@ import pytest from torch import ones import qim3d -from qim3d.utils.internal_tools import temp_data +from qim3d.tests import temp_data import matplotlib.pyplot as plt import ipywidgets as widgets + # unit tests for grid overview def test_grid_overview(): - random_tuple = (torch.ones(1,256,256),torch.ones(256,256)) - n_images = 10 + random_tuple = (torch.ones(1, 256, 256), torch.ones(256, 256)) + n_images = 10 train_set = [random_tuple for t in range(n_images)] - fig = qim3d.viz.grid_overview(train_set,num_images=n_images) - assert fig.get_figwidth() == 2*n_images + fig = qim3d.viz.grid_overview(train_set, num_images=n_images) + assert fig.get_figwidth() == 2 * n_images def test_grid_overview_tuple(): - random_tuple = (torch.ones(256,256),torch.ones(256,256)) + random_tuple = (torch.ones(256, 256), torch.ones(256, 256)) - with pytest.raises(ValueError,match="Data elements must be tuples"): - qim3d.viz.grid_overview(random_tuple,num_images=1) + with pytest.raises(ValueError, match="Data elements must be tuples"): + qim3d.viz.grid_overview(random_tuple, num_images=1) # unit tests for grid prediction def test_grid_pred(): - folder = 'folder_data' + folder = "folder_data" n = 4 - temp_data(folder,n = n) + temp_data(folder, n=n) model = qim3d.models.UNet() - augmentation = qim3d.utils.Augmentation() - train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation) + augmentation = qim3d.models.Augmentation() + train_set, _, _ = qim3d.models.prepare_datasets(folder, 0.1, model, augmentation) - in_targ_pred = qim3d.utils.models.inference(train_set,model) + in_targ_pred = qim3d.models.inference(train_set, model) fig = qim3d.viz.grid_pred(in_targ_pred) - - assert (fig.get_figwidth(),fig.get_figheight()) == (2*(n),10) - temp_data(folder,remove = True) + assert (fig.get_figwidth(), fig.get_figheight()) == (2 * (n), 10) + + temp_data(folder, remove=True) # unit tests for slices function @@ -54,47 +55,68 @@ def test_slices_numpy_array_input(): fig = qim3d.viz.slices(example_volume, n_slices=1) assert isinstance(fig, plt.Figure) -def test_slices_torch_tensor_input(): - example_volume = torch.ones((10,10,10)) - img_width = 3 - fig = qim3d.viz.slices(example_volume,n_slices = 1) - assert isinstance(fig, plt.Figure) def test_slices_wrong_input_format(): - input = 'not_a_volume' - with pytest.raises(ValueError, match = 'Data type not supported'): + input = "not_a_volume" + with pytest.raises(ValueError, match="Data type not supported"): qim3d.viz.slices(input) + def test_slices_not_volume(): - example_volume = np.ones((10,10)) - with pytest.raises(ValueError, match = 'The provided object is not a volume as it has less than 3 dimensions.'): + example_volume = np.ones((10, 10)) + with pytest.raises( + ValueError, + match="The provided object is not a volume as it has less than 3 dimensions.", + ): qim3d.viz.slices(example_volume) + def test_slices_wrong_position_format1(): - example_volume = np.ones((10,10,10)) - with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): - qim3d.viz.slices(example_volume, position = 'invalid_slice') + example_volume = np.ones((10, 10, 10)) + with pytest.raises( + ValueError, + match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', + ): + qim3d.viz.slices(example_volume, position="invalid_slice") + def test_slices_wrong_position_format2(): - example_volume = np.ones((10,10,10)) - with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): - qim3d.viz.slices(example_volume, position = 1.5) + example_volume = np.ones((10, 10, 10)) + with pytest.raises( + ValueError, + match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', + ): + qim3d.viz.slices(example_volume, position=1.5) + def test_slices_wrong_position_format3(): - example_volume = np.ones((10,10,10)) - with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): - qim3d.viz.slices(example_volume, position = [1, 2, 3.5]) + example_volume = np.ones((10, 10, 10)) + with pytest.raises( + ValueError, + match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".', + ): + qim3d.viz.slices(example_volume, position=[1, 2, 3.5]) + def test_slices_invalid_axis_value(): - example_volume = np.ones((10,10,10)) - with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"): - qim3d.viz.slices(example_volume, axis = 3) + example_volume = np.ones((10, 10, 10)) + with pytest.raises( + ValueError, + match="Invalid value for 'axis'. It should be an integer between 0 and 2", + ): + qim3d.viz.slices(example_volume, axis=3) + def test_slices_interpolation_option(): - example_volume = torch.ones((10, 10, 10)) + example_volume = np.ones((10, 10, 10)) img_width = 3 - interpolation_method = 'bilinear' - fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, interpolation=interpolation_method) + interpolation_method = "bilinear" + fig = qim3d.viz.slices( + example_volume, + n_slices=1, + img_width=img_width, + interpolation=interpolation_method, + ) for ax in fig.get_axes(): # Access the interpolation method used for each Axes object @@ -103,6 +125,7 @@ def test_slices_interpolation_option(): # Assert that the actual interpolation method matches the expected method assert actual_interpolation == interpolation_method + def test_slices_multiple_slices(): example_volume = np.ones((10, 10, 10)) img_width = 3 @@ -111,20 +134,37 @@ def test_slices_multiple_slices(): # Add assertions for the expected number of subplots in the figure assert len(fig.get_axes()) == n_slices + def test_slices_axis_argument(): # Non-symmetric input example_volume = np.arange(1000).reshape((10, 10, 10)) img_width = 3 # Call the function with different values of the axis - fig_axis_0 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=0) - fig_axis_1 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=1) - fig_axis_2 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=2) + fig_axis_0 = qim3d.viz.slices( + example_volume, n_slices=1, img_width=img_width, axis=0 + ) + fig_axis_1 = qim3d.viz.slices( + example_volume, n_slices=1, img_width=img_width, axis=1 + ) + fig_axis_2 = qim3d.viz.slices( + example_volume, n_slices=1, img_width=img_width, axis=2 + ) # Ensure that different axes result in different plots - assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array()) - assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array()) - assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array()) + assert not np.allclose( + fig_axis_0.get_axes()[0].images[0].get_array(), + fig_axis_1.get_axes()[0].images[0].get_array(), + ) + assert not np.allclose( + fig_axis_1.get_axes()[0].images[0].get_array(), + fig_axis_2.get_axes()[0].images[0].get_array(), + ) + assert not np.allclose( + fig_axis_2.get_axes()[0].images[0].get_array(), + fig_axis_0.get_axes()[0].images[0].get_array(), + ) + # unit tests for slicer function def test_slicer_with_numpy_array(): @@ -135,6 +175,7 @@ def test_slicer_with_numpy_array(): # Assert that the slicer object is created successfully assert isinstance(slicer_obj, widgets.interactive) + def test_slicer_with_torch_tensor(): # Create a sample PyTorch tensor vol = torch.rand(10, 10, 10) @@ -143,6 +184,7 @@ def test_slicer_with_torch_tensor(): # Assert that the slicer object is created successfully assert isinstance(slicer_obj, widgets.interactive) + def test_slicer_with_different_parameters(): # Test with different axis values for axis in range(3): @@ -156,14 +198,19 @@ def test_slicer_with_different_parameters(): # Test with different image sizes for img_height, img_width in [(2, 2), (4, 4)]: - slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width) + slicer_obj = qim3d.viz.slicer( + np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width + ) assert isinstance(slicer_obj, widgets.interactive) # Test with show_position set to True and False for show_position in [True, False]: - slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), show_position=show_position) + slicer_obj = qim3d.viz.slicer( + np.random.rand(10, 10, 10), show_position=show_position + ) assert isinstance(slicer_obj, widgets.interactive) + # unit tests for orthogonal function def test_orthogonal_with_numpy_array(): # Create a sample NumPy array @@ -173,6 +220,7 @@ def test_orthogonal_with_numpy_array(): # Assert that the orthogonal object is created successfully assert isinstance(orthogonal_obj, widgets.HBox) + def test_orthogonal_with_torch_tensor(): # Create a sample PyTorch tensor vol = torch.rand(10, 10, 10) @@ -181,6 +229,7 @@ def test_orthogonal_with_torch_tensor(): # Assert that the orthogonal object is created successfully assert isinstance(orthogonal_obj, widgets.HBox) + def test_orthogonal_with_different_parameters(): # Test with different colormaps for cmap in ["viridis", "gray", "plasma"]: @@ -189,43 +238,47 @@ def test_orthogonal_with_different_parameters(): # Test with different image sizes for img_height, img_width in [(2, 2), (4, 4)]: - orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width) + orthogonal_obj = qim3d.viz.orthogonal( + np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width + ) assert isinstance(orthogonal_obj, widgets.HBox) # Test with show_position set to True and False for show_position in [True, False]: - orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), show_position=show_position) + orthogonal_obj = qim3d.viz.orthogonal( + np.random.rand(10, 10, 10), show_position=show_position + ) assert isinstance(orthogonal_obj, widgets.HBox) + def test_orthogonal_initial_slider_value(): # Create a sample NumPy array vol = np.random.rand(10, 7, 19) # Call the orthogonal function with the NumPy array orthogonal_obj = qim3d.viz.orthogonal(vol) - for idx,slicer in enumerate(orthogonal_obj.children): - assert slicer.children[0].value == vol.shape[idx]//2 + for idx, slicer in enumerate(orthogonal_obj.children): + assert slicer.children[0].value == vol.shape[idx] // 2 + def test_orthogonal_slider_description(): # Create a sample NumPy array vol = np.random.rand(10, 10, 10) # Call the orthogonal function with the NumPy array orthogonal_obj = qim3d.viz.orthogonal(vol) - for idx,slicer in enumerate(orthogonal_obj.children): - assert slicer.children[0].description == ['Z', 'Y', 'X'][idx] - - - + for idx, slicer in enumerate(orthogonal_obj.children): + assert slicer.children[0].description == ["Z", "Y", "X"][idx] # unit tests for local thickness visualization def test_local_thickness_2d(): - blobs = qim3d.examples.blobs_256x256 - lt = qim3d.processing.local_thickness(blobs) - fig = qim3d.viz.local_thickness(blobs, lt) + vol = qim3d.examples.fly_150x256x256[0] + lt = qim3d.processing.local_thickness(vol) + fig = qim3d.viz.local_thickness(vol, lt) # Assert that returned figure is a matplotlib figure assert isinstance(fig, plt.Figure) + def test_local_thickness_3d(): fly = qim3d.examples.fly_150x256x256 lt = qim3d.processing.local_thickness(fly) @@ -234,6 +287,7 @@ def test_local_thickness_3d(): # Assert that returned object is an interactive widget assert isinstance(obj, widgets.interactive) + def test_local_thickness_3d_max_projection(): fly = qim3d.examples.fly_150x256x256 lt = qim3d.processing.local_thickness(fly) diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py index 32e43b19a41f2562f0ac86b1cccae3765ae4dac0..61f688060d76da9a2f7064770d4ef2d98e8659f2 100644 --- a/qim3d/utils/__init__.py +++ b/qim3d/utils/__init__.py @@ -1,8 +1,15 @@ -#from .doi import get_bibtex, get_reference -from . import doi, internal_tools -from .augmentations import Augmentation -from .data import Dataset, prepare_dataloaders, prepare_datasets -from .models import inference, model_summary, train_model -from .preview import image_preview -from .loading_progress_bar import ProgressBar +from . import doi +from .progress_bar import ProgressBar from .system import Memory + +from .misc import ( + get_local_ip, + port_from_str, + gradio_header, + sizeof, + get_file_size, + get_port_dict, + get_css, + downscale_img, + scale_to_float16, +) diff --git a/qim3d/utils/doi.py b/qim3d/utils/doi.py index 71a253ebcbc8dce5e71c2973b84f23e5710e1fb2..60e01d76281159a86a08e3868527a7442f2a230d 100644 --- a/qim3d/utils/doi.py +++ b/qim3d/utils/doi.py @@ -1,7 +1,7 @@ """ Deals with DOI for references """ import json import requests -from qim3d.io.logger import log +from qim3d.utils.logger import log def _validate_response(response): diff --git a/qim3d/io/logger.py b/qim3d/utils/logger.py similarity index 96% rename from qim3d/io/logger.py rename to qim3d/utils/logger.py index dbc88180713a0c38a74754c7261591090bf4519f..e62b2151912c0059259dda3f124813654bb595b4 100644 --- a/qim3d/io/logger.py +++ b/qim3d/utils/logger.py @@ -132,5 +132,5 @@ def level(log_level): # create the logger log = logging.getLogger("qim3d") -set_simple_output() #TODO: This used to work, but now it gives duplicated messages. Need to be investigated. -#set_level_warning() +set_level_info() +set_simple_output() diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/misc.py similarity index 59% rename from qim3d/utils/internal_tools.py rename to qim3d/utils/misc.py index 58b97a37a857199863cddd056c4ec531f5c59abe..1c542a3ed47b8aa157afaa283fae0bf841a70bf9 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/misc.py @@ -3,62 +3,13 @@ import getpass import hashlib import os -import shutil import socket -from pathlib import Path - -import gradio as gr -import matplotlib -import matplotlib.pyplot as plt import numpy as np import outputformat as ouf import requests -from fastapi import FastAPI -from PIL import Image from scipy.ndimage import zoom -from uvicorn import run - -from qim3d.io.logger import log - - -def mock_plot(): - """Creates a mock plot of a sine wave. - - Returns: - matplotlib.figure.Figure: The generated plot figure. - - Example: - Creates a mock plot of a sine wave and displays the plot using `plt.show()`. - - >>> fig = mock_plot() - >>> plt.show() - """ - - # TODO: Check if using Agg backend conflicts with other pipelines - - matplotlib.use("Agg") - - fig = plt.figure(figsize=(5, 4)) - axes = fig.add_axes([0.1, 0.1, 0.8, 0.8]) - values = np.arange(0, 2 * np.pi, 0.01) - axes.plot(values, np.sin(values)) - - return fig - - -def mock_write_file(path, content="File created by qim3d"): - """ - Creates a file at the specified path and writes a predefined text into it. - - Args: - path (str): The path to the file to be created. - - Example: - >>> mock_write_file("example.txt") - """ - _file = open(path, "w", encoding="utf-8") - _file.write(content) - _file.close() +import difflib +import qim3d def get_local_ip(): @@ -179,7 +130,17 @@ def sizeof(num, suffix="B"): num /= 1024.0 return f"{num:.1f} Y{suffix}" -def get_file_size(filename:str) -> int: + +def find_similar_paths(path): + parent_dir = os.path.dirname(path) or "." + parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" + valid_paths = [os.path.join(parent_dir, file) for file in parent_files] + similar_paths = difflib.get_close_matches(path, valid_paths) + + return similar_paths + + +def get_file_size(file_path: str) -> int: """ Args: ----- @@ -189,75 +150,19 @@ def get_file_size(filename:str) -> int: --------- size (int): size of file in bytes """ - - return os.path.getsize(filename) - -def is_server_running(ip, port): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - s.connect((ip, int(port))) - s.shutdown(2) - return True - except: - return False - + file_size = os.path.getsize(file_path) + except FileNotFoundError: + similar_paths = qim3d.utils.misc.find_similar_paths(file_path) -def temp_data(folder, remove=False, n=3, img_shape=(32, 32)): - """Creates a temporary folder to test deep learning tools. - - Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'. - n random images are then added to all four subfolders. - If the 'remove' variable is True, the folders and their content are removed. - - Args: - folder (str): The path where the folders should be placed. - remove (bool, optional): If True, all folders are removed from their location. - n (int, optional): Number of random images and labels in the temporary dataset. - img_shape (tuple, options): Tuple with the height and width of the images and labels. - - Example: - >>> tempdata('temporary_folder',n = 10, img_shape = (16,16)) - """ - folder_trte = ["train", "test"] - sub_folders = ["images", "labels"] - - # Creating train/test folder - path_train = Path(folder) / folder_trte[0] - path_test = Path(folder) / folder_trte[1] - - # Creating folders for images and labels - path_train_im = path_train / sub_folders[0] - path_train_lab = path_train / sub_folders[1] - path_test_im = path_test / sub_folders[0] - path_test_lab = path_test / sub_folders[1] - - # Random image - img = np.random.randint(2, size=img_shape, dtype=np.uint8) - img = Image.fromarray(img) - - if not os.path.exists(path_train): - os.makedirs(path_train_im) - os.makedirs(path_test_im) - os.makedirs(path_train_lab) - os.makedirs(path_test_lab) - for i in range(n): - img.save(path_train_im / f"img_train{i}.png") - img.save(path_train_lab / f"img_train{i}.png") - img.save(path_test_im / f"img_test{i}.png") - img.save(path_test_lab / f"img_test{i}.png") - - if remove: - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - log.warning("Failed to delete %s. Reason: %s" % (file_path, e)) - - os.rmdir(folder) + if similar_paths: + suggestion = similar_paths[0] # Get the closest match + message = f"Invalid path. Did you mean '{suggestion}'?" + raise FileNotFoundError(repr(message)) + else: + raise FileNotFoundError("Invalid path") + + return file_size def stringify_path(path): @@ -284,32 +189,6 @@ def get_port_dict(): return port_dict -def run_gradio_app(gradio_interface, host="0.0.0.0"): - - # Get port using the QIM API - port_dict = get_port_dict() - - if "gradio_port" in port_dict: - port = port_dict["gradio_port"] - elif "port" in port_dict: - port = port_dict["port"] - else: - raise Exception("Port not specified from QIM API") - - gradio_header(gradio_interface.title, port) - - # Create FastAPI with mounted gradio interface - app = FastAPI() - path = f"/gui/{port_dict['username']}/{port}/" - app = gr.mount_gradio_app(app, gradio_interface, path=path) - - # Full path - print(f"http://{host}:{port}{path}") - - # Run the FastAPI server usign uvicorn - run(app, host=host, port=int(port)) - - def get_css(): current_directory = os.path.dirname(os.path.abspath(__file__)) @@ -321,8 +200,9 @@ def get_css(): return css_content + def downscale_img(img, max_voxels=512**3): - """ Downscale image if total number of voxels exceeds 512³. + """Downscale image if total number of voxels exceeds 512³. Args: img (np.Array): Input image. @@ -340,11 +220,12 @@ def downscale_img(img, max_voxels=512**3): return img # Calculate zoom factor - zoom_factor = (max_voxels / total_voxels) ** (1/3) + zoom_factor = (max_voxels / total_voxels) ** (1 / 3) # Downscale image return zoom(img, zoom_factor) + def scale_to_float16(arr: np.ndarray): """ Scale the input array to the float16 data type. @@ -373,4 +254,4 @@ def scale_to_float16(arr: np.ndarray): # Convert the scaled array to float16 data type arr = arr.astype(np.float16) - return arr \ No newline at end of file + return arr diff --git a/qim3d/utils/loading_progress_bar.py b/qim3d/utils/progress_bar.py similarity index 64% rename from qim3d/utils/loading_progress_bar.py rename to qim3d/utils/progress_bar.py index 980349ea10389cb9efb1aea831aac1a7d70d7914..7964ac6ffccdb52e6442b44acfde6375c654b29a 100644 --- a/qim3d/utils/loading_progress_bar.py +++ b/qim3d/utils/progress_bar.py @@ -4,7 +4,7 @@ import sys from tqdm.auto import tqdm -from qim3d.utils.internal_tools import get_file_size +from qim3d.utils.misc import get_file_size class RepeatTimer(Timer): @@ -14,35 +14,40 @@ class RepeatTimer(Timer): work at all. Thus we have to use timer, which runs the function at (approximately) the given time. With this subclass from https://stackoverflow.com/a/48741004/11514359 - we don't have to guess how many timers we will need and create multiple timers. + we don't have to guess how many timers we will need and create multiple timers. """ + def run(self): while not self.finished.wait(self.interval): self.function(*self.args, **self.kwargs) + class ProgressBar: - def __init__(self, filename:str, repeat_time:float = 0.5, *args, **kwargs): + def __init__(self, filename: str, repeat_time: float = 0.5, *args, **kwargs): """ Creates class for 'with' statement to track progress during loading a file into memory Parameters: ------------ - filename (str): to get size of the file - - repeat_time (float, optional): How often the timer checks how many bytes were loaded. Even if very small, + - repeat_time (float, optional): How often the timer checks how many bytes were loaded. Even if very small, it doesn't make the progress bar smoother as there are only few visible changes in number of read_chars. Defaults to 0.25 """ self.timer = RepeatTimer(repeat_time, self.memory_check) - self.pbar = tqdm(total = get_file_size(filename), - desc = "Loading: ", - unit = "B", - file = sys.stdout, - unit_scale = True, - unit_divisor = 1024, - bar_format = '{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]') + self.pbar = tqdm( + total=get_file_size(filename), + desc="Loading: ", + unit="B", + file=sys.stdout, + unit_scale=True, + unit_divisor=1024, + bar_format="{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, " + "{rate_fmt}{postfix}]", + ) self.last_memory = 0 self.process = psutil.Process() - + def memory_check(self): counters = self.process.io_counters() try: @@ -50,15 +55,15 @@ class ProgressBar: except AttributeError: memory = counters.read_bytes + counters.other_bytes - try: self.pbar.update(memory - self.last_memory) - except AttributeError: # When we leave the context manager, we delete the pbar so it can not be updated anymore - # It's because it takes quite a long time for the timer to end and might update the pbar - # one more time before ending which messes up the whole thing + except ( + AttributeError + ): # When we leave the context manager, we delete the pbar so it can not be updated anymore + # It's because it takes quite a long time for the timer to end and might update the pbar + # one more time before ending which messes up the whole thing pass - self.last_memory = memory def __enter__(self): @@ -69,5 +74,4 @@ class ProgressBar: self.pbar.clear() self.pbar.n = self.pbar.total self.pbar.display() - del self.pbar # So the update process can not update it anymore - + del self.pbar # So the update process can not update it anymore diff --git a/qim3d/utils/system.py b/qim3d/utils/system.py index 9ee8f8cb5b98d32c93470b31c7f6e7a4efef7887..39dc3af72ffd133e6d758d99070d0e0c1e36819d 100644 --- a/qim3d/utils/system.py +++ b/qim3d/utils/system.py @@ -2,8 +2,8 @@ import os import time import psutil -from qim3d.utils.internal_tools import sizeof -from qim3d.io.logger import log +from qim3d.utils.misc import sizeof +from qim3d.utils.logger import log import numpy as np diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 080a4ebd7a812cb659c5b817d18767a3d287568a..91257925b765752bb0311775985ba579a15052f6 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,16 +1,14 @@ from . import colormaps from .cc import plot_cc from .detection import circles -from .img import ( - grid_overview, - grid_pred, +from .explore import ( interactive_fade_mask, orthogonal, slicer, slices, - vol_masked, ) from .k3d import vol from .local_thickness_ import local_thickness from .structure_tensor import vectors -from .visualizations import plot_metrics +from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked +from .preview import image_preview diff --git a/qim3d/viz/cc.py b/qim3d/viz/cc.py index 8daee7dca7cbeb0f206ac215954521b4564b2393..d0bdb42771f2568912cffda52f78f217513d1810 100644 --- a/qim3d/viz/cc.py +++ b/qim3d/viz/cc.py @@ -1,10 +1,9 @@ import matplotlib.pyplot as plt import numpy as np -from qim3d.io.logger import log -from qim3d.processing.cc import CC -from qim3d.viz.colormaps import objects as qim3dCmap -import qim3d +from qim3d.utils.logger import log +import qim3d.viz.colormaps + def plot_cc( connected_components, @@ -51,11 +50,13 @@ def plot_cc( component_indexs = range( 1, min(max_cc_to_plot + 1, len(connected_components) + 1) ) - + figs = [] for component in component_indexs: if overlay is not None: - assert (overlay.shape == connected_components.shape), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." + assert ( + overlay.shape == connected_components.shape + ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." # plots overlay masked to connected component if crop: @@ -71,16 +72,18 @@ def plot_cc( overlay_crop = np.where(cc == 0, 0, overlay) fig = qim3d.viz.slices(overlay_crop, show=show, **kwargs) else: - # assigns discrete color map to each connected component if not given + # assigns discrete color map to each connected component if not given if "cmap" not in kwargs: - kwargs["cmap"] = qim3dCmap(len(component_indexs)) - + kwargs["cmap"] = qim3d.viz.colormaps.objects(len(component_indexs)) + # Plot the connected component without overlay - fig = qim3d.viz.slices(connected_components.get_cc(component, crop=crop), show=show, **kwargs) + fig = qim3d.viz.slices( + connected_components.get_cc(component, crop=crop), show=show, **kwargs + ) figs.append(fig) if not show: return figs - return \ No newline at end of file + return diff --git a/qim3d/viz/detection.py b/qim3d/viz/detection.py index 3e83a458662e71ef7e3e29805b2bf5b533ca2829..fcf493cd7e550aa183f09e2dbae56c9e7bc3de7c 100644 --- a/qim3d/viz/detection.py +++ b/qim3d/viz/detection.py @@ -1,5 +1,5 @@ import matplotlib.pyplot as plt -from qim3d.io.logger import log +from qim3d.utils.logger import log import numpy as np import ipywidgets as widgets from IPython.display import clear_output, display diff --git a/qim3d/viz/img.py b/qim3d/viz/explore.py similarity index 53% rename from qim3d/viz/img.py rename to qim3d/viz/explore.py index f8b9bfee14c8e1e5cb4cc2b950b0440065cc99e5..2300ad7b96f0cb95835f864b5f42025c5f0bb891 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/explore.py @@ -9,219 +9,12 @@ import dask.array as da import ipywidgets as widgets import matplotlib.pyplot as plt import numpy as np -import torch -from matplotlib import colormaps -from matplotlib.colors import LinearSegmentedColormap import qim3d -from qim3d.io.logger import log - - -def grid_overview( - data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False -): - """Displays an overview grid of images, labels, and masks (if they exist). - - Labels are the annotated target segmentations - Masks are applied to the output and target prior to the loss calculation in case of - sparse labeled data - - Args: - data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image, label, (and mask data). - num_images (int, optional): The maximum number of images to display. Defaults to 7. - cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'. - cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'. - alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5. - show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. - - Raises: - ValueError: If the data elements are not tuples. - - - Returns: - fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels. - - Example: - ```python - data = [(image1, label1, mask1), (image2, label2, mask2)] - grid_overview(data, num_images=5, cmap_im='viridis', cmap_segm='hot', alpha=0.8) - ``` - - Notes: - - If the image data is RGB, the color map is ignored and the user is informed. - - The number of displayed images is limited to the minimum between `num_images` - and the length of the data. - - The grid layout and dimensions vary based on the presence of a mask. - """ - - # Check if data has a mask - has_mask = len(data[0]) > 2 and data[0][-1] is not None - - # Check if image data is RGB and inform the user if it's the case - if len(data[0][0].squeeze().shape) > 2: - log.info("Input images are RGB: color map is ignored") - - # Check if dataset have at least specified number of images - if len(data) < num_images: - log.warning( - "Not enough images in the dataset. Changing num_images=%d to num_images=%d", - num_images, - len(data), - ) - num_images = len(data) - - # Adapt segmentation cmap so that background is transparent - colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) - colors_segm[:128, 3] = 0 - custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) - - # Check if data have the right format - if not isinstance(data[0], tuple): - raise ValueError("Data elements must be tuples") - - # Define row titles - row_titles = ["Input images", "Ground truth segmentation", "Mask"] - - # Make new list such that possible augmentations remain identical for all three rows - plot_data = [data[idx] for idx in range(num_images)] - - fig = plt.figure( - figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True - ) - - # create 2 (3) x 1 subfigs - subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1) - for row, subfig in enumerate(subfigs): - subfig.suptitle(row_titles[row], fontsize=22) - - # create 1 x num_images subplots per subfig - axs = subfig.subplots(nrows=1, ncols=num_images) - for col, ax in enumerate(np.atleast_1d(axs)): - if row in [1, 2]: # Ground truth segmentation and mask - ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im) - ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha) - ax.axis("off") - else: - ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im) - ax.axis("off") - - if show: - plt.show() - plt.close() - - return fig - - -def grid_pred( - in_targ_preds, - num_images=7, - cmap_im="gray", - cmap_segm="viridis", - alpha=0.5, - show=False, -): - """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. - - Displays a grid of subplots representing different aspects of the input images and segmentations. - The grid includes the following rows: - - Row 1: Input images - - Row 2: Predicted segmentations overlaying input images - - Row 3: Ground truth segmentations overlaying input images - - Row 4: Comparison between true and predicted segmentations overlaying input images - - Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset. - The function utilizes various color maps for visualization and applies transparency to the segmentations. - - Args: - in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations. - num_images (int, optional): Number of images to display. Defaults to 7. - cmap_im (str, optional): Color map for input images. Defaults to "gray". - cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis". - alpha (float, optional): Alpha value for transparency. Defaults to 0.5. - show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. - - Returns: - fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models. - - Raises: - None - - Example: - dataset = MySegmentationDataset() - model = MySegmentationModel() - in_targ_preds = qim3d.utils.models.inference(dataset,model) - grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5) - """ - - # Check if dataset have at least specified number of images - if len(in_targ_preds[0]) < num_images: - log.warning( - "Not enough images in the dataset. Changing num_images=%d to num_images=%d", - num_images, - len(in_targ_preds[0]), - ) - num_images = len(in_targ_preds[0]) - - # Take only the number of images from in_targ_preds - inputs, targets, preds = [items[:num_images] for items in in_targ_preds] - - # Adapt segmentation cmap so that background is transparent - colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) - colors_segm[:128, 3] = 0 - custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) - - N = num_images - H = inputs[0].shape[-2] - W = inputs[0].shape[-1] - - comp_rgb = torch.zeros((N, 4, H, W)) - comp_rgb[:, 1, :, :] = targets.logical_and(preds) - comp_rgb[:, 0, :, :] = targets.logical_xor(preds) - comp_rgb[:, 3, :, :] = targets.logical_or(preds) - - row_titles = [ - "Input images", - "Predicted segmentation", - "Ground truth segmentation", - "True vs. predicted segmentation", - ] - - fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True) - - # create 3 x 1 subfigs - subfigs = fig.subfigures(nrows=4, ncols=1) - for row, subfig in enumerate(subfigs): - subfig.suptitle(row_titles[row], fontsize=22) - - # create 1 x num_images subplots per subfig - axs = subfig.subplots(nrows=1, ncols=num_images) - for col, ax in enumerate(np.atleast_1d(axs)): - if row == 0: - ax.imshow(inputs[col], cmap=cmap_im) - ax.axis("off") - - elif row == 1: # Predicted segmentation - ax.imshow(inputs[col], cmap=cmap_im) - ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha) - ax.axis("off") - elif row == 2: # Ground truth segmentation - ax.imshow(inputs[col], cmap=cmap_im) - ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha) - ax.axis("off") - else: - ax.imshow(inputs[col], cmap=cmap_im) - ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha) - ax.axis("off") - - if show: - plt.show() - plt.close() - - return fig def slices( - vol: Union[np.ndarray, torch.Tensor], + vol: np.ndarray, axis: int = 0, position: Optional[Union[str, int, List[int]]] = None, n_slices: int = 5, @@ -232,7 +25,7 @@ def slices( show: bool = False, show_position: bool = True, interpolation: Optional[str] = "none", - img_size = None, + img_size=None, **imshow_kwargs, ) -> plt.Figure: """Displays one or several slices from a 3d volume. @@ -242,7 +35,7 @@ def slices( If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted. Args: - vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + vol np.ndarray: The 3D volume to be sliced. axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. @@ -258,7 +51,7 @@ def slices( fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. Raises: - ValueError: If the input is not a numpy.ndarray or torch.Tensor. + ValueError: If the input is not a numpy.ndarray or da.core.Array. ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. ValueError: If the file or array is not a volume with at least 3 dimensions. ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end". @@ -277,7 +70,7 @@ def slices( img_width = img_size # Numpy array or Torch tensor input - if not isinstance(vol, (np.ndarray, torch.Tensor, da.core.Array)): + if not isinstance(vol, (np.ndarray, da.core.Array)): raise ValueError("Data type not supported") if vol.ndim < 3: @@ -287,7 +80,7 @@ def slices( if isinstance(vol, da.core.Array): vol = vol.compute() - + # Ensure axis is a valid choice if not (0 <= axis < vol.ndim): raise ValueError( @@ -334,9 +127,7 @@ def slices( axs = [axs] # Convert to a list for uniformity # Convert to NumPy array in order to use the numpy.take method - if isinstance(vol, torch.Tensor): - vol = vol.numpy() - elif isinstance(vol, da.core.Array): + if isinstance(vol, da.core.Array): vol = vol.compute() # Run through each ax of the grid @@ -406,20 +197,20 @@ def _get_slice_range(position: int, n_slices: int, n_total): def slicer( - vol: Union[np.ndarray, torch.Tensor], + vol: np.ndarray, axis: int = 0, cmap: str = "viridis", img_height: int = 3, img_width: int = 3, show_position: bool = False, interpolation: Optional[str] = "none", - img_size = None, + img_size=None, **imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. Args: - vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + vol (np.ndarray): The 3D volume to be sliced. axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". img_height (int, optional): Height of the figure. Defaults to 3. @@ -475,18 +266,18 @@ def slicer( def orthogonal( - vol: Union[np.ndarray, torch.Tensor], + vol: np.ndarray, cmap: str = "viridis", img_height: int = 3, img_width: int = 3, show_position: bool = False, interpolation: Optional[str] = None, - img_size = None, + img_size=None, ): """Interactive widget for visualizing orthogonal slices of a 3D volume. Args: - vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + vol (np.ndarray): The 3D volume to be sliced. cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". img_height(int, optional): Height of the figure. img_width(int, optional): Width of the figure. @@ -545,36 +336,8 @@ def orthogonal( return widgets.HBox([z_slicer, y_slicer, x_slicer]) -def vol_masked(vol, vol_mask, viz_delta=128): - """ - Applies masking to a volume based on a binary volume mask. - - This function takes a volume array `vol` and a corresponding binary volume mask `vol_mask`. - It computes the masked volume where pixels outside the mask are set to the background value, - and pixels inside the mask are set to foreground. - - - Args: - vol (ndarray): The input volume as a NumPy array. - vol_mask (ndarray): The binary mask volume as a NumPy array with the same shape as `vol`. - viz_delta (int, optional): Value added to the volume before applying the mask to visualize masked regions. - Defaults to 128. - - Returns: - ndarray: The masked volume with the same shape as `vol`, where pixels outside the mask are set - to the background value (negative). - - - """ - - background = (vol.astype("float") + viz_delta) * (1 - vol_mask) * -1 - foreground = (vol.astype("float") + viz_delta) * vol_mask - vol_masked = background + foreground - - return vol_masked - def interactive_fade_mask(vol: np.ndarray, axis: int = 0): - """ Interactive widget for visualizing the effect of edge fading on a 3D volume. + """Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. @@ -586,9 +349,9 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): ```python import qim3d vol = qim3d.examples.cement_128x128x128 - qim3d.viz.interactive_fade_mask(vol) + qim3d.viz.interactive_fade_mask(vol) ``` -  +  """ @@ -596,26 +359,40 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): def _slicer(position, decay_rate, ratio, geometry, invert): fig, axes = plt.subplots(1, 3, figsize=(9, 3)) - axes[0].imshow(vol[position, :, :], cmap='viridis') - axes[0].set_title('Original') - axes[0].axis('off') + axes[0].imshow(vol[position, :, :], cmap="viridis") + axes[0].set_title("Original") + axes[0].axis("off") - mask = qim3d.processing.operations.fade_mask(np.ones_like(vol), decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert) - axes[1].imshow(mask[position, :, :], cmap='viridis') - axes[1].set_title('Mask') - axes[1].axis('off') + mask = qim3d.processing.operations.fade_mask( + np.ones_like(vol), + decay_rate=decay_rate, + ratio=ratio, + geometry=geometry, + axis=axis, + invert=invert, + ) + axes[1].imshow(mask[position, :, :], cmap="viridis") + axes[1].set_title("Mask") + axes[1].axis("off") - masked_vol = qim3d.processing.operations.fade_mask(vol, decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert) - axes[2].imshow(masked_vol[position, :, :], cmap='viridis') - axes[2].set_title('Masked') - axes[2].axis('off') + masked_vol = qim3d.processing.operations.fade_mask( + vol, + decay_rate=decay_rate, + ratio=ratio, + geometry=geometry, + axis=axis, + invert=invert, + ) + axes[2].imshow(masked_vol[position, :, :], cmap="viridis") + axes[2].set_title("Masked") + axes[2].axis("off") return fig - + shape_dropdown = widgets.Dropdown( - options=['sphere', 'cilinder'], - value='sphere', # default value - description='Geometry', + options=["sphere", "cilinder"], + value="sphere", # default value + description="Geometry", ) position_slider = widgets.IntSlider( @@ -637,18 +414,24 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): value=0.5, min=0.1, max=1, - step=0.01, + step=0.01, description="Ratio", continuous_update=False, ) # Create the Checkbox widget invert_checkbox = widgets.Checkbox( - value=False, # default value - description='Invert' + value=False, description="Invert" # default value ) - slicer_obj = widgets.interactive(_slicer, position=position_slider, decay_rate=decay_rate_slider, ratio=ratio_slider, geometry=shape_dropdown, invert=invert_checkbox) + slicer_obj = widgets.interactive( + _slicer, + position=position_slider, + decay_rate=decay_rate_slider, + ratio=ratio_slider, + geometry=shape_dropdown, + invert=invert_checkbox, + ) slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py index fc3a2cf434e46c2e512efc7675f9591bffa29b85..767d84dfeed699a2d75f9fd3058fd00fab9ca46a 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/k3d.py @@ -8,8 +8,8 @@ Volumetric visualization using K3D """ import numpy as np -from qim3d.io.logger import log -from qim3d.utils.internal_tools import downscale_img, scale_to_float16 +from qim3d.utils.logger import log +from qim3d.utils.misc import downscale_img, scale_to_float16 def vol( diff --git a/qim3d/viz/local_thickness_.py b/qim3d/viz/local_thickness_.py index 86326e8ae6bdba1273dbf0bb642d77e4102219ee..e1195e7dfae54058f18e66e5a9a83867c6dea503 100644 --- a/qim3d/viz/local_thickness_.py +++ b/qim3d/viz/local_thickness_.py @@ -1,4 +1,4 @@ -from qim3d.io.logger import log +from qim3d.utils.logger import log import numpy as np import matplotlib.pyplot as plt from typing import Optional, Union, Tuple diff --git a/qim3d/viz/metrics.py b/qim3d/viz/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbd08ea71decebbd729fe6e20c6c42a50171629 --- /dev/null +++ b/qim3d/viz/metrics.py @@ -0,0 +1,311 @@ +"""Visualization tools""" + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +from matplotlib import colormaps +from qim3d.utils.logger import log + + +def plot_metrics( + *metrics, + linestyle="-", + batch_linestyle="dotted", + labels: list = None, + figsize: tuple = (16, 6), + show=False +): + """ + Plots the metrics over epochs and batches. + + Args: + *metrics: Variable-length argument list of dictionary containing the metrics per epochs and per batches. + linestyle (str, optional): The style of the epoch metric line. Defaults to '-'. + batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'. + labels (list[str], optional): Labels for the plotted lines. Defaults to None. + figsize (Tuple[int, int], optional): Figure size (width, height) in inches. Defaults to (16, 6). + show (bool, optional): If True, displays the plot. Defaults to False. + + Returns: + fig (matplotlib.figure.Figure): plot with metrics. + + Example: + train_loss = {'epoch_loss' : [...], 'batch_loss': [...]} + val_loss = {'epoch_loss' : [...], 'batch_loss': [...]} + plot_metrics(train_loss,val_loss, labels=['Train','Valid.']) + """ + import seaborn as snb + + if labels == None: + labels = [None] * len(metrics) + elif len(metrics) != len(labels): + raise ValueError("The number of metrics doesn't match the number of labels.") + + # plotting parameters + snb.set_style("darkgrid") + snb.set(font_scale=1.5) + plt.rcParams["lines.linewidth"] = 2 + + fig = plt.figure(figsize=figsize) + + palette = snb.color_palette(None, len(metrics)) + + for i, metric in enumerate(metrics): + metric_name = list(metric.keys())[0] + epoch_metric = metric[list(metric.keys())[0]] + batch_metric = metric[list(metric.keys())[1]] + + x_axis = np.linspace(0, len(epoch_metric) - 1, len(batch_metric)) + + plt.plot(epoch_metric, linestyle=linestyle, color=palette[i], label=labels[i]) + plt.plot( + x_axis, batch_metric, linestyle=batch_linestyle, color=palette[i], alpha=0.4 + ) + + if labels[0] != None: + plt.legend() + + plt.ylabel(metric_name) + plt.xlabel("epoch") + + # reset plotting parameters + snb.set_style("white") + + if show: + plt.show() + plt.close() + + return fig + + +def grid_overview( + data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False +): + """Displays an overview grid of images, labels, and masks (if they exist). + + Labels are the annotated target segmentations + Masks are applied to the output and target prior to the loss calculation in case of + sparse labeled data + + Args: + data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image, label, (and mask data). + num_images (int, optional): The maximum number of images to display. Defaults to 7. + cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'. + cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'. + alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + + Raises: + ValueError: If the data elements are not tuples. + + + Returns: + fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels. + + Example: + ```python + data = [(image1, label1, mask1), (image2, label2, mask2)] + grid_overview(data, num_images=5, cmap_im='viridis', cmap_segm='hot', alpha=0.8) + ``` + + Notes: + - If the image data is RGB, the color map is ignored and the user is informed. + - The number of displayed images is limited to the minimum between `num_images` + and the length of the data. + - The grid layout and dimensions vary based on the presence of a mask. + """ + + # Check if data has a mask + has_mask = len(data[0]) > 2 and data[0][-1] is not None + + # Check if image data is RGB and inform the user if it's the case + if len(data[0][0].squeeze().shape) > 2: + log.info("Input images are RGB: color map is ignored") + + # Check if dataset have at least specified number of images + if len(data) < num_images: + log.warning( + "Not enough images in the dataset. Changing num_images=%d to num_images=%d", + num_images, + len(data), + ) + num_images = len(data) + + # Adapt segmentation cmap so that background is transparent + colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) + colors_segm[:128, 3] = 0 + custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) + + # Check if data have the right format + if not isinstance(data[0], tuple): + raise ValueError("Data elements must be tuples") + + # Define row titles + row_titles = ["Input images", "Ground truth segmentation", "Mask"] + + # Make new list such that possible augmentations remain identical for all three rows + plot_data = [data[idx] for idx in range(num_images)] + + fig = plt.figure( + figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True + ) + + # create 2 (3) x 1 subfigs + subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1) + for row, subfig in enumerate(subfigs): + subfig.suptitle(row_titles[row], fontsize=22) + + # create 1 x num_images subplots per subfig + axs = subfig.subplots(nrows=1, ncols=num_images) + for col, ax in enumerate(np.atleast_1d(axs)): + if row in [1, 2]: # Ground truth segmentation and mask + ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im) + ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha) + ax.axis("off") + else: + ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im) + ax.axis("off") + + if show: + plt.show() + plt.close() + + return fig + + +def grid_pred( + in_targ_preds, + num_images=7, + cmap_im="gray", + cmap_segm="viridis", + alpha=0.5, + show=False, +): + """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. + + Displays a grid of subplots representing different aspects of the input images and segmentations. + The grid includes the following rows: + - Row 1: Input images + - Row 2: Predicted segmentations overlaying input images + - Row 3: Ground truth segmentations overlaying input images + - Row 4: Comparison between true and predicted segmentations overlaying input images + + Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset. + The function utilizes various color maps for visualization and applies transparency to the segmentations. + + Args: + in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations. + num_images (int, optional): Number of images to display. Defaults to 7. + cmap_im (str, optional): Color map for input images. Defaults to "gray". + cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis". + alpha (float, optional): Alpha value for transparency. Defaults to 0.5. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + + Returns: + fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models. + + Raises: + None + + Example: + dataset = MySegmentationDataset() + model = MySegmentationModel() + in_targ_preds = qim3d.utils.models.inference(dataset,model) + grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5) + """ + import torch + + # Check if dataset have at least specified number of images + if len(in_targ_preds[0]) < num_images: + log.warning( + "Not enough images in the dataset. Changing num_images=%d to num_images=%d", + num_images, + len(in_targ_preds[0]), + ) + num_images = len(in_targ_preds[0]) + + # Take only the number of images from in_targ_preds + inputs, targets, preds = [items[:num_images] for items in in_targ_preds] + + # Adapt segmentation cmap so that background is transparent + colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) + colors_segm[:128, 3] = 0 + custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) + + N = num_images + H = inputs[0].shape[-2] + W = inputs[0].shape[-1] + + comp_rgb = torch.zeros((N, 4, H, W)) + comp_rgb[:, 1, :, :] = targets.logical_and(preds) + comp_rgb[:, 0, :, :] = targets.logical_xor(preds) + comp_rgb[:, 3, :, :] = targets.logical_or(preds) + + row_titles = [ + "Input images", + "Predicted segmentation", + "Ground truth segmentation", + "True vs. predicted segmentation", + ] + + fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True) + + # create 3 x 1 subfigs + subfigs = fig.subfigures(nrows=4, ncols=1) + for row, subfig in enumerate(subfigs): + subfig.suptitle(row_titles[row], fontsize=22) + + # create 1 x num_images subplots per subfig + axs = subfig.subplots(nrows=1, ncols=num_images) + for col, ax in enumerate(np.atleast_1d(axs)): + if row == 0: + ax.imshow(inputs[col], cmap=cmap_im) + ax.axis("off") + + elif row == 1: # Predicted segmentation + ax.imshow(inputs[col], cmap=cmap_im) + ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha) + ax.axis("off") + elif row == 2: # Ground truth segmentation + ax.imshow(inputs[col], cmap=cmap_im) + ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha) + ax.axis("off") + else: + ax.imshow(inputs[col], cmap=cmap_im) + ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha) + ax.axis("off") + + if show: + plt.show() + plt.close() + + return fig + + +def vol_masked(vol, vol_mask, viz_delta=128): + """ + Applies masking to a volume based on a binary volume mask. + + This function takes a volume array `vol` and a corresponding binary volume mask `vol_mask`. + It computes the masked volume where pixels outside the mask are set to the background value, + and pixels inside the mask are set to foreground. + + + Args: + vol (ndarray): The input volume as a NumPy array. + vol_mask (ndarray): The binary mask volume as a NumPy array with the same shape as `vol`. + viz_delta (int, optional): Value added to the volume before applying the mask to visualize masked regions. + Defaults to 128. + + Returns: + ndarray: The masked volume with the same shape as `vol`, where pixels outside the mask are set + to the background value (negative). + + + """ + + background = (vol.astype("float") + viz_delta) * (1 - vol_mask) * -1 + foreground = (vol.astype("float") + viz_delta) * vol_mask + vol_masked_result = background + foreground + + return vol_masked_result diff --git a/qim3d/utils/preview.py b/qim3d/viz/preview.py similarity index 100% rename from qim3d/utils/preview.py rename to qim3d/viz/preview.py diff --git a/qim3d/viz/visualizations.py b/qim3d/viz/visualizations.py deleted file mode 100644 index 4d64b0ce5b397486c90d4e3af69363e4a333dc2c..0000000000000000000000000000000000000000 --- a/qim3d/viz/visualizations.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Visualization tools""" - -import numpy as np -import matplotlib.pyplot as plt -import seaborn as snb - -def plot_metrics(*metrics, - linestyle = '-', - batch_linestyle = 'dotted', - labels:list = None, - figsize:tuple = (16,6), - show = False): - """ - Plots the metrics over epochs and batches. - - Args: - *metrics: Variable-length argument list of dictionary containing the metrics per epochs and per batches. - linestyle (str, optional): The style of the epoch metric line. Defaults to '-'. - batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'. - labels (list[str], optional): Labels for the plotted lines. Defaults to None. - figsize (Tuple[int, int], optional): Figure size (width, height) in inches. Defaults to (16, 6). - show (bool, optional): If True, displays the plot. Defaults to False. - - Returns: - fig (matplotlib.figure.Figure): plot with metrics. - - Example: - train_loss = {'epoch_loss' : [...], 'batch_loss': [...]} - val_loss = {'epoch_loss' : [...], 'batch_loss': [...]} - plot_metrics(train_loss,val_loss, labels=['Train','Valid.']) - """ - if labels == None: - labels = [None]*len(metrics) - elif len(metrics) != len(labels): - raise ValueError("The number of metrics doesn't match the number of labels.") - - # plotting parameters - snb.set_style('darkgrid') - snb.set(font_scale=1.5) - plt.rcParams['lines.linewidth'] = 2 - - fig = plt.figure(figsize = figsize) - - palette = snb.color_palette(None,len(metrics)) - - for i,metric in enumerate(metrics): - metric_name = list(metric.keys())[0] - epoch_metric = metric[list(metric.keys())[0]] - batch_metric = metric[list(metric.keys())[1]] - - x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric)) - - plt.plot(epoch_metric,linestyle = linestyle, color = palette[i], label = labels[i]) - plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = palette[i], alpha = 0.4) - - if labels[0] != None: - plt.legend() - - plt.ylabel(metric_name) - plt.xlabel('epoch') - - # reset plotting parameters - snb.set_style('white') - - if show: - plt.show() - plt.close() - - return fig \ No newline at end of file diff --git a/setup.py b/setup.py index cba1ddc7f2059b8deaeee38950e6435c2337c87c..0838da68652a2331ff134eef4e222034382f48b5 100644 --- a/setup.py +++ b/setup.py @@ -1,18 +1,26 @@ import os - +import re from setuptools import find_packages, setup # Read the contents of your README file with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() +# Read the version from the __init__.py file +def read_version(): + with open(os.path.join("qim3d", "__init__.py"), "r", encoding="utf-8") as f: + version_file = f.read() + version_match = re.search(r'^__version__ = ["\']([^"\']*)["\']', version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") setup( name="qim3d", - version="0.3.9", + version=read_version(), author="Felipe Delestro", author_email="fima@dtu.dk", - description="QIM tools and user interfaces", + description="QIM tools and user interfaces for volumetric imaging", long_description=long_description, long_description_content_type="text/markdown", url="https://platform.qim.dk/qim3d", @@ -20,7 +28,7 @@ setup( include_package_data=True, entry_points = { 'console_scripts': [ - 'qim3d=qim3d.utils.cli:main' + 'qim3d=qim3d.cli:main' ] }, classifiers=[ @@ -37,13 +45,11 @@ setup( ], python_requires=">=3.10", install_requires=[ - "albumentations>=1.3.1", "gradio>=4.27.0", "h5py>=3.9.0", "localthickness>=0.1.2", "matplotlib>=3.8.0", "pydicom>=2.4.4", - "monai>=1.2.0", "numpy>=1.26.0", "outputformat>=0.1.3", "Pillow>=10.0.1", @@ -52,9 +58,6 @@ setup( "seaborn>=0.12.2", "setuptools>=68.0.0", "tifffile>=2023.4.12", - "torch>=2.0.1", - "torchvision>=0.15.2", - "torchinfo>=1.8.0", "tqdm>=4.65.0", "nibabel>=5.2.0", "ipywidgets>=8.1.2", @@ -65,5 +68,15 @@ setup( "structure-tensor>=0.2.1", "noise>=1.2.2", "zarr>=2.18.2", + "scikit-image>=0.24.0" ], + extras_require={ + "deep-learning": [ + "albumentations>=1.3.1", + "torch>=2.0.1", + "torchvision>=0.15.2", + "torchinfo>=1.8.0", + "monai>=1.2.0", + ] +} ) diff --git a/temp.har b/temp.har new file mode 100644 index 0000000000000000000000000000000000000000..b1c44090f3f6abd52a0c74e0f20b5209fa155f61 --- /dev/null +++ b/temp.har @@ -0,0 +1 @@ +{"log":{"version":"1.2","creator":{"name":"python","version":"3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]"},"browser":{"name":"python","version":"3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]"},"pages":[{"startedDateTime":"1991-07-05T00:00:00.000Z","id":"page0","title":"`import qim3d` 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]","pageTimings":{"onContentLoad":-1,"onLoad":-1}}],"entries":[{"page_ref":"page0","startedDateTime":"1991-07-05T00:00:00.000Z","time":72,"request":{"method":"GET","url":"qim3d","httpVersion":"HTTP/0.0","cookies":[],"headers":[],"queryString":[],"headersSize":0,"bodySize":0},"response":{"status":200,"statusText":"OK","httpVersion":"HTTP/0.0","cookies":[],"headers":[],"content":{"size":0,"mimeType":"text/x-python"},"redirectURL":"","headersSize":0,"bodySize":0},"cache":{},"timings":{"blocked":0,"dns":0,"connect":0,"ssl":0,"send":0,"wait":0,"receive":72},"serverIpAddress":"0.0.0.0"}]}}