Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes
Showing
with 955 additions and 327 deletions
......@@ -4,43 +4,51 @@ import pytest
from qim3d.segmentation._connected_components import get_3d_cc
@pytest.fixture(scope="module")
@pytest.fixture(scope='module')
def setup_data():
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[1,1,0,0,1,0],
[0,0,0,1,0,0]])
components = np.array(
[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [1, 1, 0, 0, 1, 0], [0, 0, 0, 1, 0, 0]]
)
num_components = 4
connected_components = get_3d_cc(components)
return connected_components, components, num_components
def test_connected_components_property(setup_data):
connected_components, _, _ = setup_data
components = np.array([[0,0,1,1,0,0],
[0,0,0,1,0,0],
[2,2,0,0,3,0],
[0,0,0,4,0,0]])
components = np.array(
[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [2, 2, 0, 0, 3, 0], [0, 0, 0, 4, 0, 0]]
)
assert np.array_equal(connected_components.get_cc(), components)
def test_num_connected_components_property(setup_data):
connected_components, _, num_components = setup_data
assert len(connected_components) == num_components
def test_get_connected_component_with_index(setup_data):
connected_components, _, _ = setup_data
expected_component = np.array([[0,0,1,1,0,0],
expected_component = np.array(
[
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
[0,0,0,0,0,0]], dtype=bool)
[0, 0, 0, 0, 0, 0],
],
dtype=bool,
)
print(connected_components.get_cc(index=1))
print(expected_component)
assert np.array_equal(connected_components.get_cc(index=1), expected_component)
def test_get_connected_component_without_index(setup_data):
connected_components, _, _ = setup_data
component = connected_components.get_cc()
assert np.any(component)
def test_get_connected_component_with_invalid_index(setup_data):
connected_components, _, num_components = setup_data
with pytest.raises(AssertionError):
......
import qim3d
import albumentations
import pytest
import qim3d
# unit tests for Augmentation()
def test_augmentation():
augment_class = qim3d.ml.Augmentation()
assert augment_class.resize == "crop"
assert augment_class.resize == 'crop'
def test_augment():
......@@ -20,7 +21,7 @@ def test_augment():
# unit tests for ValueErrors in Augmentation()
def test_resize():
resize_str = "not valid resize"
resize_str = 'not valid resize'
with pytest.raises(
ValueError,
......@@ -32,7 +33,7 @@ def test_resize():
def test_levels():
augment_class = qim3d.ml.Augmentation()
level = "Not a valid level"
level = 'Not a valid level'
with pytest.raises(
ValueError,
......
import qim3d
import pytest
from torch.utils.data.dataloader import DataLoader
import qim3d
from qim3d.tests import temp_data
# unit tests for Dataset()
def test_dataset():
img_shape = (32, 32)
......@@ -19,19 +20,29 @@ def test_dataset():
# unit tests for check_resize()
def test_check_resize():
h_adjust,w_adjust = qim3d.ml._data.check_resize(240,240,resize = 'crop',n_channels = 6)
h_adjust, w_adjust = qim3d.ml._data.check_resize(
240, 240, resize='crop', n_channels=6
)
assert (h_adjust, w_adjust) == (192, 192)
def test_check_resize_pad():
h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'padding',n_channels = 6)
h_adjust, w_adjust = qim3d.ml._data.check_resize(
16, 16, resize='padding', n_channels=6
)
assert (h_adjust, w_adjust) == (64, 64)
def test_check_resize_fail():
with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."):
h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'crop',n_channels = 6)
def test_check_resize_fail():
with pytest.raises(
ValueError,
match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.",
):
h_adjust, w_adjust = qim3d.ml._data.check_resize(
16, 16, resize='crop', n_channels=6
)
# unit tests for prepare_datasets()
......@@ -44,9 +55,15 @@ def test_prepare_datasets():
my_model = qim3d.ml.models.UNet()
my_augmentation = qim3d.ml.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,validation,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, validation, my_model, my_augmentation
)
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
assert (len(train_set), len(val_set), len(test_set)) == (
int((1 - validation) * n),
int(n * validation),
n,
)
temp_data(folder, remove=True)
......@@ -55,8 +72,12 @@ def test_prepare_datasets():
def test_validation():
validation = 10
with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."):
augment_class = qim3d.ml.prepare_datasets('folder',validation,'my_model','my_augmentation')
with pytest.raises(
ValueError, match='The validation fraction must be a float between 0 and 1.'
):
augment_class = qim3d.ml.prepare_datasets(
'folder', validation, 'my_model', 'my_augmentation'
)
# unit test for prepare_dataloaders()
......@@ -67,11 +88,13 @@ def test_prepare_dataloaders():
batch_size = 1
my_model = qim3d.ml.models.UNet()
my_augmentation = qim3d.ml.Augmentation()
train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,1/3,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, 1 / 3, my_model, my_augmentation
)
_,val_loader,_ = qim3d.ml.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1,
pin_memory = False)
_, val_loader, _ = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size, num_workers=1, pin_memory=False
)
assert type(val_loader) == DataLoader
......
import qim3d
doi = "https://doi.org/10.1007/s10851-021-01041-3"
doi = 'https://doi.org/10.1007/s10851-021-01041-3'
def test_get_bibtex():
bibtext = qim3d.utils._doi.get_bibtex(doi)
assert "Measuring Shape Relations Using r-Parallel Sets" in bibtext
assert 'Measuring Shape Relations Using r-Parallel Sets' in bibtext
def test_get_reference():
reference = qim3d.utils._doi.get_reference(doi)
assert "Stephensen" in reference
assert 'Stephensen' in reference
import qim3d
import os
import re
from pathlib import Path
import qim3d
def test_mock_plot():
fig = qim3d.tests.mock_plot()
......@@ -11,12 +12,12 @@ def test_mock_plot():
def test_mock_write_file():
filename = "test.txt"
content = "test file"
filename = 'test.txt'
content = 'test file'
qim3d.tests.mock_write_file(filename, content=content)
# Check contents
with open(filename, "r", encoding="utf-8") as f:
with open(filename, encoding='utf-8') as f:
file_content = f.read()
# Remove temp file
......@@ -27,7 +28,7 @@ def test_mock_write_file():
def test_get_local_ip():
def validate_ip(ip_str):
reg = r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$"
reg = r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$'
if re.match(reg, ip_str):
return True
else:
......@@ -40,7 +41,7 @@ def test_get_local_ip():
def test_stringify_path1():
"""Test that the function converts os.PathLike objects to strings"""
blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif"
blobs_path = Path(qim3d.__file__).parents[0] / 'img_examples' / 'blobs_256x256.tif'
assert str(blobs_path) == qim3d.utils._misc.stringify_path(blobs_path)
......@@ -48,6 +49,6 @@ def test_stringify_path1():
def test_stringify_path2():
"""Test that the function returns input unchanged if input is a string"""
# Create test_path
test_path = os.path.join("this", "path", "doesnt", "exist.tif")
test_path = os.path.join('this', 'path', 'doesnt', 'exist.tif')
assert test_path == qim3d.utils._misc.stringify_path(test_path)
import qim3d
def test_memory():
mem = qim3d.utils.Memory()
assert all([
mem.total>0,
mem.used>0,
mem.free>0])
assert all([mem.total > 0, mem.used > 0, mem.free > 0])
assert mem.used + mem.free == mem.total
import pytest
import torch
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pytest
from torch import ones
import torch
import qim3d
from qim3d.tests import temp_data
import matplotlib.pyplot as plt
import ipywidgets as widgets
# unit tests for grid overview
def test_grid_overview():
......@@ -26,13 +21,13 @@ def test_grid_overview():
def test_grid_overview_tuple():
random_tuple = (torch.ones(256, 256), torch.ones(256, 256))
with pytest.raises(ValueError, match="Data elements must be tuples"):
with pytest.raises(ValueError, match='Data elements must be tuples'):
qim3d.viz.grid_overview(random_tuple, num_images=1)
# unit tests for grid prediction
def test_grid_pred():
folder = "folder_data"
folder = 'folder_data'
n = 4
temp_data(folder, n=n)
......@@ -57,8 +52,8 @@ def test_slices_numpy_array_input():
def test_slices_wrong_input_format():
input = "not_a_volume"
with pytest.raises(ValueError, match="Data type not supported"):
input = 'not_a_volume'
with pytest.raises(ValueError, match='Data type not supported'):
qim3d.viz.slices_grid(input)
......@@ -66,7 +61,7 @@ def test_slices_not_volume():
example_volume = np.ones((10, 10))
with pytest.raises(
ValueError,
match="The provided object is not a volume as it has less than 3 dimensions.",
match='The provided object is not a volume as it has less than 3 dimensions.',
):
qim3d.viz.slices_grid(example_volume)
......@@ -77,7 +72,7 @@ def test_slices_wrong_position_format1():
ValueError,
match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".',
):
qim3d.viz.slices_grid(example_volume, slice_positions="invalid_slice")
qim3d.viz.slices_grid(example_volume, slice_positions='invalid_slice')
def test_slices_wrong_position_format2():
......@@ -110,7 +105,7 @@ def test_slices_invalid_axis_value():
def test_slices_interpolation_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
interpolation_method = "bilinear"
interpolation_method = 'bilinear'
fig = qim3d.viz.slices_grid(
example_volume,
num_slices=1,
......@@ -130,7 +125,9 @@ def test_slices_multiple_slices():
example_volume = np.ones((10, 10, 10))
image_width = 3
num_slices = 3
fig = qim3d.viz.slices_grid(example_volume, num_slices=num_slices, image_width=image_width)
fig = qim3d.viz.slices_grid(
example_volume, num_slices=num_slices, image_width=image_width
)
# Add assertions for the expected number of subplots in the figure
assert len(fig.get_axes()) == num_slices
......@@ -192,7 +189,7 @@ def test_slicer_with_different_parameters():
assert isinstance(slicer_obj, widgets.interactive)
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
for cmap in ['viridis', 'gray', 'plasma']:
slicer_obj = qim3d.viz.slicer(np.random.rand(10, 10, 10), cmap=cmap)
assert isinstance(slicer_obj, widgets.interactive)
......@@ -232,14 +229,18 @@ def test_orthogonal_with_torch_tensor():
def test_orthogonal_with_different_parameters():
# Test with different colormaps
for color_map in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.slicer_orthogonal(np.random.rand(10, 10, 10), color_map=color_map)
for color_map in ['viridis', 'gray', 'plasma']:
orthogonal_obj = qim3d.viz.slicer_orthogonal(
np.random.rand(10, 10, 10), color_map=color_map
)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with different image sizes
for image_height, image_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.slicer_orthogonal(
np.random.rand(10, 10, 10), image_height=image_height, image_width=image_width
np.random.rand(10, 10, 10),
image_height=image_height,
image_width=image_width,
)
assert isinstance(orthogonal_obj, widgets.HBox)
......@@ -266,7 +267,7 @@ def test_orthogonal_slider_description():
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.slicer_orthogonal(vol)
for idx, slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].description == ["Z", "Y", "X"][idx]
assert slicer.children[0].description == ['Z', 'Y', 'X'][idx]
# unit tests for local thickness visualization
......
......@@ -15,5 +15,7 @@ def test_plot_metrics():
def test_plot_metrics_labels():
metrics = {'epoch_loss': [0.3, 0.2, 0.1], 'batch_loss': [0.3, 0.2, 0.1]}
with pytest.raises(ValueError,match="The number of metrics doesn't match the number of labels."):
with pytest.raises(
ValueError, match="The number of metrics doesn't match the number of labels."
):
qim3d.viz.plot_metrics(metrics, labels=['a', 'b'])
from ._doi import *
from ._system import Memory
from ._logger import log
from ._misc import (
get_local_ip,
port_from_str,
gradio_header,
sizeof,
downscale_img,
get_css,
get_file_size,
get_local_ip,
get_port_dict,
get_css,
downscale_img,
gradio_header,
port_from_str,
scale_to_float16,
sizeof,
)
from ._server import start_http_server
from ._system import Memory
import argparse
from qim3d.gui import data_explorer, iso3d, annotation_tool, local_thickness, layers2d
from qim3d.gui import annotation_tool, data_explorer, iso3d, layers2d, local_thickness
def main():
parser = argparse.ArgumentParser(description='qim3d command-line interface.')
......@@ -8,10 +10,16 @@ def main():
# subcommands
gui_parser = subparsers.add_parser('gui', help='Graphical User Interfaces.')
gui_parser.add_argument('--data-explorer', action='store_true', help='Run data explorer.')
gui_parser.add_argument(
'--data-explorer', action='store_true', help='Run data explorer.'
)
gui_parser.add_argument('--iso3d', action='store_true', help='Run iso3d.')
gui_parser.add_argument('--annotation-tool', action='store_true', help='Run annotation tool.')
gui_parser.add_argument('--local-thickness', action='store_true', help='Run local thickness tool.')
gui_parser.add_argument(
'--annotation-tool', action='store_true', help='Run annotation tool.'
)
gui_parser.add_argument(
'--local-thickness', action='store_true', help='Run local thickness tool.'
)
gui_parser.add_argument('--layers2d', action='store_true', help='Run layers2d.')
gui_parser.add_argument('--host', default='0.0.0.0', help='Desired host.')
......@@ -20,7 +28,6 @@ def main():
if args.subcommand == 'gui':
arghost = args.host
if args.data_explorer:
data_explorer.run_interface(arghost)
elif args.iso3d:
......@@ -35,5 +42,6 @@ def main():
elif args.layers2d:
layers2d.run_interface(arghost)
if __name__ == '__main__':
main()
"""Deals with DOI for references"""
import json
import requests
from qim3d.utils._logger import log
def _validate_response(response: requests.Response) -> bool:
# Check if we got a good response
if not response.ok:
log.error(f"Could not read the provided DOI ({response.reason})")
log.error(f'Could not read the provided DOI ({response.reason})')
return False
return True
def _doi_to_url(doi: str) -> str:
if doi[:3] != "http":
url = "https://doi.org/" + doi
if doi[:3] != 'http':
url = 'https://doi.org/' + doi
else:
url = doi
......@@ -52,12 +55,14 @@ def _log_and_get_text(doi, header) -> str:
def get_bibtex(doi: str):
"""Generates bibtex from doi"""
header = {"Accept": "application/x-bibtex"}
header = {'Accept': 'application/x-bibtex'}
return _log_and_get_text(doi, header)
def custom_header(doi: str, header: str) -> str:
"""Allows a custom header to be passed
"""
Allows a custom header to be passed
Example:
import qim3d
......@@ -68,15 +73,17 @@ def custom_header(doi: str, header: str) -> str:
"""
return _log_and_get_text(doi, header)
def get_metadata(doi: str) -> dict:
"""Generates a metadata dictionary from doi"""
header = {"Accept": "application/vnd.citationstyles.csl+json"}
header = {'Accept': 'application/vnd.citationstyles.csl+json'}
response = _make_request(doi, header)
metadata = json.loads(response.text)
return metadata
def get_reference(doi: str) -> str:
"""Generates a metadata dictionary from doi and use it to build a reference string"""
......@@ -85,15 +92,18 @@ def get_reference(doi: str) -> str:
return reference_string
def build_reference_string(metadata: dict) -> str:
"""Generates a reference string from metadata"""
authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']])
authors = ', '.join(
[f"{author['family']} {author['given']}" for author in metadata['author']]
)
year = metadata['issued']['date-parts'][0][0]
title = metadata['title']
publisher = metadata['publisher']
url = metadata['URL']
doi = metadata['DOI']
reference_string = f"{authors} ({year}). {title}. {publisher} ({url}). DOI: {doi}"
reference_string = f'{authors} ({year}). {title}. {publisher} ({url}). DOI: {doi}'
return reference_string
......@@ -2,20 +2,23 @@
import logging
logger = logging.getLogger("qim3d")
logger = logging.getLogger('qim3d')
def set_detailed_output():
"""Configures the logging output to display detailed information.
"""
Configures the logging output to display detailed information.
This function sets up a logging formatter with a specific format that
includes the log level, filename, line number, and log message.
Example:
>>> set_detailed_output()
"""
formatter = logging.Formatter(
"%(levelname)-10s%(filename)s:%(lineno)-5s%(message)s"
'%(levelname)-10s%(filename)s:%(lineno)-5s%(message)s'
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
......@@ -32,8 +35,9 @@ def set_simple_output():
Example:
>>> set_simple_output()
"""
formatter = logging.Formatter("%(message)s")
formatter = logging.Formatter('%(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.handlers = []
......@@ -41,55 +45,66 @@ def set_simple_output():
def set_level_debug():
"""Sets the logging level of the "qim3d" logger to DEBUG.
"""
Sets the logging level of the "qim3d" logger to DEBUG.
Example:
>>> set_level_debug()
"""
logger.setLevel(logging.DEBUG)
def set_level_info():
"""Sets the logging level of the "qim3d" logger to INFO.
"""
Sets the logging level of the "qim3d" logger to INFO.
Example:
>>> set_level_info()
"""
logger.setLevel(logging.INFO)
def set_level_warning():
"""Sets the logging level of the "qim3d" logger to WARNING.
"""
Sets the logging level of the "qim3d" logger to WARNING.
Example:
>>> set_level_warning()
"""
logger.setLevel(logging.WARNING)
def set_level_error():
"""Sets the logging level of the "qim3d" logger to ERROR.
"""
Sets the logging level of the "qim3d" logger to ERROR.
Example:
>>> set_level_error()
"""
logger.setLevel(logging.ERROR)
def set_level_critical():
"""Sets the logging level of the "qim3d" logger to CRITICAL.
"""
Sets the logging level of the "qim3d" logger to CRITICAL.
Example:
>>> set_level_critical()
"""
logger.setLevel(logging.CRITICAL)
def level(log_level):
"""Set the logging level based on the specified level.
"""
Set the logging level based on the specified level.
Args:
log_level (str or int): The logging level to set. It can be one of:
......@@ -104,19 +119,19 @@ def level(log_level):
ValueError: If the specified level is not a valid logging level.
"""
if log_level in ["DEBUG", "debug"]:
if log_level in ['DEBUG', 'debug']:
set_level_debug()
elif log_level in ["INFO", "info"]:
elif log_level in ['INFO', 'info']:
set_level_info()
elif log_level in ["WARNING", "warning"]:
elif log_level in ['WARNING', 'warning']:
set_level_warning()
elif log_level in ["ERROR", "error"]:
elif log_level in ['ERROR', 'error']:
set_level_error()
elif log_level in ["CRITICAL", "critical"]:
elif log_level in ['CRITICAL', 'critical']:
set_level_critical()
elif isinstance(log_level, int):
......@@ -130,6 +145,6 @@ def level(log_level):
# Create the logger
log = logging.getLogger("qim3d")
log = logging.getLogger('qim3d')
set_level_info()
set_simple_output()
"""Provides a collection of internal utility functions."""
import difflib
import getpass
import hashlib
import os
import socket
import numpy as np
import outputformat as ouf
import requests
from scipy.ndimage import zoom
import difflib
import qim3d
def get_local_ip() -> str:
"""Retrieves the local IP address of the current machine.
"""
Retrieves the local IP address of the current machine.
The function uses a socket to determine the local IP address.
Then, it tries to connect to the IP address "192.255.255.255"
......@@ -23,20 +26,21 @@ def get_local_ip() -> str:
network is not available, the function falls back to returning
the loopback address "127.0.0.1".
Returns:
Returns
str: The local IP address.
Example usage:
ip_address = get_local_ip()
"""
_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
_socket.connect(("192.255.255.255", 1))
_socket.connect(('192.255.255.255', 1))
ip_address = _socket.getsockname()[0]
except socket.error:
ip_address = "127.0.0.1"
except OSError:
ip_address = '127.0.0.1'
finally:
_socket.close()
return ip_address
......@@ -61,12 +65,14 @@ def port_from_str(s: str) -> int:
Example usage:
port = port_from_str("my_specific_app_name")
"""
return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4)
return int(hashlib.sha1(s.encode('utf-8')).hexdigest(), 16) % (10**4)
def gradio_header(title: str, port: int) -> None:
"""Display the header for a Gradio server.
"""
Display the header for a Gradio server.
Displays a formatted header containing the provided title,
the port number being used, and the IP address where the server is running.
......@@ -93,14 +99,15 @@ def gradio_header(title: str, port: int) -> None:
ouf.br(2)
details = [
f'{ouf.c(title, color="rainbow", cmap="cool", bold=True, return_str=True)}',
f"Using port {port}",
f"Running at {get_local_ip()}",
f'Using port {port}',
f'Running at {get_local_ip()}',
]
ouf.showlist(details, style="box", title="Starting gradio server")
ouf.showlist(details, style='box', title='Starting gradio server')
def sizeof(num: float, suffix: str = "B") -> str:
"""Converts a number to a human-readable string representing its size.
def sizeof(num: float, suffix: str = 'B') -> str:
"""
Converts a number to a human-readable string representing its size.
Converts the given number to a human-readable string representing its size in
a more readable format, such as bytes (B), kilobytes (KB), megabytes (MB),
......@@ -124,17 +131,18 @@ def sizeof(num: float, suffix: str = "B") -> str:
'1.0 KB'
>>> qim3d.utils.sizeof(1234567890)
'1.1 GB'
"""
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(num) < 1024.0:
return f"{num:3.1f} {unit}{suffix}"
return f'{num:3.1f} {unit}{suffix}'
num /= 1024.0
return f"{num:.1f} Y{suffix}"
return f'{num:.1f} Y{suffix}'
def find_similar_paths(path: str) -> list[str]:
parent_dir = os.path.dirname(path) or "."
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
parent_dir = os.path.dirname(path) or '.'
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ''
valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
similar_paths = difflib.get_close_matches(path, valid_paths)
......@@ -144,12 +152,13 @@ def find_similar_paths(path: str) -> list[str]:
def get_file_size(file_path: str) -> int:
"""
Args:
-----
----
filename (str): Specifies full path to file
Returns:
---------
-------
size (int): size of file in bytes
"""
try:
file_size = os.path.getsize(file_path)
......@@ -161,7 +170,7 @@ def get_file_size(file_path: str) -> int:
message = f"Invalid path. Did you mean '{suggestion}'?"
raise FileNotFoundError(repr(message))
else:
raise FileNotFoundError("Invalid path")
raise FileNotFoundError('Invalid path')
return file_size
......@@ -176,7 +185,7 @@ def stringify_path(path: os.PathLike) -> str:
def get_port_dict() -> dict:
# Gets user and port
username = getpass.getuser()
url = f"https://platform.qim.dk/qim-api/get-port/{username}"
url = f'https://platform.qim.dk/qim-api/get-port/{username}'
response = requests.get(url, timeout=10)
# Check if the request was successful (status code 200)
......@@ -185,25 +194,25 @@ def get_port_dict() -> dict:
port_dict = response.json()
else:
# Print an error message if the request was not successful
raise (f"Error: {response.status_code}")
raise (f'Error: {response.status_code}')
return port_dict
def get_css() -> str:
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
css_path = os.path.join(parent_directory, "css", "gradio.css")
css_path = os.path.join(parent_directory, 'css', 'gradio.css')
with open(css_path, "r") as file:
with open(css_path) as file:
css_content = file.read()
return css_content
def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray:
"""Downscale image if total number of voxels exceeds 512³.
"""
Downscale image if total number of voxels exceeds 512³.
Args:
img (np.Array): Input image.
......@@ -211,6 +220,7 @@ def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray:
Returns:
np.Array: Downscaled image if total number of voxels exceeds 512³.
"""
# Calculate total number of pixels in the image
......@@ -231,10 +241,10 @@ def scale_to_float16(arr: np.ndarray) -> np.ndarray:
"""
Scale the input array to the float16 data type.
Parameters:
Parameters
arr (np.ndarray): Input array to be scaled.
Returns:
Returns
np.ndarray: Scaled array with dtype=np.float16.
This function scales the input array to the float16 data type, ensuring that the
......@@ -242,6 +252,7 @@ def scale_to_float16(arr: np.ndarray) -> np.ndarray:
for float16. If the maximum value of the input array exceeds the maximum
representable value for float16, the array is scaled down proportionally
to fit within the float16 range.
"""
# Get the maximum value to comprare with the float16 maximum value
......
from zarr.util import normalize_chunks, normalize_dtype, normalize_shape
import numpy as np
from zarr.util import normalize_chunks, normalize_dtype, normalize_shape
def get_chunk_size(shape: tuple, dtype: tuple) -> tuple[int, ...]:
"""
......@@ -10,6 +11,7 @@ def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]:
----------
- shape: shape of the data
- dtype: dtype of the data
"""
object_codec = None
dtype, object_codec = normalize_dtype(dtype, object_codec)
......@@ -29,6 +31,7 @@ def get_n_chunks(shapes: tuple, dtypes: tuple) -> int:
----------
- shapes: list of shapes of the data for each scale
- dtype: dtype of the data
"""
n_chunks = 0
for shape, dtype in zip(shapes, dtypes):
......@@ -37,4 +40,3 @@ def get_n_chunks(shapes: tuple, dtypes: tuple) -> int:
ratio = shape / chunk_size
n_chunks += np.prod(ratio)
return int(n_chunks)
from threading import Timer
import psutil
import sys
import os
import sys
from abc import ABC, abstractmethod
from threading import Timer
import psutil
from tqdm.auto import tqdm
from qim3d.utils._misc import get_file_size
class RepeatTimer(Timer):
"""
If the memory check is set as a normal thread, there is no garuantee it will switch
resulting in not enough memory checks to create smooth progress bar or to make it
......@@ -23,6 +24,7 @@ class RepeatTimer(Timer):
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)
class ProgressBar(ABC):
def __init__(self, tqdm_kwargs: dict, repeat_time: float, *args, **kwargs):
"""
......@@ -32,10 +34,11 @@ class ProgressBar(ABC):
Thus we have to run parallel thread with forced activation to check the state
Parameters:
------------
Parameters
----------
- tqdm_kwargs (dict): Passed directly to tqdm constructor
- repeat_time (float): How often the timer runs the function (in seconds)
"""
self.timer = RepeatTimer(repeat_time, self.update_pbar)
self.pbar = tqdm(**tqdm_kwargs)
......@@ -47,16 +50,13 @@ class ProgressBar(ABC):
try:
self.pbar.update(update)
except (
AttributeError
): # When we leave the context manager, we delete the pbar so it can not be updated anymore
except AttributeError: # When we leave the context manager, we delete the pbar so it can not be updated anymore
# It's because it takes quite a long time for the timer to end and might update the pbar
# one more time before ending which messes up the whole thing
pass
self.last_update = new_update
@abstractmethod
def get_new_update(self):
pass
......@@ -72,28 +72,28 @@ class ProgressBar(ABC):
del self.pbar # So the update process can not update it anymore
class FileLoadingProgressBar(ProgressBar):
def __init__(self, filename: str, repeat_time: float = 0.5, *args, **kwargs):
"""
Context manager ('with' statement) to track progress during loading a file into memory
Parameters:
------------
Parameters
----------
- filename (str): to get size of the file
- repeat_time (float, optional): How often the timer checks how many bytes were loaded. Even if very small,
it doesn't make the progress bar smoother as there are only few visible changes in number of read_chars.
Defaults to 0.5
"""
tqdm_kwargs = dict(
total=get_file_size(filename),
desc="Loading: ",
unit="B",
desc='Loading: ',
unit='B',
file=sys.stdout,
unit_scale=True,
unit_divisor=1024,
bar_format="{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, "
"{rate_fmt}{postfix}]",
bar_format='{l_bar}{bar}| {n_fmt}{unit}/{total_fmt}{unit} [{elapsed}<{remaining}, '
'{rate_fmt}{postfix}]',
)
super().__init__(tqdm_kwargs, repeat_time)
self.process = psutil.Process()
......@@ -106,8 +106,9 @@ class FileLoadingProgressBar(ProgressBar):
memory = counters.read_bytes + counters.other_bytes
return memory
class OmeZarrExportProgressBar(ProgressBar):
def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"):
def __init__(self, path: str, n_chunks: int, reapeat_time: str = 'auto'):
"""
Context manager to track the exporting of OmeZarr files.
......@@ -120,12 +121,11 @@ class OmeZarrExportProgressBar(ProgressBar):
repeat_time : int or float, optional
The interval (in seconds) for updating the progress bar. Defaults to "auto", which
sets the update frequency based on the number of chunks.
"""
"""
# Calculate the repeat time for the progress bar
if reapeat_time == "auto":
if reapeat_time == 'auto':
# Approximate the repeat time based on the number of chunks
# This ratio is based on reading the HOA dataset over the network:
# 620,000 files took 300 seconds to read
......@@ -142,11 +142,7 @@ class OmeZarrExportProgressBar(ProgressBar):
self.path = path
tqdm_kwargs = dict(
total = n_chunks,
unit = "Chunks",
desc = "Saving",
unit_scale = True
total=n_chunks, unit='Chunks', desc='Saving', unit_scale=True
)
super().__init__(tqdm_kwargs, reapeat_time)
self.last_update = 0
......@@ -162,7 +158,7 @@ class OmeZarrExportProgressBar(ProgressBar):
new_path = os.path.join(folder_path, path)
if os.path.isfile(new_path):
filename = os.path.basename(os.path.normpath(new_path))
if not filename.startswith("."):
if not filename.startswith('.'):
count += 1
else:
count += file_count(new_path)
......
import os
from http.server import SimpleHTTPRequestHandler, HTTPServer
import threading
from http.server import HTTPServer, SimpleHTTPRequestHandler
from qim3d.utils._logger import log
class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
def end_headers(self):
"""Add CORS headers to each response."""
# Allow requests from any origin, or restrict to specific domains by specifying the origin
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header('Access-Control-Allow-Origin', '*')
# Allow specific methods
self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
# Allow specific headers (if needed)
self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
self.send_header(
'Access-Control-Allow-Headers', 'X-Requested-With, Content-Type'
)
super().end_headers()
def list_directory(self, path: str | os.PathLike):
......@@ -19,7 +23,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
try:
file_list = os.listdir(path)
except OSError:
self.send_error(404, "No permission to list directory")
self.send_error(404, 'No permission to list directory')
return None
# Sort the file list
......@@ -28,40 +32,42 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
# Format the list with hidden files included
displaypath = os.path.basename(path)
r = ['<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">']
r.append(f"<html>\n<title>Directory listing for {displaypath}</title>\n")
r.append(f"<body>\n<h2>Directory listing for {displaypath}</h2>\n")
r.append("<hr>\n<ul>")
r.append(f'<html>\n<title>Directory listing for {displaypath}</title>\n')
r.append(f'<body>\n<h2>Directory listing for {displaypath}</h2>\n')
r.append('<hr>\n<ul>')
for name in file_list:
fullname = os.path.join(path, name)
displayname = linkname = name
# Append the files and directories to the HTML list
if os.path.isdir(fullname):
displayname = name + "/"
linkname = name + "/"
displayname = name + '/'
linkname = name + '/'
r.append(f'<li><a href="{linkname}">{displayname}</a></li>')
r.append("</ul>\n<hr>\n</body>\n</html>\n")
encoded = "\n".join(r).encode('utf-8', 'surrogateescape')
r.append('</ul>\n<hr>\n</body>\n</html>\n')
encoded = '\n'.join(r).encode('utf-8', 'surrogateescape')
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(encoded)))
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(encoded)))
self.end_headers()
# Write the encoded HTML directly to the response
self.wfile.write(encoded)
def start_http_server(directory: str, port: int = 8000) -> HTTPServer:
"""
Starts an HTTP server serving the specified directory on the given port with CORS enabled.
Parameters:
Parameters
directory (str): The directory to serve.
port (int): The port number to use (default is 8000).
"""
# Change the working directory to the specified directory
os.chdir(directory)
# Create the server
server = HTTPServer(("", port), CustomHTTPRequestHandler)
server = HTTPServer(('', port), CustomHTTPRequestHandler)
# Run the server in a separate thread so it doesn't block execution
thread = threading.Thread(target=server.serve_forever)
......
"""Provides tools for obtaining information about the system."""
import os
import time
import numpy as np
import psutil
from qim3d.utils._misc import sizeof
from qim3d.utils._logger import log
import numpy as np
from qim3d.utils._misc import sizeof
class Memory:
"""Class for obtaining current memory information
Attributes:
"""
Class for obtaining current memory information
Attributes
total (int): Total system memory in bytes
free (int): Free system memory in bytes
used (int): Used system memory in bytes
used_pct (float): Used system memory in percentage
"""
def __init__(self):
......@@ -28,7 +34,7 @@ class Memory:
def report(self):
log.info(
"System memory:\n • Total.: %s\n • Used..: %s (%s%%)\n • Free..: %s (%s%%)",
'System memory:\n • Total.: %s\n • Used..: %s (%s%%)\n • Free..: %s (%s%%)',
sizeof(self.total),
sizeof(self.used),
round(self.used_pct, 1),
......@@ -36,8 +42,11 @@ class Memory:
round(self.free_pct, 1),
)
def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]:
'''
def _test_disk_speed(
file_size_bytes: int = 1024, ntimes: int = 10
) -> tuple[float, float, float, float]:
"""
Test the write and read speed of the disk by writing a file of a given size
and then reading it back.
......@@ -56,7 +65,8 @@ def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[floa
print(f"Write speed: {write_speed:.2f} GB/s")
print(f"Read speed: {read_speed:.2f} GB/s")
```
'''
"""
write_speeds = []
read_speeds = []
......@@ -96,7 +106,7 @@ def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[floa
def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None:
'''
"""
Report the average write and read speed of the disk.
Args:
......@@ -108,16 +118,19 @@ def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None:
qim3d.io.logger.level("info")
qim3d.utils.system.disk_report()
```
'''
"""
# Test disk speed
avg_write_speed, write_speed_std, avg_read_speed, read_speed_std = _test_disk_speed(file_size_bytes=file_size, ntimes=ntimes)
avg_write_speed, write_speed_std, avg_read_speed, read_speed_std = _test_disk_speed(
file_size_bytes=file_size, ntimes=ntimes
)
# Print disk information
log.info(
"Disk:\n • Write speed..: %.2f GB/s (± %.2f GB/s)\n • Read speed...: %.2f GB/s (± %.2f GB/s)",
'Disk:\n • Write speed..: %.2f GB/s (± %.2f GB/s)\n • Read speed...: %.2f GB/s (± %.2f GB/s)',
avg_write_speed,
write_speed_std,
avg_read_speed,
read_speed_std
read_speed_std,
)
from . import colormaps
from . import _layers2d, colormaps
from ._cc import plot_cc
from ._detection import circles
from ._data_exploration import (
chunks,
fade_mask,
histogram,
line_profile,
slicer,
slicer_orthogonal,
slices_grid,
chunks,
histogram,
threshold,
)
from .itk_vtk_viewer import itk_vtk
from ._k3d import volumetric, mesh
from ._detection import circles
from ._k3d import mesh, volumetric
from ._local_thickness import local_thickness
from ._structure_tensor import vectors
from ._metrics import plot_metrics, grid_overview, grid_pred, vol_masked
from ._metrics import grid_overview, grid_pred, plot_metrics, vol_masked
from ._preview import image_preview
from . import _layers2d
from ._structure_tensor import vectors
from .itk_vtk_viewer import itk_vtk
import matplotlib.pyplot as plt
import numpy as np
import qim3d
from qim3d.utils._logger import log
from qim3d.segmentation._connected_components import CC
from qim3d.utils._logger import log
def plot_cc(
connected_components: CC,
......@@ -11,7 +13,7 @@ def plot_cc(
overlay: np.ndarray = None,
crop: bool = False,
display_figure: bool = True,
color_map: str = "viridis",
color_map: str = 'viridis',
value_min: float = None,
value_max: float = None,
**kwargs,
......@@ -19,7 +21,7 @@ def plot_cc(
"""
Plots the connected components from a `qim3d.processing.cc.CC` object. If an overlay image is provided, the connected component will be masked to the overlay image.
Parameters:
Parameters
connected_components (CC): The connected components object.
component_indexs (list or tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
......@@ -31,7 +33,7 @@ def plot_cc(
value_max (float or None, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
**kwargs (Any): Additional keyword arguments to pass to `qim3d.viz.slices_grid`.
Returns:
Returns
figs (list[plt.Figure]): List of figures, if `display_figure=False`.
Example:
......@@ -45,12 +47,13 @@ def plot_cc(
```
![plot_cc_no_overlay](../../assets/screenshots/plot_cc_no_overlay.png)
![plot_cc_overlay](../../assets/screenshots/plot_cc_overlay.png)
"""
# if no components are given, plot the first max_cc_to_plot=32 components
if component_indexs is None:
if len(connected_components) > max_cc_to_plot:
log.warning(
f"More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components."
f'More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components.'
)
component_indexs = range(
1, min(max_cc_to_plot + 1, len(connected_components) + 1)
......@@ -61,7 +64,7 @@ def plot_cc(
if overlay is not None:
assert (
overlay.shape == connected_components.shape
), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
), f'Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}.'
# plots overlay masked to connected component
if crop:
......@@ -75,16 +78,25 @@ def plot_cc(
cc = connected_components.get_cc(component, crop=False)
overlay_crop = np.where(cc == 0, 0, overlay)
fig = qim3d.viz.slices_grid(
overlay_crop, display_figure=display_figure, color_map=color_map, value_min=value_min, value_max=value_max, **kwargs
overlay_crop,
display_figure=display_figure,
color_map=color_map,
value_min=value_min,
value_max=value_max,
**kwargs,
)
else:
# assigns discrete color map to each connected component if not given
if "color_map" not in kwargs:
kwargs["color_map"] = qim3d.viz.colormaps.segmentation(len(component_indexs))
if 'color_map' not in kwargs:
kwargs['color_map'] = qim3d.viz.colormaps.segmentation(
len(component_indexs)
)
# Plot the connected component without overlay
fig = qim3d.viz.slices_grid(
connected_components.get_cc(component, crop=crop), display_figure=display_figure, **kwargs
connected_components.get_cc(component, crop=crop),
display_figure=display_figure,
**kwargs,
)
figs.append(fig)
......
......@@ -4,22 +4,30 @@ Provides a collection of visualization functions.
import math
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib
import matplotlib.figure
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import SVG, display
import matplotlib
import numpy as np
import zarr
from qim3d.utils._logger import log
import seaborn as sns
import skimage.measure
from skimage.filters import (
threshold_otsu,
threshold_isodata,
threshold_li,
threshold_mean,
threshold_minimum,
threshold_triangle,
threshold_yen,
)
from IPython.display import clear_output, display
import qim3d
from qim3d.utils._logger import log
def slices_grid(
......@@ -41,7 +49,8 @@ def slices_grid(
color_bar_style: str = "small",
**matplotlib_imshow_kwargs,
) -> matplotlib.figure.Figure:
"""Displays one or several slices from a 3d volume.
"""
Displays one or several slices from a 3d volume.
By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position.
......@@ -82,6 +91,7 @@ def slices_grid(
qim3d.viz.slices_grid(vol, num_slices=15)
```
![Grid of slices](../../assets/screenshots/viz-slices.png)
"""
if image_size:
image_height = image_size
......@@ -330,7 +340,8 @@ def slicer(
color_bar: str = None,
**matplotlib_imshow_kwargs,
) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume.
"""
Interactive widget for visualizing slices of a 3D volume.
Args:
volume (np.ndarray): The 3D volume to be sliced.
......@@ -355,20 +366,21 @@ def slicer(
qim3d.viz.slicer(vol)
```
![viz slicer](../../assets/screenshots/viz-slicer.gif)
"""
if image_size:
image_height = image_size
image_width = image_size
color_bar_options = [None, 'slices', 'volume']
color_bar_options = [None, "slices", "volume"]
if color_bar not in color_bar_options:
raise ValueError(
f"Unrecognized value '{color_bar}' for parameter color_bar. "
f"Expected one of {color_bar_options}."
)
show_color_bar = color_bar is not None
if color_bar == 'slices':
if color_bar == "slices":
# Precompute the minimum and maximum along each slice for faster widget sliding.
non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis)
slice_mins = np.min(volume, axis=non_slice_axes)
......@@ -376,7 +388,7 @@ def slicer(
# Create the interactive widget
def _slicer(slice_positions):
if color_bar == 'slices':
if color_bar == "slices":
dynamic_min = slice_mins[slice_positions]
dynamic_max = slice_maxs[slice_positions]
else:
......@@ -425,7 +437,8 @@ def slicer_orthogonal(
interpolation: Optional[str] = None,
image_size: int = None,
) -> widgets.interactive:
"""Interactive widget for visualizing orthogonal slices of a 3D volume.
"""
Interactive widget for visualizing orthogonal slices of a 3D volume.
Args:
volume (np.ndarray): The 3D volume to be sliced.
......@@ -448,6 +461,7 @@ def slicer_orthogonal(
qim3d.viz.slicer_orthogonal(vol, color_map="magma")
```
![viz slicer_orthogonal](../../assets/screenshots/viz-orthogonal.gif)
"""
if image_size:
......@@ -484,7 +498,8 @@ def fade_mask(
value_min: float = None,
value_max: float = None,
) -> widgets.interactive:
"""Interactive widget for visualizing the effect of edge fading on a 3D volume.
"""
Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask.
......@@ -605,7 +620,8 @@ def fade_mask(
# Create the Checkbox widget
invert_checkbox = widgets.Checkbox(
value=False, description="Invert" # default value
value=False,
description="Invert", # default value
)
slicer_obj = widgets.interactive(
......@@ -644,6 +660,7 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
qim3d.viz.chunks("Escargot.zarr")
```
![chunks-visualization](../../assets/screenshots/chunks_visualization.gif)
"""
# Load the Zarr dataset
......@@ -794,7 +811,6 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def update_coordinate_dropdowns(scale):
disable_observers() # to avoid multiple reload of the visualization when updating the dropdowns
multiscale_shape = zarr_data[scale].shape
......@@ -873,20 +889,22 @@ def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
def histogram(
volume: np.ndarray,
bins: Union[int, str] = "auto",
slice_idx: Union[int, str] = None,
slice_idx: Union[int, str, None] = None,
vertical_line: int = None,
axis: int = 0,
kde: bool = True,
log_scale: bool = False,
despine: bool = True,
show_title: bool = True,
color: str = "qim3d",
edgecolor: str|None = None,
figsize: tuple[float, float] = (8, 4.5),
edgecolor: Optional[str] = None,
figsize: Tuple[float, float] = (8, 4.5),
element: str = "step",
return_fig: bool = False,
show: bool = True,
**sns_kwargs,
) -> None|matplotlib.figure.Figure:
ax: Optional[plt.Axes] = None,
**sns_kwargs: Union[str, float, int, bool]
) -> Optional[Union[plt.Figure, plt.Axes]]:
"""
Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
......@@ -894,54 +912,41 @@ def histogram(
Args:
volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (int or str, optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (int or str or None, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
show_title (bool, optional): If True, displays a title with slice information. Default is True.
color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d".
edgecolor (str, optional): Color for the edges of the histogram bars. Default is None.
figsize (tuple of floats, optional): Size of the figure (width, height). Default is (8, 4.5).
figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5).
element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
**sns_kwargs (Any): Additional keyword arguments for `seaborn.histplot`.
ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
**sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
Returns:
fig (Optional[matplotlib.figure.Figure]): If `return_fig` is True, returns the generated figure object. Otherwise, returns None.
Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
If `return_fig` is True, returns the generated figure object.
If `return_fig` is False and `ax` is provided, returns the `Axes` object.
Otherwise, returns None.
Raises:
ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol)
```
![viz histogram](../../assets/screenshots/viz-histogram-vol.png)
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True)
```
![viz histogram](../../assets/screenshots/viz-histogram-slice.png)
"""
if not (0 <= axis < volume.ndim):
raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == "middle":
slice_idx = volume.shape[axis] // 2
if slice_idx:
if slice_idx is not None:
if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel()
......@@ -954,10 +959,14 @@ def histogram(
data = volume.ravel()
title = f"Intensity histogram for whole volume {volume.shape}"
# Use provided Axes or create new figure
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
if log_scale:
plt.yscale("log")
ax.set_yscale("log")
if color == "qim3d":
color = qim3d.viz.colormaps.qim(1.0)
......@@ -969,13 +978,23 @@ def histogram(
color=color,
element=element,
edgecolor=edgecolor,
ax=ax, # Plot directly on the specified Axes
**sns_kwargs,
)
if vertical_line is not None:
ax.axvline(
x=vertical_line,
color='red',
linestyle="--",
linewidth=2,
)
if despine:
sns.despine(
fig=None,
ax=None,
ax=ax,
top=True,
right=True,
left=False,
......@@ -984,17 +1003,521 @@ def histogram(
trim=True,
)
plt.xlabel("Voxel Intensity")
plt.ylabel("Frequency")
ax.set_xlabel("Voxel Intensity")
ax.set_ylabel("Frequency")
if show_title:
plt.title(title, fontsize=10)
ax.set_title(title, fontsize=10)
# Handle show and return
if show:
if show and fig is not None:
plt.show()
else:
plt.close(fig)
if return_fig:
return fig
elif ax is not None:
return ax
class _LineProfile:
def __init__(
self,
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
):
self.volume = volume
self.slice_axis = slice_axis
self.dims = np.array(volume.shape)
self.pad = 1 # Padding on pivot point to avoid border issues
self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1]
self.initialize_widgets()
self.update_slice_axis(slice_axis)
self.slice_index_widget.value = slice_index
self.x_widget.value = horizontal_position
self.y_widget.value = vertical_position
self.angle_widget.value = angle
self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]]
def update_slice_axis(self, slice_axis):
self.slice_axis = slice_axis
self.slice_index_widget.max = self.volume.shape[slice_axis] - 1
self.slice_index_widget.value = self.volume.shape[slice_axis] // 2
self.x_max, self.y_max = np.delete(self.dims, self.slice_axis) - 1
self.x_widget.max = self.x_max - self.pad
self.x_widget.value = self.x_max // 2
self.y_widget.max = self.y_max - self.pad
self.y_widget.value = self.y_max // 2
def initialize_widgets(self):
layout = widgets.Layout(width="300px", height="auto")
self.x_widget = widgets.IntSlider(
min=self.pad, step=1, description="", layout=layout
)
self.y_widget = widgets.IntSlider(
min=self.pad, step=1, description="", layout=layout
)
self.angle_widget = widgets.IntSlider(
min=0, max=360, step=1, value=0, description="", layout=layout
)
self.line_fraction_widget = widgets.FloatRangeSlider(
min=0, max=1, step=0.01, value=[0, 1], description="", layout=layout
)
self.slice_axis_widget = widgets.Dropdown(
options=[0, 1, 2], value=self.slice_axis, description="Slice axis"
)
self.slice_axis_widget.layout.width = "250px"
self.slice_index_widget = widgets.IntSlider(
min=0, step=1, description="Slice index", layout=layout
)
self.slice_index_widget.layout.width = "400px"
def calculate_line_endpoints(self, x, y, angle):
"""
Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)]
"""
if np.isclose(angle, 0):
return [0, y], [self.x_max, y]
elif np.isclose(angle, np.pi / 2):
return [x, 0], [x, self.y_max]
elif np.isclose(angle, np.pi):
return [self.x_max, y], [0, y]
elif np.isclose(angle, 3 * np.pi / 2):
return [x, self.y_max], [x, 0]
elif np.isclose(angle, 2 * np.pi):
return [0, y], [self.x_max, y]
t_left = -x / np.cos(angle)
t_bottom = -y / np.sin(angle)
t_right = (self.x_max - x) / np.cos(angle)
t_top = (self.y_max - y) / np.sin(angle)
t_values = np.array([t_left, t_top, t_right, t_bottom])
t_pos = np.min(t_values[t_values > 0])
t_neg = np.max(t_values[t_values < 0])
src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)]
dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)]
return src, dst
def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range):
if slice_axis != self.slice_axis:
self.update_slice_axis(slice_axis)
x = self.x_widget.value
y = self.y_widget.value
slice_index = self.slice_index_widget.value
clear_output(wait=True)
image = np.take(self.volume, slice_index, slice_axis)
angle = np.radians(angle_deg)
src, dst = (
np.array(point, dtype="float32")
for point in self.calculate_line_endpoints(x, y, angle)
)
# Rescale endpoints
line_vec = dst - src
dst = src + fraction_range[1] * line_vec
src = src + fraction_range[0] * line_vec
y_pline = skimage.measure.profile_line(image, src, dst)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# Image with color-gradiented line
num_segments = 100
x_seg = np.linspace(src[0], dst[0], num_segments)
y_seg = np.linspace(src[1], dst[1], num_segments)
segments = np.stack(
[
np.column_stack([y_seg[:-2], x_seg[:-2]]),
np.column_stack([y_seg[2:], x_seg[2:]]),
],
axis=1,
)
norm = plt.Normalize(vmin=0, vmax=num_segments - 1)
colors = self.cmap(norm(np.arange(num_segments - 1)))
lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2)
ax[0].imshow(image, cmap="gray")
ax[0].add_collection(lc)
# pivot point
ax[0].plot(y, x, marker="s", linestyle="", color="cyan", markersize=4)
ax[0].set_xlabel(f"axis {np.delete(np.arange(3), self.slice_axis)[1]}")
ax[0].set_ylabel(f"axis {np.delete(np.arange(3), self.slice_axis)[0]}")
# Profile intensity plot
norm = plt.Normalize(0, vmax=len(y_pline) - 1)
x_pline = np.arange(len(y_pline))
points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :]
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = matplotlib.collections.LineCollection(
segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2
)
ax[1].add_collection(lc)
ax[1].autoscale()
ax[1].set_xlabel("Distance along line")
ax[1].grid(True)
plt.tight_layout()
plt.show()
def build_interactive(self):
# Group widgets into two columns
title_style = (
"text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;"
)
title_column1 = widgets.HTML(
f"<div style='{title_style}'>Line parameterization</div>"
)
title_column2 = widgets.HTML(
f"<div style='{title_style}'>Slice selection</div>"
)
# Make label widgets instead of descriptions which have different lengths.
label_layout = widgets.Layout(width="120px")
label_x = widgets.Label("Vertical position", layout=label_layout)
label_y = widgets.Label("Horizontal position", layout=label_layout)
label_angle = widgets.Label("Angle (°)", layout=label_layout)
label_fraction = widgets.Label("Fraction range", layout=label_layout)
row_x = widgets.HBox([label_x, self.x_widget])
row_y = widgets.HBox([label_y, self.y_widget])
row_angle = widgets.HBox([label_angle, self.angle_widget])
row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget])
controls_column1 = widgets.VBox(
[title_column1, row_x, row_y, row_angle, row_fraction]
)
controls_column2 = widgets.VBox(
[title_column2, self.slice_axis_widget, self.slice_index_widget]
)
controls = widgets.HBox([controls_column1, controls_column2])
interactive_plot = widgets.interactive_output(
self.update,
{
"slice_axis": self.slice_axis_widget,
"slice_index": self.slice_index_widget,
"x": self.x_widget,
"y": self.y_widget,
"angle_deg": self.angle_widget,
"fraction_range": self.line_fraction_widget,
},
)
return widgets.VBox([controls, interactive_plot])
def line_profile(
volume: np.ndarray,
slice_axis: int = 0,
slice_index: int | str = "middle",
vertical_position: int | str = "middle",
horizontal_position: int | str = "middle",
angle: int = 0,
fraction_range: Tuple[float, float] = (0.00, 1.00),
) -> widgets.interactive:
"""
Returns an interactive widget for visualizing the intensity profiles of lines on slices.
Args:
volume (np.ndarray): The 3D volume of interest.
slice_axis (int, optional): Specifies the initial axis along which to slice.
slice_index (int or str, optional): Specifies the initial slice index along slice_axis.
vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point.
horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point.
angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo.
fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0].
Returns:
widget (widgets.widget_box.VBox): The interactive widget.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.line_profile(vol)
```
![viz histogram](../../assets/screenshots/viz-line_profile.gif)
"""
def parse_position(pos, pos_range, name):
if isinstance(pos, int):
if not pos_range[0] <= pos < pos_range[1]:
raise ValueError(
f"Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]"
)
return pos
elif isinstance(pos, str):
pos = pos.lower()
if pos == "start":
return pos_range[0]
elif pos == "middle":
return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
elif pos == "end":
return pos_range[1]
else:
raise ValueError(
f"Invalid string '{pos}' for {name}. "
"Must be 'start', 'middle', or 'end'."
)
else:
raise TypeError("Axis position must be of type int or str.")
if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError("Data type for volume not supported.")
if volume.ndim != 3:
raise ValueError("Volume must be 3D.")
dims = volume.shape
slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), "slice_index")
# the omission of the ends for the pivot point is due to border issues.
vertical_position = parse_position(
vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), "vertical_position"
)
horizontal_position = parse_position(
horizontal_position,
(1, np.delete(dims, slice_axis)[1] - 2),
"horizontal_position",
)
if not isinstance(angle, int | float):
raise ValueError("Invalid type for angle.")
angle = round(angle) % 360
if not (
0.0 <= fraction_range[0] <= 1.0
and 0.0 <= fraction_range[1] <= 1.0
and fraction_range[0] <= fraction_range[1]
):
raise ValueError("Invalid values for fraction_range.")
lp = _LineProfile(
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
)
return lp.build_interactive()
def threshold(
volume: np.ndarray,
cmap_image: str = 'magma',
vmin: float = None,
vmax: float = None,
) -> widgets.VBox:
"""
This function provides an interactive interface to explore thresholding on a
3D volume slice-by-slice. Users can either manually set the threshold value
using a slider or select an automatic thresholding method from `skimage`.
The visualization includes the original image slice, a binary mask showing regions above the
threshold and an overlay combining the binary mask and the original image.
Args:
volume (np.ndarray): 3D volume to threshold.
cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'.
cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'.
vmin (float, optional): Minimum value for the colormap. Defaults to None.
vmax (float, optional): Maximum value for the colormap. Defaults to None.
Returns:
slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume.
Interactivity:
- **Manual Thresholding**:
Select 'Manual' from the dropdown menu to manually adjust the threshold
using the slider.
- **Automatic Thresholding**:
Choose a method from the dropdown menu to apply an automatic thresholding
algorithm. Available methods include:
- Otsu
- Isodata
- Li
- Mean
- Minimum
- Triangle
- Yen
The threshold slider will display the computed value and will be disabled
in this mode.
```python
import qim3d
# Load a sample volume
vol = qim3d.examples.bone_128x128x128
# Visualize interactive thresholding
qim3d.viz.threshold(vol)
```
![interactive threshold](../../assets/screenshots/interactive_thresholding.gif)
"""
# Centralized state dictionary to track current parameters
state = {
"position": volume.shape[0] // 2,
"threshold": int((volume.min() + volume.max()) / 2),
"method": "Manual",
}
threshold_methods = {
"Otsu": threshold_otsu,
"Isodata": threshold_isodata,
"Li": threshold_li,
"Mean": threshold_mean,
"Minimum": threshold_minimum,
"Triangle": threshold_triangle,
"Yen": threshold_yen,
}
# Create an output widget to display the plot
output = widgets.Output()
# Function to update the state and trigger visualization
def update_state(change):
# Update state based on widget values
state["position"] = position_slider.value
state["method"] = method_dropdown.value
if state["method"] == "Manual":
state["threshold"] = threshold_slider.value
threshold_slider.disabled = False
else:
threshold_func = threshold_methods.get(state["method"])
if threshold_func:
slice_img = volume[state["position"], :, :]
computed_threshold = threshold_func(slice_img)
state["threshold"] = computed_threshold
# Programmatically update the slider without triggering callbacks
threshold_slider.unobserve_all()
threshold_slider.value = computed_threshold
threshold_slider.disabled = True
threshold_slider.observe(update_state, names="value")
else:
raise ValueError(f"Unsupported thresholding method: {state['method']}")
# Trigger visualization
update_visualization()
# Visualization function
def update_visualization():
slice_img = volume[state["position"], :, :]
with output:
output.clear_output(wait=True) # Clear previous plot
fig, axes = plt.subplots(1, 4, figsize=(25, 5))
# Original image
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title("Original")
axes[0].axis("off")
# Histogram
histogram(
volume=volume,
bins=32,
slice_idx=state["position"],
vertical_line=state["threshold"],
axis=1,
kde=False,
ax=axes[1],
show=False,
)
axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")
# Binary mask
mask = slice_img > state["threshold"]
axes[2].imshow(mask, cmap="gray")
axes[2].set_title("Binary mask")
axes[2].axis("off")
# Overlay
mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
mask_rgb[:, :, 0] = mask
masked_volume = qim3d.operations.overlay_rgb_images(
background=slice_img,
foreground=mask_rgb,
)
axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
axes[3].set_title("Overlay")
axes[3].axis("off")
plt.show()
# Widgets
position_slider = widgets.IntSlider(
value=state["position"],
min=0,
max=volume.shape[0] - 1,
description="Slice",
)
threshold_slider = widgets.IntSlider(
value=state["threshold"],
min=volume.min(),
max=volume.max(),
description="Threshold",
)
method_dropdown = widgets.Dropdown(
options=[
"Manual",
"Otsu",
"Isodata",
"Li",
"Mean",
"Minimum",
"Triangle",
"Yen",
],
value=state["method"],
description="Method",
)
# Attach the state update function to widgets
position_slider.observe(update_state, names="value")
threshold_slider.observe(update_state, names="value")
method_dropdown.observe(update_state, names="value")
# Layout
controls_left = widgets.VBox([position_slider, threshold_slider])
controls_right = widgets.VBox([method_dropdown])
controls_layout = widgets.HBox(
[controls_left, controls_right],
layout=widgets.Layout(justify_content="flex-start"),
)
interactive_ui = widgets.VBox([controls_layout, output])
update_visualization()
return interactive_ui