Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes
Showing
with 1476 additions and 1190 deletions
from ._generators import noise_object
from ._aggregators import noise_object_collection
from ._generators import noise_object
......@@ -22,6 +22,7 @@ def random_placement(
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
"""
# Find available (zero) elements in collection
available_z, available_y, available_x = np.where(collection == 0)
......@@ -44,14 +45,12 @@ def random_placement(
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
......@@ -81,6 +80,7 @@ def specific_placement(
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
positions (list[tuple]): List of remaining positions to place blobs.
"""
# Flag for placement success
placed = False
......@@ -99,14 +99,12 @@ def specific_placement(
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
......@@ -289,23 +287,24 @@ def noise_object_collection(
qim3d.viz.slices_grid(vol, num_slices=15, slice_axis=1)
```
![synthetic_collection_tube](../../assets/screenshots/synthetic_collection_tube_slices.png)
"""
if verbose:
original_log_level = log.getEffectiveLevel()
log.setLevel("DEBUG")
log.setLevel('DEBUG')
# Check valid input types
if not isinstance(collection_shape, tuple) or len(collection_shape) != 3:
raise TypeError(
"Shape of collection must be a tuple with three dimensions (z, y, x)"
'Shape of collection must be a tuple with three dimensions (z, y, x)'
)
if len(min_shape) != len(max_shape):
raise ValueError("Object shapes must be tuples of the same length")
raise ValueError('Object shapes must be tuples of the same length')
if (positions is not None) and (len(positions) != num_objects):
raise ValueError(
"Number of objects must match number of positions, otherwise set positions = None"
'Number of objects must match number of positions, otherwise set positions = None'
)
# Set seed for random number generator
......@@ -318,8 +317,8 @@ def noise_object_collection(
labels = np.zeros_like(collection_array)
# Fill the 3D array with synthetic blobs
for i in tqdm(range(num_objects), desc="Objects placed"):
log.debug(f"\nObject #{i+1}")
for i in tqdm(range(num_objects), desc='Objects placed'):
log.debug(f'\nObject #{i+1}')
# Sample from blob parameter ranges
if min_shape == max_shape:
......@@ -328,7 +327,7 @@ def noise_object_collection(
blob_shape = tuple(
rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3)
)
log.debug(f"- Blob shape: {blob_shape}")
log.debug(f'- Blob shape: {blob_shape}')
# Scale object shape
final_shape = tuple(l * r for l, r in zip(blob_shape, object_shape_zoom))
......@@ -336,19 +335,19 @@ def noise_object_collection(
# Sample noise scale
noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise)
log.debug(f"- Object noise scale: {noise_scale:.4f}")
log.debug(f'- Object noise scale: {noise_scale:.4f}')
gamma = rng.uniform(low=min_gamma, high=max_gamma)
log.debug(f"- Gamma correction: {gamma:.3f}")
log.debug(f'- Gamma correction: {gamma:.3f}')
if max_high_value > min_high_value:
max_value = rng.integers(low=min_high_value, high=max_high_value)
else:
max_value = min_high_value
log.debug(f"- Max value: {max_value}")
log.debug(f'- Max value: {max_value}')
threshold = rng.uniform(low=min_threshold, high=max_threshold)
log.debug(f"- Threshold: {threshold:.3f}")
log.debug(f'- Threshold: {threshold:.3f}')
# Generate synthetic object
blob = qim3d.generate.noise_object(
......@@ -368,7 +367,7 @@ def noise_object_collection(
low=min_rotation_degrees, high=max_rotation_degrees
) # Sample rotation angle
axes = rng.choice(rotation_axes) # Sample the two axes to rotate around
log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}")
log.debug(f'- Rotation angle: {angle:.2f} at axes: {axes}')
blob = scipy.ndimage.rotate(blob, angle, axes, order=1)
......@@ -397,7 +396,7 @@ def noise_object_collection(
if not placed:
# Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar)
log.error(
f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects."
f'Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects.'
)
if verbose:
......
......@@ -4,6 +4,7 @@ from noise import pnoise3
import qim3d.processing
def noise_object(
base_shape: tuple = (128, 128, 128),
final_shape: tuple = (128, 128, 128),
......@@ -14,7 +15,7 @@ def noise_object(
threshold: float = 0.5,
smooth_borders: bool = False,
object_shape: str = None,
dtype: str = "uint8",
dtype: str = 'uint8',
) -> np.ndarray:
"""
Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction.
......@@ -103,12 +104,13 @@ def noise_object(
qim3d.viz.slices_grid(vol, num_slices=15)
```
![synthetic_blob_tube_slice](../../assets/screenshots/synthetic_blob_tube_slice.png)
"""
if not isinstance(final_shape, tuple) or len(final_shape) != 3:
raise TypeError("Size must be a tuple of 3 dimensions")
raise TypeError('Size must be a tuple of 3 dimensions')
if not np.issubdtype(dtype, np.number):
raise ValueError("Invalid data type")
raise ValueError('Invalid data type')
# Initialize the 3D array for the shape
volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32)
......@@ -119,18 +121,17 @@ def noise_object(
# Calculate the distance from the center of the shape
center = np.array(base_shape) / 2
dist = np.sqrt((z - center[0])**2 +
(y - center[1])**2 +
(x - center[2])**2)
dist = np.sqrt((z - center[0]) ** 2 + (y - center[1]) ** 2 + (x - center[2]) ** 2)
dist /= np.sqrt(3 * (center[0] ** 2))
# Generate Perlin noise and adjust the values based on the distance from the center
vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input
vectorized_pnoise3 = np.vectorize(
pnoise3
) # Vectorize pnoise3, since it only takes scalar input
noise = vectorized_pnoise3(z.flatten() * noise_scale,
y.flatten() * noise_scale,
x.flatten() * noise_scale
noise = vectorized_pnoise3(
z.flatten() * noise_scale, y.flatten() * noise_scale, x.flatten() * noise_scale
).reshape(base_shape)
volume = (1 + noise) * (1 - dist)
......@@ -150,11 +151,16 @@ def noise_object(
if smooth_borders:
# Maximum value among the six sides of the 3D volume
max_border_value = np.max([
np.max(volume[0, :, :]), np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]), np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]), np.max(volume[:, :, -1])
])
max_border_value = np.max(
[
np.max(volume[0, :, :]),
np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]),
np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]),
np.max(volume[:, :, -1]),
]
)
# Compute threshold such that there will be no straight cuts in the blob
threshold = max_border_value / max_value
......@@ -171,42 +177,47 @@ def noise_object(
)
# Fade into a shape if specified
if object_shape == "cylinder":
if object_shape == 'cylinder':
# Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical
geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(
volume.shape
) # Fade along the dimension where the object is the largest
target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
volume = qim3d.operations.fade_mask(volume,
volume = qim3d.operations.fade_mask(
volume,
geometry=geometry,
axis=axis,
target_max_normalized_distance = target_max_normalized_distance
target_max_normalized_distance=target_max_normalized_distance,
)
elif object_shape == "tube":
elif object_shape == 'tube':
# Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest
geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(
volume.shape
) # Fade along the dimension where the object is the largest
decay_rate = 5 # Decay rate for the fade operation
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical
target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
# Fade once for making the object cylindrical
volume = qim3d.operations.fade_mask(volume,
volume = qim3d.operations.fade_mask(
volume,
geometry=geometry,
axis=axis,
decay_rate=decay_rate,
target_max_normalized_distance=target_max_normalized_distance,
invert = False
invert=False,
)
# Fade again with invert = True for making the object a tube (i.e. with a hole in the middle)
volume = qim3d.operations.fade_mask(volume,
geometry = geometry,
axis = axis,
decay_rate = decay_rate,
invert = True
volume = qim3d.operations.fade_mask(
volume, geometry=geometry, axis=axis, decay_rate=decay_rate, invert=True
)
# Convert to desired data type
......
from fastapi import FastAPI
import qim3d.utils
from . import data_explorer
from . import iso3d
from . import local_thickness
from . import annotation_tool
from . import layers2d
from . import annotation_tool, data_explorer, iso3d, layers2d, local_thickness
from .qim_theme import QimTheme
def run_gradio_app(gradio_interface, host="0.0.0.0"):
def run_gradio_app(gradio_interface, host='0.0.0.0'):
import gradio as gr
import uvicorn
# Get port using the QIM API
port_dict = qim3d.utils.get_port_dict()
if "gradio_port" in port_dict:
port = port_dict["gradio_port"]
elif "port" in port_dict:
port = port_dict["port"]
if 'gradio_port' in port_dict:
port = port_dict['gradio_port']
elif 'port' in port_dict:
port = port_dict['port']
else:
raise Exception("Port not specified from QIM API")
raise Exception('Port not specified from QIM API')
qim3d.utils.gradio_header(gradio_interface.title, port)
......@@ -30,7 +28,7 @@ def run_gradio_app(gradio_interface, host="0.0.0.0"):
app = gr.mount_gradio_app(app, gradio_interface, path=path)
# Full path
print(f"http://{host}:{port}{path}")
print(f'http://{host}:{port}{path}')
# Run the FastAPI server usign uvicorn
uvicorn.run(app, host=host, port=int(port))
......@@ -27,6 +27,7 @@ import tempfile
import gradio as gr
import numpy as np
from PIL import Image
import qim3d
from qim3d.gui.interface import BaseInterface
......@@ -34,17 +35,19 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
def __init__(
self, name_suffix: str = '', verbose: bool = False, img: np.ndarray = None
):
super().__init__(
title="Annotation Tool",
title='Annotation Tool',
height=768,
width="100%",
width='100%',
verbose=verbose,
custom_css="annotation_tool.css",
custom_css='annotation_tool.css',
)
self.username = getpass.getuser()
self.temp_dir = os.path.join(tempfile.gettempdir(), f"qim-{self.username}")
self.temp_dir = os.path.join(tempfile.gettempdir(), f'qim-{self.username}')
self.name_suffix = name_suffix
self.img = img
......@@ -57,7 +60,7 @@ class Interface(BaseInterface):
# Get the temporary files from gradio
temp_path_list = []
for filename in os.listdir(self.temp_dir):
if "mask" and self.name_suffix in str(filename):
if 'mask' and self.name_suffix in str(filename):
# Get the list of the temporary files
temp_path_list.append(os.path.join(self.temp_dir, filename))
......@@ -76,9 +79,9 @@ class Interface(BaseInterface):
this is safer and backwards compatible (should be)
"""
self.mask_names = [
f"red{self.name_suffix}",
f"green{self.name_suffix}",
f"blue{self.name_suffix}",
f'red{self.name_suffix}',
f'green{self.name_suffix}',
f'blue{self.name_suffix}',
]
# Clean up old files
......@@ -86,7 +89,7 @@ class Interface(BaseInterface):
files = os.listdir(self.temp_dir)
for filename in files:
# Check if "mask" is in the filename
if ("mask" in filename) and (self.name_suffix in filename):
if ('mask' in filename) and (self.name_suffix in filename):
file_path = os.path.join(self.temp_dir, filename)
os.remove(file_path)
......@@ -94,13 +97,13 @@ class Interface(BaseInterface):
files = None
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"]
masks = img_editor["layers"][0]
background = img_editor['background']
masks = img_editor['layers'][0]
overlay_image = qim3d.operations.overlay_rgb_images(background, masks)
return overlay_image
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0]
masks_rgb = img_editor['layers'][0]
mask_threshold = 200 # This value is based
mask_list = []
......@@ -114,7 +117,7 @@ class Interface(BaseInterface):
# Save only if we have a mask
if np.sum(mask) > 0:
mask_list.append(mask)
filename = f"mask_{self.mask_names[idx]}.tif"
filename = f'mask_{self.mask_names[idx]}.tif'
if not os.path.exists(self.temp_dir):
os.makedirs(self.temp_dir)
filepath = os.path.join(self.temp_dir, filename)
......@@ -128,11 +131,11 @@ class Interface(BaseInterface):
def define_interface(self, **kwargs):
brush = gr.Brush(
colors=[
"rgb(255,50,100)",
"rgb(50,250,100)",
"rgb(50,100,255)",
'rgb(255,50,100)',
'rgb(50,250,100)',
'rgb(50,100,255)',
],
color_mode="fixed",
color_mode='fixed',
default_size=10,
)
with gr.Row():
......@@ -142,26 +145,25 @@ class Interface(BaseInterface):
img_editor = gr.ImageEditor(
value=(
{
"background": self.img,
"layers": [Image.new("RGBA", self.img.shape, (0, 0, 0, 0))],
"composite": None,
'background': self.img,
'layers': [Image.new('RGBA', self.img.shape, (0, 0, 0, 0))],
'composite': None,
}
if self.img is not None
else None
),
type="numpy",
image_mode="RGB",
type='numpy',
image_mode='RGB',
brush=brush,
sources="upload",
sources='upload',
interactive=True,
show_download_button=True,
container=False,
transforms=["crop"],
transforms=['crop'],
layers=False,
)
with gr.Column(scale=1, min_width=256):
with gr.Row():
overlay_img = gr.Image(
show_download_button=False,
......@@ -169,7 +171,7 @@ class Interface(BaseInterface):
visible=False,
)
with gr.Row():
masks_download = gr.File(label="Download masks", visible=False)
masks_download = gr.File(label='Download masks', visible=False)
# fmt: off
img_editor.change(
......
......@@ -18,50 +18,47 @@ app.launch()
import datetime
import os
import re
from typing import Any, Callable, Dict
import gradio as gr
import matplotlib
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import outputformat as ouf
from qim3d.gui.interface import BaseInterface
from qim3d.io import load
from qim3d.utils._logger import log
from qim3d.utils import _misc
from qim3d.gui.interface import BaseInterface
from typing import Callable, Any, Dict
import matplotlib
from qim3d.utils._logger import log
class Interface(BaseInterface):
def __init__(self,
def __init__(
self,
verbose: bool = False,
figsize: int = 8,
display_saturation_percentile: int = 99,
nbins:int = 32):
nbins: int = 32,
):
"""
Parameters:
-----------
Parameters
----------
verbose (bool, optional): If true, prints info during session into terminal. Defualt is False.
figsize (int, optional): Sets the size of plots displaying the slices. Default is 8.
display_saturation_percentile (int, optional): Sets the display saturation percentile. Defaults to 99.
"""
super().__init__(
title = "Data Explorer",
height = 1024,
width = 900,
verbose = verbose
)
self.axis_dict = {"Z":0, "Y":1, "X":2}
super().__init__(title='Data Explorer', height=1024, width=900, verbose=verbose)
self.axis_dict = {'Z': 0, 'Y': 1, 'X': 2}
self.all_operations = [
"Z Slicer",
"Y Slicer",
"X Slicer",
"Z max projection",
"Z min projection",
"Intensity histogram",
"Data summary",
'Z Slicer',
'Y Slicer',
'X Slicer',
'Z max projection',
'Z min projection',
'Intensity histogram',
'Data summary',
]
self.calculated_operations = [] # For changing the visibility of results, we keep track what was calculated and thus will be displayed
......@@ -79,7 +76,12 @@ class Interface(BaseInterface):
# Spinner state - what phase after clicking run button are we in
self.spinner_state = -1
self.spinner_messages = ["Starting session...", "Loading data...", "Running pipeline...", "Relaunch"]
self.spinner_messages = [
'Starting session...',
'Loading data...',
'Running pipeline...',
'Relaunch',
]
# Error message that we want to show, for more details look inside function check error state
self.error_message = None
......@@ -87,57 +89,55 @@ class Interface(BaseInterface):
# File selection and parameters
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### File selection")
gr.Markdown('### File selection')
with gr.Row():
with gr.Column(scale=99, min_width=128):
base_path = gr.Textbox(
max_lines=1,
container=False,
label="Base path",
label='Base path',
value=os.getcwd(),
)
with gr.Column(scale=1, min_width=36):
reload_base_path = gr.Button(
value=""
)
reload_base_path = gr.Button(value='')
explorer = gr.FileExplorer(
ignore_glob="*/.*", # ignores hidden files
ignore_glob='*/.*', # ignores hidden files
root_dir=os.getcwd(),
label=os.getcwd(),
render=True,
file_count="single",
file_count='single',
interactive=True,
height=320,
)
with gr.Column(scale=1):
gr.Markdown("### Parameters")
gr.Markdown('### Parameters')
cmap = gr.Dropdown(
value="viridis",
value='viridis',
choices=plt.colormaps(),
label="Colormap",
label='Colormap',
interactive=True,
)
virtual_stack = gr.Checkbox(
value=False,
label="Virtual stack",
info="If checked, will use less memory by loading the images on demand.",
label='Virtual stack',
info='If checked, will use less memory by loading the images on demand.',
)
load_series = gr.Checkbox(
value=False,
label="Load series",
info="If checked, will load the whole series of images in the same folder as the selected file.",
label='Load series',
info='If checked, will load the whole series of images in the same folder as the selected file.',
)
series_contains = gr.Textbox(
label="Specify common part of file names for series",
value="",
label='Specify common part of file names for series',
value='',
visible=False,
)
dataset_name = gr.Textbox(
label="Dataset name (in case of H5 files, for example)",
value="exchange/data",
label='Dataset name (in case of H5 files, for example)',
value='exchange/data',
)
def toggle_show(checkbox):
......@@ -151,7 +151,7 @@ class Interface(BaseInterface):
load_series.change(toggle_show, load_series, series_contains)
with gr.Column(scale=1):
gr.Markdown("### Operations")
gr.Markdown('### Operations')
operations = gr.CheckboxGroup(
choices=self.all_operations,
value=[self.all_operations[0], self.all_operations[-1]],
......@@ -161,11 +161,13 @@ class Interface(BaseInterface):
)
with gr.Row():
btn_run = gr.Button(
value="Load & Run", variant = "primary",
value='Load & Run',
variant='primary',
)
# Visualization and results
with gr.Row():
def create_uniform_image(intensity=1):
"""
Generates a blank image with a single color.
......@@ -174,50 +176,50 @@ class Interface(BaseInterface):
"""
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(pixels, interpolation="nearest")
ax.imshow(pixels, interpolation='nearest')
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
# Z Slicer
with gr.Column(visible=False) as result_z_slicer:
zslice_plot = gr.Plot(label="Z slice", value=create_uniform_image(1))
zslice_plot = gr.Plot(label='Z slice', value=create_uniform_image(1))
zpos = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.01, label="Z position"
minimum=0, maximum=1, value=0.5, step=0.01, label='Z position'
)
# Y Slicer
with gr.Column(visible=False) as result_y_slicer:
yslice_plot = gr.Plot(label="Y slice", value=create_uniform_image(1))
yslice_plot = gr.Plot(label='Y slice', value=create_uniform_image(1))
ypos = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.01, label="Y position"
minimum=0, maximum=1, value=0.5, step=0.01, label='Y position'
)
# X Slicer
with gr.Column(visible=False) as result_x_slicer:
xslice_plot = gr.Plot(label="X slice", value=create_uniform_image(1))
xslice_plot = gr.Plot(label='X slice', value=create_uniform_image(1))
xpos = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.01, label="X position"
minimum=0, maximum=1, value=0.5, step=0.01, label='X position'
)
# Z Max projection
with gr.Column(visible=False) as result_z_max_projection:
max_projection_plot = gr.Plot(
label="Z max projection",
label='Z max projection',
)
# Z Min projection
with gr.Column(visible=False) as result_z_min_projection:
min_projection_plot = gr.Plot(
label="Z min projection",
label='Z min projection',
)
# Intensity histogram
with gr.Column(visible=False) as result_intensity_histogram:
hist_plot = gr.Plot(label="Volume intensity histogram")
hist_plot = gr.Plot(label='Volume intensity histogram')
# Text box with data summary
with gr.Column(visible=False) as result_data_summary:
......@@ -225,12 +227,10 @@ class Interface(BaseInterface):
lines=24,
label=None,
show_label=False,
value="Data summary",
value='Data summary',
)
### Gradio objects lists
####################################
# EVENT LISTENERS
###################################
......@@ -256,20 +256,29 @@ class Interface(BaseInterface):
result_data_summary,
]
reload_base_path.click(fn=self.update_explorer,inputs=base_path, outputs=explorer)
btn_run.click(
fn=self.update_run_btn, inputs = [], outputs = btn_run).then(
fn=self.start_session, inputs = [load_series, series_contains, explorer, base_path], outputs = []).then(
fn=self.update_run_btn, inputs = [], outputs = btn_run).then(
fn=self.check_error_state, inputs = [], outputs = []).success(
fn=self.load_data, inputs= [virtual_stack, dataset_name, series_contains], outputs= []).then(
fn=self.update_run_btn, inputs = [], outputs = btn_run).then(
fn=self.check_error_state, inputs = [], outputs = []).success(
fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs).then(
fn=self.update_run_btn, inputs = [], outputs = btn_run).then(
fn=self.check_error_state, inputs = [], outputs = []).success(
fn=self.show_results, inputs = operations, outputs = results) # results are columns of images and other component, not just the components
reload_base_path.click(
fn=self.update_explorer, inputs=base_path, outputs=explorer
)
btn_run.click(fn=self.update_run_btn, inputs=[], outputs=btn_run).then(
fn=self.start_session,
inputs=[load_series, series_contains, explorer, base_path],
outputs=[],
).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then(
fn=self.check_error_state, inputs=[], outputs=[]
).success(
fn=self.load_data,
inputs=[virtual_stack, dataset_name, series_contains],
outputs=[],
).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then(
fn=self.check_error_state, inputs=[], outputs=[]
).success(
fn=self.run_operations, inputs=pipeline_inputs, outputs=pipeline_outputs
).then(fn=self.update_run_btn, inputs=[], outputs=btn_run).then(
fn=self.check_error_state, inputs=[], outputs=[]
).success(
fn=self.show_results, inputs=operations, outputs=results
) # results are columns of images and other component, not just the components
"""
Gradio passes only the value to the function, not the whole component.
......@@ -278,15 +287,21 @@ class Interface(BaseInterface):
The self.update_slice_wrapper returns a function.
"""
sliders = [xpos, ypos, zpos]
letters = ["X", "Y", "Z"]
letters = ['X', 'Y', 'Z']
plots = [xslice_plot, yslice_plot, zslice_plot]
for slider, letter, plot in zip(sliders, letters, plots):
slider.change(fn = self.update_slice_wrapper(letter), inputs = [slider, cmap], outputs = plot, show_progress="hidden")
slider.change(
fn=self.update_slice_wrapper(letter),
inputs=[slider, cmap],
outputs=plot,
show_progress='hidden',
)
# Immediate change without the need of pressing the relaunch button
operations.change(fn=self.show_results, inputs=operations, outputs=results)
cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
cmap.change(
fn=self.run_operations, inputs=pipeline_inputs, outputs=pipeline_outputs
)
def update_explorer(self, new_path: str):
new_path = os.path.expanduser(new_path)
......@@ -301,18 +316,22 @@ class Interface(BaseInterface):
return gr.update(root_dir=parent_dir, label=parent_dir, value=file_name)
else:
raise ValueError("Invalid path")
raise ValueError('Invalid path')
def update_run_btn(self):
"""
When run_btn is clicked, it becomes uninteractive and displays which operation is now in progress
When all operations are done, it becomes interactive again with 'Relaunch' label
"""
self.spinner_state = (self.spinner_state + 1) % len(self.spinner_messages) if self.error_message is None else len(self.spinner_messages) - 1
self.spinner_state = (
(self.spinner_state + 1) % len(self.spinner_messages)
if self.error_message is None
else len(self.spinner_messages) - 1
)
message = self.spinner_messages[self.spinner_state]
interactive = (self.spinner_state == len(self.spinner_messages) - 1)
interactive = self.spinner_state == len(self.spinner_messages) - 1
return gr.update(
value=f"{message}",
value=f'{message}',
interactive=interactive,
)
......@@ -332,21 +351,26 @@ class Interface(BaseInterface):
#
#######################################################
def start_session(self, load_series:bool, series_contains:str, explorer:str, base_path:str):
self.projections_calculated = False # Probably new file was loaded, we would need new projections
def start_session(
self, load_series: bool, series_contains: str, explorer: str, base_path: str
):
self.projections_calculated = (
False # Probably new file was loaded, we would need new projections
)
if load_series and series_contains == "":
if load_series and series_contains == '':
# Try to guess the common part of the file names
try:
filename = explorer.split("/")[-1] # Extract filename from path
series_contains = re.search(r"[^0-9]+", filename).group()
gr.Info(f"Using '{series_contains}' as common file name part for loading.")
filename = explorer.split('/')[-1] # Extract filename from path
series_contains = re.search(r'[^0-9]+', filename).group()
gr.Info(
f"Using '{series_contains}' as common file name part for loading."
)
self.series_contains = series_contains
except:
self.error_message = "For series, common part of file name must be provided in 'series_contains' field."
# Get the file path from the explorer or base path
# priority is given to the explorer if file is selected
# else the base path is used
......@@ -357,20 +381,19 @@ class Interface(BaseInterface):
self.file_path = base_path
else:
self.error_message = "Invalid file path"
self.error_message = 'Invalid file path'
# If we are loading a series, we need to get the directory
if load_series:
self.file_path = os.path.dirname(self.file_path)
def load_data(self, virtual_stack: bool, dataset_name: str, contains: str):
try:
self.vol = load(
path=self.file_path,
virtual_stack=virtual_stack,
dataset_name=dataset_name,
contains = contains
contains=contains,
)
# Incase the data is 4D (RGB for example), we take the mean of the last dimension
......@@ -379,54 +402,58 @@ class Interface(BaseInterface):
# The rest of the pipeline expects 3D data
if self.vol.ndim != 3:
self.error_message = F"Invalid data shape should be 3 dimensional, not shape: {self.vol.shape}"
self.error_message = f'Invalid data shape should be 3 dimensional, not shape: {self.vol.shape}'
except Exception as error_message:
self.error_message = F"Error when loading data: {error_message}"
self.error_message = f'Error when loading data: {error_message}'
def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
outputs = []
self.calculated_operations = []
for operation in self.all_operations:
if operation in operations:
log.info(f"Running {operation}")
log.info(f'Running {operation}')
try:
outputs.append(self.run_operation(operation, *args))
self.calculated_operations.append(operation)
except Exception as err:
self.error_message = F"Error while running operation '{operation}': {err}"
self.error_message = (
f"Error while running operation '{operation}': {err}"
)
log.info(self.error_message)
outputs.append(gr.update())
else:
log.info(f"Skipping {operation}")
log.info(f'Skipping {operation}')
outputs.append(gr.update())
return outputs
def run_operation(self, operation:list, zpos:float, ypos:float, xpos:float, cmap:str, *args):
def run_operation(
self, operation: list, zpos: float, ypos: float, xpos: float, cmap: str, *args
):
match operation:
case "Z Slicer":
return self.update_slice_wrapper("Z")(zpos, cmap)
case "Y Slicer":
return self.update_slice_wrapper("Y")(ypos, cmap)
case "X Slicer":
return self.update_slice_wrapper("X")(xpos, cmap)
case "Z max projection":
case 'Z Slicer':
return self.update_slice_wrapper('Z')(zpos, cmap)
case 'Y Slicer':
return self.update_slice_wrapper('Y')(ypos, cmap)
case 'X Slicer':
return self.update_slice_wrapper('X')(xpos, cmap)
case 'Z max projection':
return self.create_projections_figs()[0]
case "Z min projection":
case 'Z min projection':
return self.create_projections_figs()[1]
case "Intensity histogram":
case 'Intensity histogram':
# If the operations are run with the run_button, spinner_state == 2,
# If we just changed cmap, spinner state would be 3
# and we don't have to calculate histogram again
# That saves a lot of time as the histogram takes the most time to calculate
return self.plot_histogram() if self.spinner_state == 2 else gr.update()
case "Data summary":
case 'Data summary':
return self.show_data_summary()
case _:
raise NotImplementedError(F"Operation '{operation} is not defined")
raise NotImplementedError(f"Operation '{operation} is not defined")
def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
update_list = []
......@@ -446,15 +473,17 @@ class Interface(BaseInterface):
def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(img, interpolation="nearest", **kwargs)
ax.imshow(img, interpolation='nearest', **kwargs)
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
def update_slice_wrapper(
self, letter: str
) -> Callable[[float, str], Dict[str, Any]]:
def update_slice(position_slider: float, cmap: str) -> Dict[str, Any]:
"""
position_slider: float from gradio slider, saying which relative slice we want to see
......@@ -479,10 +508,15 @@ class Interface(BaseInterface):
fig_img = self.create_img_fig(slice_img, vmin=vmin, vmax=vmax)
return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
return gr.update(
value=fig_img, label=f'{letter} Slice: {slice_index}', visible=True
)
return update_slice
def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
def vol_histogram(
self, nbins: int, min_value: float, max_value: float
) -> tuple[np.ndarray, np.ndarray]:
# Start histogram
vol_hist = np.zeros(nbins)
......@@ -500,22 +534,28 @@ class Interface(BaseInterface):
if not self.projections_calculated:
_ = self.get_projections()
vol_hist, bin_edges = self.vol_histogram(self.nbins, self.min_value, self.max_value)
vol_hist, bin_edges = self.vol_histogram(
self.nbins, self.min_value, self.max_value
)
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge")
ax.bar(
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge'
)
# Adjustments
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_yscale("log")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_yscale('log')
return fig
def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
def create_projections_figs(
self,
) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
if not self.projections_calculated:
projections = self.get_projections()
self.max_projection = projections[0]
......@@ -539,7 +579,7 @@ class Interface(BaseInterface):
def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
# Create arrays for iteration
max_projection = np.zeros(np.shape(self.vol[0]))
min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
min_projection = np.ones(np.shape(self.vol[0])) * float('inf')
intensity_sum = 0
# Iterate over slices. This is needed in case of virtual stacks.
......@@ -566,20 +606,22 @@ class Interface(BaseInterface):
def show_data_summary(self):
summary_dict = {
"Last modified": datetime.datetime.fromtimestamp(os.path.getmtime(self.file_path)).strftime("%Y-%m-%d %H:%M"),
"File size": _misc.sizeof(os.path.getsize(self.file_path)),
"Z-size": str(self.vol.shape[self.axis_dict["Z"]]),
"Y-size": str(self.vol.shape[self.axis_dict["Y"]]),
"X-size": str(self.vol.shape[self.axis_dict["X"]]),
"Data type": str(self.vol.dtype),
"Min value": str(self.vol.min()),
"Mean value": str(np.mean(self.vol)),
"Max value": str(self.vol.max()),
'Last modified': datetime.datetime.fromtimestamp(
os.path.getmtime(self.file_path)
).strftime('%Y-%m-%d %H:%M'),
'File size': _misc.sizeof(os.path.getsize(self.file_path)),
'Z-size': str(self.vol.shape[self.axis_dict['Z']]),
'Y-size': str(self.vol.shape[self.axis_dict['Y']]),
'X-size': str(self.vol.shape[self.axis_dict['X']]),
'Data type': str(self.vol.dtype),
'Min value': str(self.vol.min()),
'Mean value': str(np.mean(self.vol)),
'Max value': str(self.vol.max()),
}
display_dict = {k: v for k, v in summary_dict.items() if v is not None}
return ouf.showdict(display_dict, return_str=True, title="Data summary")
return ouf.showdict(display_dict, return_str=True, title='Data summary')
if __name__ == "__main__":
if __name__ == '__main__':
Interface().run_interface()
from abc import ABC, abstractmethod
from os import listdir, path
from pathlib import Path
from abc import abstractmethod, ABC
from os import path, listdir
import gradio as gr
import numpy as np
from .qim_theme import QimTheme
import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli
class BaseInterface(ABC):
"""
Annotation tool and Data explorer as those don't need any examples.
"""
......@@ -19,7 +19,7 @@ class BaseInterface(ABC):
self,
title: str,
height: int,
width: int = "100%",
width: int = '100%',
verbose: bool = False,
custom_css: str = None,
):
......@@ -38,7 +38,7 @@ class BaseInterface(ABC):
self.qim_dir = Path(qim3d.__file__).parents[0]
self.custom_css = (
path.join(self.qim_dir, "css", custom_css)
path.join(self.qim_dir, 'css', custom_css)
if custom_css is not None
else None
)
......@@ -72,8 +72,7 @@ class BaseInterface(ABC):
quiet=not self.verbose,
height=self.height,
width=self.width,
favicon_path=Path(qim3d.__file__).parents[0]
/ "gui/assets/qim3d-icon.svg",
favicon_path=Path(qim3d.__file__).parents[0] / 'gui/assets/qim3d-icon.svg',
**kwargs,
)
......@@ -88,7 +87,7 @@ class BaseInterface(ABC):
title=self.title,
css=self.custom_css,
) as gradio_interface:
gr.Markdown(f"# {self.title}")
gr.Markdown(f'# {self.title}')
self.define_interface(**kwargs)
return gradio_interface
......@@ -96,11 +95,12 @@ class BaseInterface(ABC):
def define_interface(self, **kwargs):
pass
def run_interface(self, host: str = "0.0.0.0"):
def run_interface(self, host: str = '0.0.0.0'):
qim3d.gui.run_gradio_app(self.create_interface(), host)
class InterfaceWithExamples(BaseInterface):
"""
For Iso3D and Local Thickness
"""
......@@ -117,7 +117,23 @@ class InterfaceWithExamples(BaseInterface):
self._set_examples_list()
def _set_examples_list(self):
valid_sufixes = (".tif", ".tiff", ".h5", ".nii", ".gz", ".dcm", ".DCM", ".vol", ".vgi", ".txrm", ".txm", ".xrm")
valid_sufixes = (
'.tif',
'.tiff',
'.h5',
'.nii',
'.gz',
'.dcm',
'.DCM',
'.vol',
'.vgi',
'.txrm',
'.txm',
'.xrm',
)
examples_folder = path.join(self.qim_dir, 'examples')
self.img_examples = [path.join(examples_folder, example) for example in listdir(examples_folder) if example.endswith(valid_sufixes)]
self.img_examples = [
path.join(examples_folder, example)
for example in listdir(examples_folder)
if example.endswith(valid_sufixes)
]
......@@ -15,6 +15,7 @@ app.launch()
```
"""
import os
import gradio as gr
......@@ -23,21 +24,19 @@ import plotly.graph_objects as go
from scipy import ndimage
import qim3d
from qim3d.utils._logger import log
from qim3d.gui.interface import InterfaceWithExamples
from qim3d.utils._logger import log
# TODO img in launch should be self.img
class Interface(InterfaceWithExamples):
def __init__(self,
verbose:bool = False,
plot_height:int = 768,
img = None):
super().__init__(title = "Isosurfaces for 3D visualization",
def __init__(self, verbose: bool = False, plot_height: int = 768, img=None):
super().__init__(
title='Isosurfaces for 3D visualization',
height=1024,
width=960,
verbose = verbose)
verbose=verbose,
)
self.interface = None
self.img = img
......@@ -48,11 +47,13 @@ class Interface(InterfaceWithExamples):
self.vol = qim3d.io.load(gradiofile.name)
assert self.vol.ndim == 3
except AttributeError:
raise gr.Error("You have to select a file")
raise gr.Error('You have to select a file')
except ValueError:
raise gr.Error("Unsupported file format")
raise gr.Error('Unsupported file format')
except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size"""
......@@ -61,7 +62,7 @@ class Interface(InterfaceWithExamples):
original_Z, original_Y, original_X = np.shape(self.vol)
max_size = np.max([original_Z, original_Y, original_X])
if self.verbose:
log.info(f"\nOriginal volume: {original_Z, original_Y, original_X}")
log.info(f'\nOriginal volume: {original_Z, original_Y, original_X}')
# Resize for display
self.vol = ndimage.zoom(
......@@ -76,14 +77,15 @@ class Interface(InterfaceWithExamples):
)
if self.verbose:
log.info(
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
f'Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}'
)
def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk
fig.write_html(filename)
def create_fig(self,
def create_fig(
self,
gradio_file: gr.File,
display_size: int,
opacity: float,
......@@ -106,7 +108,6 @@ class Interface(InterfaceWithExamples):
show_x_slice: bool,
slice_x_location: int,
) -> tuple[go.Figure, str]:
# Load volume
self.load_data(gradio_file)
......@@ -161,10 +162,10 @@ class Interface(InterfaceWithExamples):
),
showscale=show_colorbar,
colorbar=dict(
thickness=8, outlinecolor="#fff", len=0.5, orientation="h"
thickness=8, outlinecolor='#fff', len=0.5, orientation='h'
),
reversescale=reversescale,
hoverinfo = "skip",
hoverinfo='skip',
)
)
......@@ -175,13 +176,13 @@ class Interface(InterfaceWithExamples):
scene_xaxis_visible=show_axis,
scene_yaxis_visible=show_axis,
scene_zaxis_visible=show_axis,
scene_aspectmode="data",
scene_aspectmode='data',
height=self.plot_height,
hovermode=False,
scene_camera_eye=dict(x=2.0, y=-2.0, z=1.5),
)
filename = "iso3d.html"
filename = 'iso3d.html'
self.save_fig(fig, filename)
return fig, filename
......@@ -189,10 +190,9 @@ class Interface(InterfaceWithExamples):
def remove_unused_file(self):
# Remove localthickness.tif file from working directory
# as it otherwise is not deleted
os.remove("iso3d.html")
os.remove('iso3d.html')
def define_interface(self, **kwargs):
gr.Markdown(
"""
This tool uses Plotly Volume (https://plotly.com/python/3d-volume-plots/) to create iso surfaces from voxels based on their intensity levels.
......@@ -203,117 +203,111 @@ class Interface(InterfaceWithExamples):
with gr.Row():
# Input and parameters column
with gr.Column(scale=1, min_width=320):
with gr.Tab("Input"):
with gr.Tab('Input'):
# File loader
gradio_file = gr.File(
show_label=False
)
with gr.Tab("Examples"):
gradio_file = gr.File(show_label=False)
with gr.Tab('Examples'):
gr.Examples(examples=self.img_examples, inputs=gradio_file)
# Run button
with gr.Row():
with gr.Column(scale=3, min_width=64):
btn_run = gr.Button(
value="Run 3D visualization", variant = "primary"
value='Run 3D visualization', variant='primary'
)
with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button(
value="Clear", variant = "stop"
)
btn_clear = gr.Button(value='Clear', variant='stop')
with gr.Tab("Display"):
with gr.Tab('Display'):
# Display options
display_size = gr.Slider(
32,
128,
step=4,
label="Display resolution",
info="Number of voxels for the largest dimension",
label='Display resolution',
info='Number of voxels for the largest dimension',
value=64,
)
surface_count = gr.Slider(
2, 16, step=1, label="Total iso-surfaces", value=6
2, 16, step=1, label='Total iso-surfaces', value=6
)
show_caps = gr.Checkbox(value=False, label="Show surface caps")
show_caps = gr.Checkbox(value=False, label='Show surface caps')
with gr.Row():
opacityscale = gr.Dropdown(
choices=["uniform", "extremes", "min", "max"],
value="uniform",
label="Opacity scale",
info="Handles opacity acording to voxel value",
choices=['uniform', 'extremes', 'min', 'max'],
value='uniform',
label='Opacity scale',
info='Handles opacity acording to voxel value',
)
opacity = gr.Slider(
0.0, 1.0, step=0.1, label="Max opacity", value=0.4
0.0, 1.0, step=0.1, label='Max opacity', value=0.4
)
with gr.Row():
min_value = gr.Slider(
0.0, 1.0, step=0.05, label="Min value", value=0.1
0.0, 1.0, step=0.05, label='Min value', value=0.1
)
max_value = gr.Slider(
0.0, 1.0, step=0.05, label="Max value", value=1
0.0, 1.0, step=0.05, label='Max value', value=1
)
with gr.Tab("Slices") as slices:
show_z_slice = gr.Checkbox(value=False, label="Show Z slice")
with gr.Tab('Slices') as slices:
show_z_slice = gr.Checkbox(value=False, label='Show Z slice')
slice_z_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
show_y_slice = gr.Checkbox(value=False, label="Show Y slice")
show_y_slice = gr.Checkbox(value=False, label='Show Y slice')
slice_y_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
show_x_slice = gr.Checkbox(value=False, label="Show X slice")
show_x_slice = gr.Checkbox(value=False, label='Show X slice')
slice_x_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
with gr.Tab("Misc"):
with gr.Tab('Misc'):
with gr.Row():
colormap = gr.Dropdown(
choices=[
"Blackbody",
"Bluered",
"Blues",
"Cividis",
"Earth",
"Electric",
"Greens",
"Greys",
"Hot",
"Jet",
"Magma",
"Picnic",
"Portland",
"Rainbow",
"RdBu",
"Reds",
"Viridis",
"YlGnBu",
"YlOrRd",
'Blackbody',
'Bluered',
'Blues',
'Cividis',
'Earth',
'Electric',
'Greens',
'Greys',
'Hot',
'Jet',
'Magma',
'Picnic',
'Portland',
'Rainbow',
'RdBu',
'Reds',
'Viridis',
'YlGnBu',
'YlOrRd',
],
value="Magma",
label="Colormap",
value='Magma',
label='Colormap',
)
show_colorbar = gr.Checkbox(
value=False, label="Show color scale"
value=False, label='Show color scale'
)
reversescale = gr.Checkbox(
value=False, label="Reverse color scale"
)
flip_z = gr.Checkbox(value=True, label="Flip Z axis")
show_axis = gr.Checkbox(value=True, label="Show axis")
show_ticks = gr.Checkbox(value=False, label="Show ticks")
only_wireframe = gr.Checkbox(
value=False, label="Only wireframe"
value=False, label='Reverse color scale'
)
flip_z = gr.Checkbox(value=True, label='Flip Z axis')
show_axis = gr.Checkbox(value=True, label='Show axis')
show_ticks = gr.Checkbox(value=False, label='Show ticks')
only_wireframe = gr.Checkbox(value=False, label='Only wireframe')
# Inputs for gradio
inputs = [
......@@ -346,7 +340,7 @@ class Interface(InterfaceWithExamples):
plot_download = gr.File(
interactive=False,
label="Download interactive plot",
label='Download interactive plot',
show_label=True,
visible=False,
)
......@@ -367,5 +361,6 @@ class Interface(InterfaceWithExamples):
fn=self.remove_unused_file).success(
fn=self.set_visible, inputs=None, outputs=plot_download)
if __name__ == "__main__":
if __name__ == '__main__':
Interface().run_interface()
......@@ -18,17 +18,13 @@ app = layers.launch()
"""
import os
from typing import Any, Dict
import gradio as gr
import numpy as np
from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d
from qim3d.processing import segment_layers, get_lines
from qim3d.operations import overlay_rgb_images
from qim3d.io import load
from qim3d.viz._layers2d import image_with_lines
from typing import Dict, Any
from .interface import BaseInterface
# TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error
......@@ -41,7 +37,9 @@ Z = 'Z'
AXES = {X: 2, Y: 1, Z: 0}
DEFAULT_PLOT_TYPE = 'Segmentation mask'
SEGMENTATION_COLORS = np.array([[0, 255, 255], # Cyan
SEGMENTATION_COLORS = np.array(
[
[0, 255, 255], # Cyan
[255, 195, 0], # Yellow Orange
[199, 0, 57], # Dark orange
[218, 247, 166], # Light green
......@@ -49,11 +47,13 @@ SEGMENTATION_COLORS = np.array([[0, 255, 255], # Cyan
[65, 105, 225], # Royal blue
[138, 43, 226], # Blue violet
[255, 0, 0], # Red
])
]
)
class Interface(BaseInterface):
def __init__(self):
super().__init__("Layered surfaces 2D", 1080)
super().__init__('Layered surfaces 2D', 1080)
self.data = None
# It important to keep the name of the attributes like this (including the capital letter) becuase of
......@@ -69,8 +69,6 @@ class Interface(BaseInterface):
self.error = False
def define_interface(self):
with gr.Row():
with gr.Column(scale=1, min_width=320):
......@@ -79,40 +77,40 @@ class Interface(BaseInterface):
base_path = gr.Textbox(
max_lines=1,
container=False,
label="Base path",
label='Base path',
value=os.getcwd(),
)
with gr.Column(scale=1, min_width=36):
reload_base_path = gr.Button(value="")
reload_base_path = gr.Button(value='')
explorer = gr.FileExplorer(
ignore_glob="*/.*",
ignore_glob='*/.*',
root_dir=os.getcwd(),
label=os.getcwd(),
render=True,
file_count="single",
file_count='single',
interactive=True,
height=230,
)
with gr.Group():
with gr.Row():
axis = gr.Radio(
choices=[Z, Y, X],
value=Z,
label='Layer axis',
info = 'Specifies in which direction are the layers. The order of axes is ZYX',)
info='Specifies in which direction are the layers. The order of axes is ZYX',
)
with gr.Row():
wrap = gr.Checkbox(
label = "Lines start and end at the same level.",
info = "Used when segmenting layers of unfolded image."
label='Lines start and end at the same level.',
info='Used when segmenting layers of unfolded image.',
)
is_inverted = gr.Checkbox(
label="Invert image before processing",
info="The algorithm effectively flips the gradient.",
label='Invert image before processing',
info='The algorithm effectively flips the gradient.',
)
with gr.Row():
......@@ -122,8 +120,8 @@ class Interface(BaseInterface):
value=0.75,
step=0.01,
interactive=True,
label="Delta value",
info="The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute. Delta above 1 is rounded down to closest lower integer",
label='Delta value',
info='The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute. Delta above 1 is rounded down to closest lower integer',
)
with gr.Row():
......@@ -133,8 +131,8 @@ class Interface(BaseInterface):
value=10,
step=1,
interactive=True,
label="Min margin",
info="Minimum margin between layers to be detected in the image.",
label='Min margin',
info='Minimum margin between layers to be detected in the image.',
)
with gr.Row():
......@@ -144,8 +142,8 @@ class Interface(BaseInterface):
value=2,
step=1,
interactive=True,
label="Number of layers",
info="Number of layers to be detected in the image",
label='Number of layers',
info='Number of layers to be detected in the image',
)
# with gr.Row():
......@@ -162,10 +160,15 @@ class Interface(BaseInterface):
change their height manually
"""
self.heights = ['60em', '30em', '20em'] # em units are relative to the parent,
self.heights = [
'60em',
'30em',
'20em',
] # em units are relative to the parent,
with gr.Column(scale=2,):
with gr.Column(
scale=2,
):
# with gr.Row(): # Source image outputs
# input_image_kwargs = lambda axis: dict(
# show_label = True,
......@@ -181,9 +184,9 @@ class Interface(BaseInterface):
with gr.Row(): # Detected layers outputs
output_image_kwargs = lambda axis: dict(
show_label=True,
label = F'Detected layers {axis}-axis',
label=f'Detected layers {axis}-axis',
visible=True,
height = self.heights[2]
height=self.heights[2],
)
output_plot_x = gr.Image(**output_image_kwargs('X'))
output_plot_y = gr.Image(**output_image_kwargs('Y'))
......@@ -195,8 +198,8 @@ class Interface(BaseInterface):
maximum=1,
value=0.5,
step=0.01,
label = F'{axis} position',
info = F'The 3D image is sliced along {axis}-axis'
label=f'{axis} position',
info=f'The 3D image is sliced along {axis}-axis',
)
x_pos = gr.Slider(**slider_kwargs('X'))
......@@ -204,17 +207,26 @@ class Interface(BaseInterface):
z_pos = gr.Slider(**slider_kwargs('Z'))
with gr.Row():
x_check = gr.Checkbox(value = True, interactive=True, label = 'Show X slice')
y_check = gr.Checkbox(value = True, interactive=True, label = 'Show Y slice')
z_check = gr.Checkbox(value = True, interactive=True, label = 'Show Z slice')
x_check = gr.Checkbox(
value=True, interactive=True, label='Show X slice'
)
y_check = gr.Checkbox(
value=True, interactive=True, label='Show Y slice'
)
z_check = gr.Checkbox(
value=True, interactive=True, label='Show Z slice'
)
with gr.Row():
with gr.Group():
plot_type = gr.Radio(
choices= (DEFAULT_PLOT_TYPE, 'Segmentation lines',),
choices=(
DEFAULT_PLOT_TYPE,
'Segmentation lines',
),
value=DEFAULT_PLOT_TYPE,
interactive=True,
show_label=False
show_label=False,
)
alpha = gr.Slider(
......@@ -225,7 +237,7 @@ class Interface(BaseInterface):
show_label=True,
value=0.5,
visible=True,
interactive=True
interactive=True,
)
line_thickness = gr.Slider(
......@@ -235,12 +247,11 @@ class Interface(BaseInterface):
label='Line thickness',
show_label=True,
visible=False,
interactive = True
interactive=True,
)
with gr.Row():
btn_run = gr.Button("Run Layers2D", variant = 'primary')
btn_run = gr.Button('Run Layers2D', variant='primary')
positions = [x_pos, y_pos, z_pos]
process_inputs = [axis, is_inverted, delta, min_margin, n_layers, wrap]
......@@ -249,20 +260,24 @@ class Interface(BaseInterface):
output_plots = [output_plot_x, output_plot_y, output_plot_z]
visibility_check_inputs = [x_check, y_check, z_check]
spinner_loading = gr.Text("Loading data...", visible=False)
spinner_running = gr.Text("Running pipeline...", visible=False)
spinner_loading = gr.Text('Loading data...', visible=False)
spinner_running = gr.Text('Running pipeline...', visible=False)
reload_base_path.click(
fn=self.update_explorer,inputs=base_path, outputs=explorer)
fn=self.update_explorer, inputs=base_path, outputs=explorer
)
plot_type.change(
self.change_plot_type, inputs = plot_type, outputs = [alpha, line_thickness]).then(
self.change_plot_type, inputs=plot_type, outputs=[alpha, line_thickness]
).then(
fn=self.plot_output_img_all, inputs=plotting_inputs, outputs=output_plots
)
gr.on(
triggers=[alpha.release, line_thickness.release],
fn = self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots
fn=self.plot_output_img_all,
inputs=plotting_inputs,
outputs=output_plots,
)
"""
......@@ -275,53 +290,92 @@ class Interface(BaseInterface):
update_component = gr.State(True)
btn_run.click(
fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run).then(
fn=self.load_data, inputs = [base_path, explorer]).then(
fn = lambda state: not state, inputs = update_component, outputs = update_component)
fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run
).then(fn=self.load_data, inputs=[base_path, explorer]).then(
fn=lambda state: not state,
inputs=update_component,
outputs=update_component,
)
gr.on(
triggers= (axis.change, is_inverted.change, delta.release, min_margin.release, n_layers.release, update_component.change, wrap.change),
fn=self.set_spinner, inputs = spinner_running, outputs=btn_run).then(
fn=self.process_all, inputs = [*positions, *process_inputs]).then(
triggers=(
axis.change,
is_inverted.change,
delta.release,
min_margin.release,
n_layers.release,
update_component.change,
wrap.change,
),
fn=self.set_spinner,
inputs=spinner_running,
outputs=btn_run,
).then(fn=self.process_all, inputs=[*positions, *process_inputs]).then(
# fn=self.plot_input_img_all, outputs = input_plots, show_progress='hidden').then(
fn=self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots, show_progress='hidden').then(
fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
fn=self.plot_output_img_all,
inputs=plotting_inputs,
outputs=output_plots,
show_progress='hidden',
).then(fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
# Chnages visibility and sizes of the plots - gives user the option to see only some of the images and in bigger scale
gr.on(
triggers=[x_check.change, y_check.change, z_check.change],
fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = positions).then(
fn=self.change_row_visibility,
inputs=visibility_check_inputs,
outputs=positions,
).then(
# fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = input_plots).then(
fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots)
fn=self.change_plot_size,
inputs=visibility_check_inputs,
outputs=output_plots,
)
# for axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots):
for axis, slider, output_plot in zip([X, Y, Z], positions, output_plots):
slider.change(
self.process_wrapper(axis), inputs = [slider, *process_inputs]).then(
self.process_wrapper(axis), inputs=[slider, *process_inputs]
).then(
# self.plot_input_img_wrapper(axis), outputs = input_plot).then(
self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot)
self.plot_output_img_wrapper(axis),
inputs=plotting_inputs,
outputs=output_plot,
)
def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
def change_plot_type(
self,
plot_type: str,
) -> tuple[Dict[str, Any], Dict[str, Any]]:
self.plot_type = plot_type
if plot_type == 'Segmentation lines':
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def change_plot_size(
self, x_check: int, y_check: int, z_check: int
) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""
Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
They are defines before defining the plot in 'define_interface'
"""
index = x_check + y_check + z_check - 1
height = self.heights[index] # also used to define heights of plots in the begining
return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
height = self.heights[
index
] # also used to define heights of plots in the begining
return (
gr.update(height=height, visible=x_check),
gr.update(height=height, visible=y_check),
gr.update(height=height, visible=z_check),
)
def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
return (
self.change_visibility(x_check),
self.change_visibility(y_check),
self.change_visibility(z_check),
)
def update_explorer(self, new_path: str):
# Refresh the file explorer object
......@@ -337,16 +391,16 @@ class Interface(BaseInterface):
return gr.update(root_dir=parent_dir, label=parent_dir, value=file_name)
else:
raise ValueError("Invalid path")
raise ValueError('Invalid path')
def set_relaunch_button(self):
return gr.update(value=f"Relaunch", interactive=True)
return gr.update(value='Relaunch', interactive=True)
def set_spinner(self, message: str):
if self.error:
return gr.Button()
# spinner icon/shows the user something is happeing
return gr.update(value=f"{message}", interactive=False)
return gr.update(value=f'{message}', interactive=False)
def load_data(self, base_path: str, explorer: str):
if base_path and os.path.isfile(base_path):
......@@ -354,22 +408,36 @@ class Interface(BaseInterface):
elif explorer and os.path.isfile(explorer):
file_path = explorer
else:
raise gr.Error("Invalid file path")
raise gr.Error('Invalid file path')
try:
self.data = qim3d.io.load(
file_path,
progress_bar=False
)
self.data = qim3d.io.load(file_path, progress_bar=False)
except Exception as error_message:
raise gr.Error(
f"Failed to load the image: {error_message}"
f'Failed to load the image: {error_message}'
) from error_message
def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
self.process_wrapper(X)(x_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper(Y)(y_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper(Z)(z_pos, axis, inverted, delta, min_margin, n_layers, wrap)
def process_all(
self,
x_pos: float,
y_pos: float,
z_pos: float,
axis: str,
inverted: bool,
delta: float,
min_margin: int,
n_layers: int,
wrap: bool,
):
self.process_wrapper(X)(
x_pos, axis, inverted, delta, min_margin, n_layers, wrap
)
self.process_wrapper(Y)(
y_pos, axis, inverted, delta, min_margin, n_layers, wrap
)
self.process_wrapper(Z)(
z_pos, axis, inverted, delta, min_margin, n_layers, wrap
)
def process_wrapper(self, slicing_axis: str):
"""
......@@ -377,14 +445,22 @@ class Interface(BaseInterface):
Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data
and we return a function working in that direction
"""
slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis}_segmentation'
slice_key = f'{slicing_axis}_slice'
seg_key = f'{slicing_axis}_segmentation'
slicing_axis_int = AXES[slicing_axis]
def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
def process(
pos: float,
segmenting_axis: str,
inverted: bool,
delta: float,
min_margin: int,
n_layers: int,
wrap: bool,
):
"""
Parameters:
-----------
Parameters
----------
pos: Relative position of a slice from data
segmenting_axis: In which direction we want to detect layers
inverted: If we want use inverted gradient
......@@ -392,6 +468,7 @@ class Interface(BaseInterface):
min_margin: What is the minimum distance between layers. If it was 0, all layers would be the same
n_layers: How many layer boarders we want to find
wrap: If True, the starting point and end point will be at the same level. Useful when segmenting unfolded images.
"""
slice = self.get_slice(pos, slicing_axis_int)
self.__dict__[slice_key] = slice
......@@ -399,10 +476,16 @@ class Interface(BaseInterface):
if segmenting_axis == slicing_axis:
self.__dict__[seg_key] = None
else:
if self.is_transposed(slicing_axis, segmenting_axis):
slice = np.rot90(slice)
self.__dict__[seg_key] = qim3d.processing.segment_layers(slice, inverted = inverted, n_layers = n_layers, delta = delta, min_margin = min_margin, wrap = wrap)
self.__dict__[seg_key] = qim3d.processing.segment_layers(
slice,
inverted=inverted,
n_layers=n_layers,
delta=delta,
min_margin=min_margin,
wrap=wrap,
)
return process
......@@ -411,7 +494,9 @@ class Interface(BaseInterface):
Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is.
If it is not, we have to rotate it before we put it to segmentation algorithm
"""
remaining_axis = F"{X}{Y}{Z}".replace(slicing_axis, '').replace(segmenting_axis, '')
remaining_axis = f'{X}{Y}{Z}'.replace(slicing_axis, '').replace(
segmenting_axis, ''
)
return AXES[segmenting_axis] > AXES[remaining_axis]
def get_slice(self, pos: float, axis: int):
......@@ -434,8 +519,8 @@ class Interface(BaseInterface):
# return x_plot, y_plot, z_plot
def plot_output_img_wrapper(self, slicing_axis: str):
slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis}_segmentation'
slice_key = f'{slicing_axis}_slice'
seg_key = f'{slicing_axis}_segmentation'
def plot_output_img(segmenting_axis: str, alpha: float, line_thickness: float):
slice = self.__dict__[slice_key]
......@@ -459,18 +544,28 @@ class Interface(BaseInterface):
else:
lines = qim3d.processing.get_lines(seg)
if self.is_transposed(slicing_axis, segmenting_axis):
return qim3d.viz.image_with_lines(np.rot90(slice), lines, line_thickness).rotate(270, expand = True)
return qim3d.viz.image_with_lines(
np.rot90(slice), lines, line_thickness
).rotate(270, expand=True)
else:
return qim3d.viz.image_with_lines(slice, lines, line_thickness)
return plot_output_img
def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float):
x_output = self.plot_output_img_wrapper(X)(segmenting_axis, alpha, line_thickness)
y_output = self.plot_output_img_wrapper(Y)(segmenting_axis, alpha, line_thickness)
z_output = self.plot_output_img_wrapper(Z)(segmenting_axis, alpha, line_thickness)
def plot_output_img_all(
self, segmenting_axis: str, alpha: float, line_thickness: float
):
x_output = self.plot_output_img_wrapper(X)(
segmenting_axis, alpha, line_thickness
)
y_output = self.plot_output_img_wrapper(Y)(
segmenting_axis, alpha, line_thickness
)
z_output = self.plot_output_img_wrapper(Z)(
segmenting_axis, alpha, line_thickness
)
return x_output, y_output, z_output
if __name__ == "__main__":
Interface().run_interface()
if __name__ == '__main__':
Interface().run_interface()
......@@ -32,29 +32,31 @@ app.launch()
```
"""
import os
import gradio as gr
import localthickness as lt
# matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
import tifffile
import localthickness as lt
import qim3d
import qim3d
class Interface(qim3d.gui.interface.InterfaceWithExamples):
def __init__(self,
def __init__(
self,
img: np.ndarray = None,
verbose: bool = False,
plot_height: int = 768,
figsize:int = 6):
super().__init__(title = "Local thickness",
height = 1024,
width = 960,
verbose = verbose)
figsize: int = 6,
):
super().__init__(
title='Local thickness', height=1024, width=960, verbose=verbose
)
self.plot_height = plot_height
self.figsize = figsize
......@@ -64,7 +66,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# Get the temporary files from gradio
temp_sets = self.interface.temp_file_sets
for temp_set in temp_sets:
if "localthickness" in str(temp_set):
if 'localthickness' in str(temp_set):
# Get the lsit of the temporary files
temp_path_list = list(temp_set)
......@@ -84,7 +86,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
def define_interface(self):
gr.Markdown(
"Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)"
'Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)'
)
with gr.Row():
......@@ -92,12 +94,12 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
if self.img is not None:
data = gr.State(value=self.img)
else:
with gr.Tab("Input"):
with gr.Tab('Input'):
data = gr.File(
show_label=False,
value=self.img,
)
with gr.Tab("Examples"):
with gr.Tab('Examples'):
gr.Examples(examples=self.img_examples, inputs=data)
with gr.Row():
......@@ -106,17 +108,15 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
maximum=1,
value=0.5,
step=0.01,
label="Z position",
info="Local thickness is calculated in 3D, this slider controls the visualization only.",
label='Z position',
info='Local thickness is calculated in 3D, this slider controls the visualization only.',
)
with gr.Tab("Parameters"):
with gr.Tab('Parameters'):
gr.Markdown(
"It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results."
)
lt_scale = gr.Slider(
0.1, 1.0, label="Scale", value=0.5, step=0.1
'It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results.'
)
lt_scale = gr.Slider(0.1, 1.0, label='Scale', value=0.5, step=0.1)
with gr.Row():
threshold = gr.Slider(
......@@ -124,85 +124,83 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
1.0,
value=0.5,
step=0.05,
label="Threshold",
info="Local thickness uses a binary image, so a threshold value is needed.",
label='Threshold',
info='Local thickness uses a binary image, so a threshold value is needed.',
)
dark_objects = gr.Checkbox(
value=False,
label="Dark objects",
info="Inverts the image before thresholding. Use in case your foreground is darker than the background.",
label='Dark objects',
info='Inverts the image before thresholding. Use in case your foreground is darker than the background.',
)
with gr.Tab("Display options"):
with gr.Tab('Display options'):
cmap_original = gr.Dropdown(
value="viridis",
value='viridis',
choices=plt.colormaps(),
label="Colormap - input",
label='Colormap - input',
interactive=True,
)
cmap_lt = gr.Dropdown(
value="magma",
value='magma',
choices=plt.colormaps(),
label="Colormap - local thickness",
label='Colormap - local thickness',
interactive=True,
)
nbins = gr.Slider(
5, 50, value=25, step=1, label="Histogram bins"
)
nbins = gr.Slider(5, 50, value=25, step=1, label='Histogram bins')
# Run button
with gr.Row():
with gr.Column(scale=3, min_width=64):
btn = gr.Button(
"Run local thickness", variant = "primary"
)
btn = gr.Button('Run local thickness', variant='primary')
with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button("Clear", variant = "stop")
btn_clear = gr.Button('Clear', variant='stop')
with gr.Column(scale=4):
def create_uniform_image(intensity=1):
"""
Generates a blank image with a single color.
Gradio `gr.Plot` components will flicker if there is no default value.
bug fix on gradio version 4.44.0
"""
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255)
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(
intensity * 255
)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(pixels, interpolation="nearest")
ax.imshow(pixels, interpolation='nearest')
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
with gr.Row():
input_vol = gr.Plot(
show_label=True,
label="Original",
label='Original',
visible=True,
value=create_uniform_image(),
)
binary_vol = gr.Plot(
show_label=True,
label="Binary",
label='Binary',
visible=True,
value=create_uniform_image(),
)
output_vol = gr.Plot(
show_label=True,
label="Local thickness",
label='Local thickness',
visible=True,
value=create_uniform_image(),
)
with gr.Row():
histogram = gr.Plot(
show_label=True,
label="Thickness histogram",
label='Thickness histogram',
visible=True,
value=create_uniform_image(),
)
......@@ -210,11 +208,10 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
lt_output = gr.File(
interactive=False,
show_label=True,
label="Output file",
label='Output file',
visible=False,
)
# Run button
# fmt: off
viz_input = lambda zpos, cmap: self.show_slice(self.vol, zpos, self.vmin, self.vmax, cmap)
......@@ -274,7 +271,9 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
except AttributeError:
self.vol = data
except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
if dark_objects:
self.vol = np.invert(self.vol)
......@@ -283,15 +282,22 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol)
def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
def show_slice(
self,
vol: np.ndarray,
zpos: int,
vmin: float = None,
vmax: float = None,
cmap: str = 'viridis',
):
plt.close()
z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(vol[z_idx], interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
ax.imshow(vol[z_idx], interpolation='nearest', cmap=cmap, vmin=vmin, vmax=vmax)
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
......@@ -318,20 +324,20 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge"
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge'
)
# Adjustments
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_yscale("log")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_yscale('log')
return fig
def save_lt(self):
filename = "localthickness.tif"
filename = 'localthickness.tif'
# Save output image in a temp space
tifffile.imwrite(filename, self.vol_thickness)
......@@ -342,5 +348,6 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# as it otherwise is not deleted
os.remove('localthickness.tif')
if __name__ == "__main__":
if __name__ == '__main__':
Interface().run_interface()
import gradio as gr
class QimTheme(gr.themes.Default):
"""
Theme for qim3d gradio interfaces.
The theming options are quite broad. However if there is something you can not achieve with this theme
there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
in comments.
"""
def __init__(self, force_light_mode: bool = True):
"""
Parameters:
-----------
Parameters
----------
- force_light_mode (bool, optional): Gradio themes have dark mode by default.
QIM platform is not ready for dark mode yet, thus the tools should also be in light mode.
This sets the darkmode values to be the same as light mode values.
"""
super().__init__()
self.force_light_mode = force_light_mode
......@@ -34,8 +38,14 @@ class QimTheme(gr.themes.Default):
def set_dark_mode_values(self):
if self.force_light_mode:
for attr in [dark_attr for dark_attr in dir(self) if not dark_attr.startswith("_") and dark_attr.endswith("dark")]:
self.__dict__[attr] = self.__dict__[attr[:-5]] # ligth and dark attributes have same names except for '_dark' at the end
for attr in [
dark_attr
for dark_attr in dir(self)
if not dark_attr.startswith('_') and dark_attr.endswith('dark')
]:
self.__dict__[attr] = self.__dict__[
attr[:-5]
] # ligth and dark attributes have same names except for '_dark' at the end
else:
self.set_dark_primary_button()
# Secondary button looks good by default in dark mode
......@@ -44,26 +54,28 @@ class QimTheme(gr.themes.Default):
# Example looks good by default in dark mode
def set_button(self):
self.button_transition = "0.15s"
self.button_large_text_weight = "normal"
self.button_transition = '0.15s'
self.button_large_text_weight = 'normal'
def set_light_primary_button(self):
self.run_color = "#198754"
self.button_primary_background_fill = "#FFFFFF"
self.run_color = '#198754'
self.button_primary_background_fill = '#FFFFFF'
self.button_primary_background_fill_hover = self.run_color
self.button_primary_border_color = self.run_color
self.button_primary_text_color = self.run_color
self.button_primary_text_color_hover = "#FFFFFF"
self.button_primary_text_color_hover = '#FFFFFF'
def set_dark_primary_button(self):
self.bright_run_color = "#299764"
self.button_primary_background_fill_dark = self.button_primary_background_fill_hover
self.bright_run_color = '#299764'
self.button_primary_background_fill_dark = (
self.button_primary_background_fill_hover
)
self.button_primary_background_fill_hover_dark = self.bright_run_color
self.button_primary_border_color_dark = self.button_primary_border_color
self.button_primary_border_color_hover_dark = self.bright_run_color
def set_light_secondary_button(self):
self.button_secondary_background_fill = "white"
self.button_secondary_background_fill = 'white'
def set_light_example(self):
"""
......@@ -73,10 +85,10 @@ class QimTheme(gr.themes.Default):
self.color_accent_soft = self.neutral_100
def set_h1(self):
self.text_xxl = "2.5rem"
self.text_xxl = '2.5rem'
def set_light_checkbox(self):
light_blue = "#60a5fa"
light_blue = '#60a5fa'
self.checkbox_background_color_selected = light_blue
self.checkbox_border_color_selected = light_blue
self.checkbox_border_color_focus = light_blue
......@@ -86,21 +98,20 @@ class QimTheme(gr.themes.Default):
self.checkbox_border_color_focus_dark = self.checkbox_border_color_focus_dark
def set_light_cancel_button(self):
self.cancel_color = "#dc3545"
self.button_cancel_background_fill = "white"
self.cancel_color = '#dc3545'
self.button_cancel_background_fill = 'white'
self.button_cancel_background_fill_hover = self.cancel_color
self.button_cancel_border_color = self.cancel_color
self.button_cancel_text_color = self.cancel_color
self.button_cancel_text_color_hover = "white"
self.button_cancel_text_color_hover = 'white'
def set_dark_cancel_button(self):
self.button_cancel_background_fill_dark = self.cancel_color
self.button_cancel_background_fill_hover_dark = "red"
self.button_cancel_background_fill_hover_dark = 'red'
self.button_cancel_border_color_dark = self.cancel_color
self.button_cancel_border_color_hover_dark = "red"
self.button_cancel_text_color_dark = "white"
self.button_cancel_border_color_hover_dark = 'red'
self.button_cancel_text_color_dark = 'white'
# def _get_theme_css(self):
# sup = super()._get_theme_css()
# return "\n.svelte-182fdeq {\nbackground: rgba(255, 0, 0, 0.5) !important;\n}\n" + sup # You have to use !important, so it overrides other css
\ No newline at end of file
# from ._sync import Sync # this will be added back after future development
from ._loading import load, load_mesh
from ._downloader import Downloader
from ._saving import save, save_mesh
# from ._sync import Sync # this will be added back after future development
from ._convert import convert
from ._ome_zarr import export_ome_zarr, import_ome_zarr
......@@ -6,21 +6,24 @@ import nibabel as nib
import numpy as np
import tifffile as tiff
import zarr
from tqdm import tqdm
import zarr.core
import qim3d
from tqdm import tqdm
from qim3d.utils._misc import stringify_path
from qim3d.io import save
class Convert:
def __init__(self, **kwargs):
"""Utility class to convert files to other formats without loading the entire file into memory
"""
Utility class to convert files to other formats without loading the entire file into memory
Args:
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
self.chunk_shape = kwargs.get('chunk_shape', (64, 64, 64))
def convert(self, input_path: str, output_path: str):
def get_file_extension(file_path):
......@@ -29,6 +32,7 @@ class Convert:
root, ext2 = os.path.splitext(root)
ext = ext2 + ext
return ext
# Stringify path in case it is not already a string
input_path = stringify_path(input_path)
input_ext = get_file_extension(input_path)
......@@ -37,28 +41,30 @@ class Convert:
if os.path.isfile(input_path):
match input_ext, output_ext:
case (".tif", ".zarr") | (".tiff", ".zarr"):
case ('.tif', '.zarr') | ('.tiff', '.zarr'):
return self.convert_tif_to_zarr(input_path, output_path)
case (".nii", ".zarr") | (".nii.gz", ".zarr"):
case ('.nii', '.zarr') | ('.nii.gz', '.zarr'):
return self.convert_nifti_to_zarr(input_path, output_path)
case _:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# Load a directory
elif os.path.isdir(input_path):
match input_ext, output_ext:
case (".zarr", ".tif") | (".zarr", ".tiff"):
case ('.zarr', '.tif') | ('.zarr', '.tiff'):
return self.convert_zarr_to_tif(input_path, output_path)
case (".zarr", ".nii"):
case ('.zarr', '.nii'):
return self.convert_zarr_to_nifti(input_path, output_path)
case (".zarr", ".nii.gz"):
return self.convert_zarr_to_nifti(input_path, output_path, compression=True)
case ('.zarr', '.nii.gz'):
return self.convert_zarr_to_nifti(
input_path, output_path, compression=True
)
case _:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# Fail
else:
# Find the closest matching path to warn the user
parent_dir = os.path.dirname(input_path) or "."
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
parent_dir = os.path.dirname(input_path) or '.'
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ''
valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
similar_paths = difflib.get_close_matches(input_path, valid_paths)
if similar_paths:
......@@ -66,10 +72,11 @@ class Convert:
message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message))
else:
raise ValueError("Invalid path")
raise ValueError('Invalid path')
def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file
"""
Convert a tiff file to a zarr file
Args:
tif_path (str): path to the tiff file
......@@ -77,10 +84,15 @@ class Convert:
Returns:
zarr.core.Array: zarr array containing the data from the tiff file
"""
vol = tiff.memmap(tif_path)
z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
)
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar
......@@ -98,7 +110,8 @@ class Convert:
return z
def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file
"""
Convert a zarr file to a tiff file
Args:
zarr_path (str): path to the zarr file
......@@ -106,12 +119,14 @@ class Convert:
returns:
None
"""
z = zarr.open(zarr_path)
save(tif_path, z)
qim3d.io.save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file
"""
Convert a nifti file to a zarr file
Args:
nifti_path (str): path to the nifti file
......@@ -119,10 +134,15 @@ class Convert:
Returns:
zarr.core.Array: zarr array containing the data from the nifti file
"""
vol = nib.load(nifti_path).dataobj
z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
)
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar
......@@ -139,8 +159,11 @@ class Convert:
return z
def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
"""Convert a zarr file to a nifti file
def convert_zarr_to_nifti(
self, zarr_path: str, nifti_path: str, compression: bool = False
) -> None:
"""
Convert a zarr file to a nifti file
Args:
zarr_path (str): path to the zarr file
......@@ -148,18 +171,23 @@ class Convert:
Returns:
None
"""
z = zarr.open(zarr_path)
save(nifti_path, z, compression=compression)
qim3d.io.save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
"""Convert a file to another format without loading the entire file into memory
def convert(
input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)
) -> None:
"""
Convert a file to another format without loading the entire file into memory
Args:
input_path (str): path to the input file
output_path (str): path to the output file
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
converter = Convert(chunk_shape=chunk_shape)
converter.convert(input_path, output_path)
"Manages downloads and access to data"
"""Manages downloads and access to data"""
import os
import urllib.request
from urllib.parse import quote
import outputformat as ouf
from tqdm import tqdm
from pathlib import Path
from qim3d.io import load
from qim3d.utils import log
import outputformat as ouf
class Downloader:
"""Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
"""
Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
Attributes:
folder_name (str or os.PathLike): Folder class with the name of the folder in <https://data.qim.dk/>
......@@ -59,17 +60,18 @@ class Downloader:
qim3d.viz.slicer_orthogonal(data, color_map="magma")
```
![cowry shell](../../assets/screenshots/cowry_shell_slicer.gif)
"""
def __init__(self):
folders = _extract_names()
for idx, folder in enumerate(folders):
exec(f"self.{folder} = self._Myfolder(folder)")
exec(f'self.{folder} = self._Myfolder(folder)')
def list_files(self):
"""Generate and print formatted folder, file, and size information."""
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
folders = _extract_names()
......@@ -78,17 +80,20 @@ class Downloader:
files = _extract_names(folder)
for file in files:
url = os.path.join(url_dl, folder, file).replace("\\", "/")
url = os.path.join(url_dl, folder, file).replace('\\', '/')
file_size = _get_file_size(url)
formatted_file = f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}"
formatted_file = (
f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}"
)
formatted_size = _format_file_size(file_size)
path_string = f'{folder}.{formatted_file}'
log.info(f'{path_string:<50}({formatted_size})')
class _Myfolder:
"""Class for extracting the files from each folder in the Downloader class.
"""
Class for extracting the files from each folder in the Downloader class.
Args:
folder(str): name of the folder of interest in the QIM data repository.
......@@ -99,6 +104,7 @@ class Downloader:
[file_name_2](load_file,optional): Function to download file number 2 in the given folder.
...
[file_name_n](load_file,optional): Function to download file number n in the given folder.
"""
def __init__(self, folder: str):
......@@ -107,14 +113,15 @@ class Downloader:
for idx, file in enumerate(files):
# Changes names to usable function name.
file_name = file
if ("%20" in file) or ("-" in file):
file_name = file_name.replace("%20", "_")
file_name = file_name.replace("-", "_")
if ('%20' in file) or ('-' in file):
file_name = file_name.replace('%20', '_')
file_name = file_name.replace('-', '_')
setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
def _make_fn(self, folder: str, file: str):
"""Private method that returns a function. The function downloads the chosen file from the folder.
"""
Private method that returns a function. The function downloads the chosen file from the folder.
Args:
folder(str): Folder where the file is located.
......@@ -122,23 +129,26 @@ class Downloader:
Returns:
function: the function used to download the file.
"""
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
def _download(load_file: bool = False, virtual_stack: bool = True):
"""Downloads the file and optionally also loads it.
"""
Downloads the file and optionally also loads it.
Args:
load_file(bool,optional): Whether to simply download or also load the file.
Returns:
virtual_stack: The loaded image.
"""
download_file(url_dl, folder, file)
if load_file == True:
log.info(f"\nLoading {file}")
log.info(f'\nLoading {file}')
file_path = os.path.join(folder, file)
return load(path=file_path, virtual_stack=virtual_stack)
......@@ -159,38 +169,40 @@ def _get_file_size(url: str):
Helper function for the ´download_file()´ function. Finds the size of the file.
"""
return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
return int(urllib.request.urlopen(url).info().get('Content-Length', -1))
def download_file(path: str, name: str, file: str):
"""Downloads the file from path / name / file.
"""
Downloads the file from path / name / file.
Args:
path(str): path to the folders available.
name(str): name of the folder of interest.
file(str): name of the file to be downloaded.
"""
if not os.path.exists(name):
os.makedirs(name)
url = os.path.join(path, name, file).replace("\\", "/") # if user is on windows
url = os.path.join(path, name, file).replace('\\', '/') # if user is on windows
file_path = os.path.join(name, file)
if os.path.exists(file_path):
log.warning(f"File already downloaded:\n{os.path.abspath(file_path)}")
log.warning(f'File already downloaded:\n{os.path.abspath(file_path)}')
return
else:
log.info(
f"Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}"
f'Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}'
)
if " " in url:
url = quote(url, safe=":/")
if ' ' in url:
url = quote(url, safe=':/')
with tqdm(
total=_get_file_size(url),
unit="B",
unit='B',
unit_scale=True,
unit_divisor=1024,
ncols=80,
......@@ -203,28 +215,31 @@ def download_file(path: str, name: str, file: str):
def _extract_html(url: str):
"""Extracts the html content of a webpage in "utf-8"
"""
Extracts the html content of a webpage in "utf-8"
Args:
url(str): url to the location where all the data is stored.
Returns:
html_content(str): decoded html.
"""
try:
with urllib.request.urlopen(url) as response:
html_content = response.read().decode(
"utf-8"
'utf-8'
) # Assuming the content is in UTF-8 encoding
except urllib.error.URLError as e:
log.warning(f"Failed to retrieve data from {url}. Error: {e}")
log.warning(f'Failed to retrieve data from {url}. Error: {e}')
return html_content
def _extract_names(name: str = None):
"""Extracts the names of the folders and files.
"""
Extracts the names of the folders and files.
Finds the names of either the folders if no name is given,
or all the names of all files in the given folder.
......@@ -235,31 +250,33 @@ def _extract_names(name: str = None):
Returns:
list: If name is None, returns a list of all folders available.
If name is not None, returns a list of all files available in the given 'name' folder.
"""
url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository"
url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository'
if name:
datapath = os.path.join(url, name).replace("\\", "/")
datapath = os.path.join(url, name).replace('\\', '/')
html_content = _extract_html(datapath)
data_split = html_content.split(
"files/public/projects/viscomp_data_repository/"
'files/public/projects/viscomp_data_repository/'
)[3:]
data_files = [
element.split(" ")[0][(len(name) + 1) : -3] for element in data_split
element.split(' ')[0][(len(name) + 1) : -3] for element in data_split
]
return data_files
else:
html_content = _extract_html(url)
split = html_content.split('"icon-folder-open">')[2:]
folders = [element.split(" ")[0][4:-4] for element in split]
folders = [element.split(' ')[0][4:-4] for element in split]
return folders
def _format_file_size(size_in_bytes):
# Define size units
units = ["B", "KB", "MB", "GB", "TB", "PB"]
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
size = float(size_in_bytes)
unit_index = 0
......@@ -269,4 +286,4 @@ def _format_file_size(size_in_bytes):
unit_index += 1
# Format the size with 1 decimal place
return f"{size:.2f}{units[unit_index]}"
return f'{size:.2f}{units[unit_index]}'
......@@ -13,31 +13,33 @@ Example:
import difflib
import os
import re
from pathlib import Path
from typing import Dict, Optional
import dask
import dask.array as da
import numpy as np
import tifffile
import trimesh
from dask import delayed
from PIL import Image, UnidentifiedImageError
import qim3d
from qim3d.utils import log
from qim3d.utils import Memory, log
from qim3d.utils._misc import get_file_size, sizeof, stringify_path
from qim3d.utils import Memory
from qim3d.utils._progress_bar import FileLoadingProgressBar
import trimesh
from pygel3d import hmesh
from typing import Optional, Dict
dask.config.set(scheduler="processes")
class DataLoader:
"""Utility class for loading data from different file formats.
Attributes:
"""
Utility class for loading data from different file formats.
Attributes
virtual_stack (bool): Specifies whether virtual stack is enabled.
dataset_name (str): Specifies the name of the dataset to be loaded
(only relevant for HDF5 files)
......@@ -46,17 +48,19 @@ class DataLoader:
contains (str): Specifies a part of the name that is common for the
TIFF file stack to be loaded (only relevant for TIFF stacks)
Methods:
Methods
load_tiff(path): Load a TIFF file from the specified path.
load_h5(path): Load an HDF5 file from the specified path.
load_tiff_stack(path): Load a stack of TIFF files from the specified path.
load_txrm(path): Load a TXRM/TXM/XRM file from the specified path
load_vol(path): Load a VOL file from the specified path. Path should point to the .vgi metadata file
load(path): Load a file or directory based on the given path
"""
def __init__(self, **kwargs):
"""Initializes a new instance of the DataLoader class.
"""
Initializes a new instance of the DataLoader class.
Args:
virtual_stack (bool, optional): Specifies whether to use virtual
......@@ -69,17 +73,19 @@ class DataLoader:
force_load (bool, optional): If false and user tries to load file that exceeds available memory, throws a MemoryError. If true, this error is
changed to warning and dataloader tries to load the file. Default is False.
dim_order (tuple, optional): The order of the dimensions in the volume. Default is (2,1,0) which corresponds to (z,y,x)
"""
self.virtual_stack = kwargs.get("virtual_stack", False)
self.dataset_name = kwargs.get("dataset_name", None)
self.return_metadata = kwargs.get("return_metadata", False)
self.contains = kwargs.get("contains", None)
self.force_load = kwargs.get("force_load", False)
self.dim_order = kwargs.get("dim_order", (2, 1, 0))
self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
self.virtual_stack = kwargs.get('virtual_stack', False)
self.dataset_name = kwargs.get('dataset_name', None)
self.return_metadata = kwargs.get('return_metadata', False)
self.contains = kwargs.get('contains', None)
self.force_load = kwargs.get('force_load', False)
self.dim_order = kwargs.get('dim_order', (2, 1, 0))
self.PIL_extensions = ('.jp2', '.jpg', 'jpeg', '.png', 'gif', '.bmp', '.webp')
def load_tiff(self, path: str | os.PathLike):
"""Load a TIFF file from the specified path.
"""
Load a TIFF file from the specified path.
Args:
path (str): The path to the TIFF file.
......@@ -98,12 +104,13 @@ class DataLoader:
else:
vol = tifffile.imread(path, key=range(series) if series > 1 else None)
log.info("Loaded shape: %s", vol.shape)
log.info('Loaded shape: %s', vol.shape)
return vol
def load_h5(self, path: str | os.PathLike) -> tuple[np.ndarray, Optional[Dict]]:
"""Load an HDF5 file from the specified path.
"""
Load an HDF5 file from the specified path.
Args:
path (str): The path to the HDF5 file.
......@@ -117,11 +124,12 @@ class DataLoader:
ValueError: If the specified dataset_name is not found or is invalid.
ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file
ValueError: If no datasets are found in the file.
"""
import h5py
# Read file
f = h5py.File(path, "r")
f = h5py.File(path, 'r')
data_keys = _get_h5_dataset_keys(f)
datasets = []
metadata = {}
......@@ -132,7 +140,7 @@ class DataLoader:
datasets.append(key)
if f[key].attrs.keys():
metadata[key] = {
"value": f[key][()],
'value': f[key][()],
**{attr_key: val for attr_key, val in f[key].attrs.items()},
}
......@@ -162,7 +170,7 @@ class DataLoader:
)
else:
raise ValueError(
f"Invalid dataset name. Please choose between the following datasets: {datasets}"
f'Invalid dataset name. Please choose between the following datasets: {datasets}'
)
else:
raise ValueError(
......@@ -171,14 +179,14 @@ class DataLoader:
# No datasets were found
else:
raise ValueError(f"Did not find any data in the file: {path}")
raise ValueError(f'Did not find any data in the file: {path}')
if not self.virtual_stack:
vol = vol[()] # Load dataset into memory
f.close()
log.info("Loaded the following dataset: %s", name)
log.info("Loaded shape: %s", vol.shape)
log.info('Loaded the following dataset: %s', name)
log.info('Loaded shape: %s', vol.shape)
if self.return_metadata:
return vol, metadata
......@@ -186,7 +194,8 @@ class DataLoader:
return vol
def load_tiff_stack(self, path: str | os.PathLike) -> np.ndarray | np.memmap:
"""Load a stack of TIFF files from the specified path.
"""
Load a stack of TIFF files from the specified path.
Args:
path (str): The path to the stack of TIFF files.
......@@ -198,6 +207,7 @@ class DataLoader:
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory
"""
if not self.contains:
raise ValueError(
......@@ -207,7 +217,7 @@ class DataLoader:
tiff_stack = [
file
for file in os.listdir(path)
if (file.endswith(".tif") or file.endswith(".tiff"))
if (file.endswith('.tif') or file.endswith('.tiff'))
and self.contains in file
]
tiff_stack.sort() # Ensure proper ordering
......@@ -217,30 +227,33 @@ class DataLoader:
for filename in tiff_stack:
name = os.path.splitext(filename)[0] # Remove file extension
tiff_stack_only_letters.append(
"".join(filter(str.isalpha, name))
''.join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements from tiff_stack_only_letters
unique_names = list(set(tiff_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded"
f'The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded'
)
vol = tifffile.imread(
[os.path.join(path, file) for file in tiff_stack], out="memmap"
[os.path.join(path, file) for file in tiff_stack], out='memmap'
)
if not self.virtual_stack:
vol = np.copy(vol) # Copy to memory
log.info("Found %s file(s)", len(tiff_stack))
log.info("Loaded shape: %s", vol.shape)
log.info('Found %s file(s)', len(tiff_stack))
log.info('Loaded shape: %s', vol.shape)
return vol
def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]:
"""Load a TXRM/XRM/TXM file from the specified path.
def load_txrm(
self, path: str | os.PathLike
) -> tuple[dask.array.core.Array | np.ndarray, Optional[Dict]]:
"""
Load a TXRM/XRM/TXM file from the specified path.
Args:
path (str): The path to the TXRM/TXM file.
......@@ -252,6 +265,7 @@ class DataLoader:
Raises:
ValueError: If the dxchange library is not installed
"""
import olefile
......@@ -259,13 +273,13 @@ class DataLoader:
import dxchange
except ImportError:
raise ValueError(
"The library dxchange is required to load TXRM files. Please find installation instructions at https://dxchange.readthedocs.io/en/latest/source/install.html"
'The library dxchange is required to load TXRM files. Please find installation instructions at https://dxchange.readthedocs.io/en/latest/source/install.html'
)
if self.virtual_stack:
if not path.endswith(".txm"):
if not path.endswith('.txm'):
log.warning(
"Virtual stack is only thoroughly tested for reconstructed volumes in TXM format and is thus not guaranteed to load TXRM and XRM files correctly"
'Virtual stack is only thoroughly tested for reconstructed volumes in TXM format and is thus not guaranteed to load TXRM and XRM files correctly'
)
# Get metadata
......@@ -275,7 +289,7 @@ class DataLoader:
# Compute data offsets in bytes for each slice
offsets = _get_ole_offsets(ole)
if len(offsets) != metadata["number_of_images"]:
if len(offsets) != metadata['number_of_images']:
raise ValueError(
f'Metadata is erroneous: number of images {metadata["number_of_images"]} is different from number of data offsets {len(offsets)}'
)
......@@ -286,17 +300,17 @@ class DataLoader:
np.memmap(
path,
dtype=dxchange.reader._get_ole_data_type(metadata).newbyteorder(
"<"
'<'
),
mode="r",
mode='r',
offset=offset,
shape=(1, metadata["image_height"], metadata["image_width"]),
shape=(1, metadata['image_height'], metadata['image_width']),
)
)
vol = da.concatenate(slices, axis=0)
log.warning(
"Virtual stack volume will be returned as a dask array. To load certain slices into memory, use normal indexing followed by the compute() method, e.g. vol[:,0,:].compute()"
'Virtual stack volume will be returned as a dask array. To load certain slices into memory, use normal indexing followed by the compute() method, e.g. vol[:,0,:].compute()'
)
else:
......@@ -311,7 +325,8 @@ class DataLoader:
return vol
def load_nifti(self, path: str | os.PathLike):
"""Load a NIfTI file from the specified path.
"""
Load a NIfTI file from the specified path.
Args:
path (str): The path to the NIfTI file.
......@@ -320,6 +335,7 @@ class DataLoader:
numpy.ndarray, nibabel.arrayproxy.ArrayProxy or tuple: The loaded volume.
If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object
If 'self.return_metadata' is True, returns a tuple (volume, metadata).
"""
import nibabel as nib
......@@ -341,18 +357,21 @@ class DataLoader:
return vol
def load_pil(self, path: str | os.PathLike):
"""Load a PIL image from the specified path
"""
Load a PIL image from the specified path
Args:
path (str): The path to the image supported by PIL.
Returns:
numpy.ndarray: The loaded image/volume.
"""
return np.array(Image.open(path))
def load_PIL_stack(self, path: str | os.PathLike):
"""Load a stack of PIL files from the specified path.
"""
Load a stack of PIL files from the specified path.
Args:
path (str): The path to the stack of PIL files.
......@@ -364,6 +383,7 @@ class DataLoader:
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple PIL stacks in the directory
"""
if not self.contains:
raise ValueError(
......@@ -384,18 +404,17 @@ class DataLoader:
for filename in PIL_stack:
name = os.path.splitext(filename)[0] # Remove file extension
PIL_stack_only_letters.append(
"".join(filter(str.isalpha, name))
''.join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements
unique_names = list(set(PIL_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded"
f'The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded'
)
if self.virtual_stack:
full_paths = [os.path.join(path, file) for file in PIL_stack]
def lazy_loader(path):
......@@ -411,7 +430,8 @@ class DataLoader:
# Stack the images into a single Dask array
dask_images = [
da.from_delayed(img, shape=image_shape, dtype=dtype) for img in lazy_images
da.from_delayed(img, shape=image_shape, dtype=dtype)
for img in lazy_images
]
stacked = da.stack(dask_images, axis=0)
......@@ -420,29 +440,28 @@ class DataLoader:
else:
# Generate placeholder volume
first_image = self.load_pil(os.path.join(path, PIL_stack[0]))
vol = np.zeros((len(PIL_stack), *first_image.shape), dtype=first_image.dtype)
vol = np.zeros(
(len(PIL_stack), *first_image.shape), dtype=first_image.dtype
)
# Load file sequence
for idx, file_name in enumerate(PIL_stack):
vol[idx] = self.load_pil(os.path.join(path, file_name))
return vol
# log.info("Found %s file(s)", len(PIL_stack))
# log.info("Loaded shape: %s", vol.shape)
def _load_vgi_metadata(self, path: str | os.PathLike):
"""Helper functions that loads metadata from a VGI file
"""
Helper functions that loads metadata from a VGI file
Args:
path (str): The path to the VGI file.
returns:
dict: The loaded metadata.
"""
meta_data = {}
current_section = meta_data
......@@ -450,11 +469,11 @@ class DataLoader:
should_indent = True
with open(path, "r") as f:
with open(path) as f:
for line in f:
line = line.strip()
# {NAME} is start of a new object, so should indent
if line.startswith("{") and line.endswith("}"):
if line.startswith('{') and line.endswith('}'):
section_name = line[1:-1]
current_section[section_name] = {}
section_stack.append(current_section)
......@@ -462,7 +481,7 @@ class DataLoader:
should_indent = True
# [NAME] is start of a section, so should not indent
elif line.startswith("[") and line.endswith("]"):
elif line.startswith('[') and line.endswith(']'):
section_name = line[1:-1]
if not should_indent:
......@@ -475,17 +494,18 @@ class DataLoader:
should_indent = False
# = is a key value pair
elif "=" in line:
key, value = line.split("=", 1)
elif '=' in line:
key, value = line.split('=', 1)
current_section[key.strip()] = value.strip()
elif line == "":
elif line == '':
if len(section_stack) > 1:
current_section = section_stack.pop()
return meta_data
def load_vol(self, path: str | os.PathLike):
"""Load a VOL filed based on the VGI metadata file
"""
Load a VOL filed based on the VGI metadata file
Args:
path (str): The path to the VGI file.
......@@ -496,43 +516,44 @@ class DataLoader:
returns:
numpy.ndarray, numpy.memmap or tuple: The loaded volume.
If 'self.return_metadata' is True, returns a tuple (volume, metadata).
"""
# makes sure path point to .VGI metadata file and not the .VOL file
if path.endswith(".vol") and os.path.isfile(path.replace(".vol", ".vgi")):
path = path.replace(".vol", ".vgi")
log.warning("Corrected path to .vgi metadata file from .vol file")
elif path.endswith(".vol") and not os.path.isfile(path.replace(".vol", ".vgi")):
if path.endswith('.vol') and os.path.isfile(path.replace('.vol', '.vgi')):
path = path.replace('.vol', '.vgi')
log.warning('Corrected path to .vgi metadata file from .vol file')
elif path.endswith('.vol') and not os.path.isfile(path.replace('.vol', '.vgi')):
raise ValueError(
f"Unsupported file format, should point to .vgi metadata file assumed to be in same folder as .vol file: {path}"
f'Unsupported file format, should point to .vgi metadata file assumed to be in same folder as .vol file: {path}'
)
meta_data = self._load_vgi_metadata(path)
# Extracts relevant information from the metadata
file_name = meta_data["volume1"]["file1"]["Name"]
path = path.rsplit("/", 1)[
file_name = meta_data['volume1']['file1']['Name']
path = path.rsplit('/', 1)[
0
] # Remove characters after the last "/" to be replaced with .vol filename
vol_path = os.path.join(
path, file_name
) # .vol and .vgi files are assumed to be in the same directory
dims = meta_data["volume1"]["file1"]["Size"]
dims = meta_data['volume1']['file1']['Size']
dims = [int(n) for n in dims.split() if n.isdigit()]
dt = meta_data["volume1"]["file1"]["Datatype"]
dt = meta_data['volume1']['file1']['Datatype']
match dt:
case "float":
case 'float':
dt = np.float32
case "float32":
case 'float32':
dt = np.float32
case "uint8":
case 'uint8':
dt = np.uint8
case "unsigned integer":
case 'unsigned integer':
dt = np.uint16
case "uint16":
case 'uint16':
dt = np.uint16
case _:
raise ValueError(f"Unsupported data type: {dt}")
raise ValueError(f'Unsupported data type: {dt}')
dims_order = (
dims[self.dim_order[0]],
......@@ -540,7 +561,7 @@ class DataLoader:
dims[self.dim_order[2]],
)
if self.virtual_stack:
vol = np.memmap(vol_path, dtype=dt, mode="r", shape=dims_order)
vol = np.memmap(vol_path, dtype=dt, mode='r', shape=dims_order)
else:
vol = np.fromfile(vol_path, dtype=dt, count=np.prod(dims))
vol = np.reshape(vol, dims_order)
......@@ -551,10 +572,12 @@ class DataLoader:
return vol
def load_dicom(self, path: str | os.PathLike):
"""Load a DICOM file
"""
Load a DICOM file
Args:
path (str): Path to file
"""
import pydicom
......@@ -566,7 +589,8 @@ class DataLoader:
return dcm_data.pixel_array
def load_dicom_dir(self, path: str | os.PathLike):
"""Load a directory of DICOM files into a numpy 3d array
"""
Load a directory of DICOM files into a numpy 3d array
Args:
path (str): Directory path
......@@ -574,6 +598,7 @@ class DataLoader:
returns:
numpy.ndarray, numpy.memmap or tuple: The loaded volume.
If 'self.return_metadata' is True, returns a tuple (volume, metadata).
"""
import pydicom
......@@ -590,14 +615,14 @@ class DataLoader:
for filename in dicom_stack:
name = os.path.splitext(filename)[0] # Remove file extension
dicom_stack_only_letters.append(
"".join(filter(str.isalpha, name))
''.join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements from tiff_stack_only_letters
unique_names = list(set(dicom_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded"
f'The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded'
)
# dicom_list contains the dicom objects with metadata
......@@ -610,9 +635,9 @@ class DataLoader:
else:
return vol
def load_zarr(self, path: str | os.PathLike):
""" Loads a Zarr array from disk.
"""
Loads a Zarr array from disk.
Args:
path (str): The path to the Zarr array on disk.
......@@ -620,6 +645,7 @@ class DataLoader:
Returns:
dask.array | numpy.ndarray: The dask array loaded from disk.
if 'self.virtual_stack' is True, returns a dask array object, else returns a numpy.ndarray object.
"""
# Opens the Zarr array
......@@ -634,25 +660,25 @@ class DataLoader:
def check_file_size(self, filename: str):
"""
Checks if there is enough memory where the file can be loaded.
Args:
------------
----
filename: (str) Specifies path to file
force_load: (bool, optional) If true, the memory error will not be raised. Warning will be printed insted and
the loader will attempt to load the file.
Raises:
-----------
------
MemoryError: If filesize is greater then available memory
"""
if (
self.virtual_stack
): # If virtual_stack is True, then data is loaded from the disk, no need for loading into memory
if self.virtual_stack: # If virtual_stack is True, then data is loaded from the disk, no need for loading into memory
return
file_size = get_file_size(filename)
available_memory = Memory().free
if file_size > available_memory:
message = f"The file {filename} has {sizeof(file_size)} but only {sizeof(available_memory)} of memory is available."
message = f'The file {filename} has {sizeof(file_size)} but only {sizeof(available_memory)} of memory is available.'
if self.force_load:
log.warning(message)
else:
......@@ -677,6 +703,7 @@ class DataLoader:
ValueError: If the format is not supported
ValueError: If the file or directory does not exist.
MemoryError: If file size exceeds available memory and force_load is not set to True. In check_size function.
"""
# Stringify path in case it is not already a string
......@@ -686,35 +713,35 @@ class DataLoader:
if os.path.isfile(path):
# Choose the loader based on the file extension
self.check_file_size(path)
if path.endswith(".tif") or path.endswith(".tiff"):
if path.endswith('.tif') or path.endswith('.tiff'):
return self.load_tiff(path)
elif path.endswith(".h5"):
elif path.endswith('.h5'):
return self.load_h5(path)
elif path.endswith((".txrm", ".txm", ".xrm")):
elif path.endswith(('.txrm', '.txm', '.xrm')):
return self.load_txrm(path)
elif path.endswith((".nii", ".nii.gz")):
elif path.endswith(('.nii', '.nii.gz')):
return self.load_nifti(path)
elif path.endswith((".vol", ".vgi")):
elif path.endswith(('.vol', '.vgi')):
return self.load_vol(path)
elif path.endswith((".dcm", ".DCM")):
elif path.endswith(('.dcm', '.DCM')):
return self.load_dicom(path)
else:
try:
return self.load_pil(path)
except UnidentifiedImageError:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# Load a directory
elif os.path.isdir(path):
# load tiff stack if folder contains tiff files else load dicom directory
if any(
[f.endswith(".tif") or f.endswith(".tiff") for f in os.listdir(path)]
[f.endswith('.tif') or f.endswith('.tiff') for f in os.listdir(path)]
):
return self.load_tiff_stack(path)
elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]):
return self.load_PIL_stack(path)
elif path.endswith(".zarr"):
elif path.endswith('.zarr'):
return self.load_zarr(path)
else:
return self.load_dicom_dir(path)
......@@ -729,7 +756,7 @@ class DataLoader:
message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message))
else:
raise ValueError("Invalid path")
raise ValueError('Invalid path')
def _get_h5_dataset_keys(f):
......@@ -743,18 +770,18 @@ def _get_h5_dataset_keys(f):
def _get_ole_offsets(ole):
slice_offset = {}
for stream in ole.listdir():
if stream[0].startswith("ImageData"):
if stream[0].startswith('ImageData'):
sid = ole._find(stream)
direntry = ole.direntries[sid]
sect_start = direntry.isectStart
offset = ole.sectorsize * (sect_start + 1)
slice_offset[f"{stream[0]}/{stream[1]}"] = offset
slice_offset[f'{stream[0]}/{stream[1]}'] = offset
# sort dictionary after natural sorting (https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/)
sorted_keys = sorted(
slice_offset.keys(),
key=lambda string_: [
int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_)
int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)
],
)
slice_offset_sorted = {key: slice_offset[key] for key in sorted_keys}
......@@ -822,6 +849,7 @@ def load(
vol = qim3d.io.load("path/to/image.tif", virtual_stack=True)
```
"""
loader = DataLoader(
......@@ -843,13 +871,13 @@ def load(
def log_memory_info(data):
mem = Memory()
log.info(
"Volume using %s of memory\n",
'Volume using %s of memory\n',
sizeof(data[0].nbytes if isinstance(data, tuple) else data.nbytes),
)
mem.report()
if return_metadata and not isinstance(data, tuple):
log.warning("The file format does not contain metadata")
log.warning('The file format does not contain metadata')
if not virtual_stack:
log_memory_info(data)
......@@ -858,22 +886,31 @@ def load(
if not isinstance(
type(data[0]) if isinstance(data, tuple) else type(data), np.ndarray
):
log.info("Using virtual stack")
log.info('Using virtual stack')
else:
log.warning("Virtual stack is not supported for this file format")
log.warning('Virtual stack is not supported for this file format')
log_memory_info(data)
return data
def load_mesh(filename: str) -> trimesh.Trimesh:
def load_mesh(filename: str) -> hmesh.Manifold:
"""
Load a mesh from an .obj file using trimesh.
Load a mesh from a specific file.
This function is based on the [PyGEL3D library's loading function implementation](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#load).
Supported formats:
- `X3D`
- `OBJ`
- `OFF`
- `PLY`
Args:
filename (str or os.PathLike): The path to the .obj file.
filename (str or os.PathLike): The path to the file.
Returns:
mesh (trimesh.Trimesh): A trimesh object containing the mesh data (vertices and faces).
mesh (hmesh.Manifold or None): A hmesh object containing the mesh data or None if loading failed.
Example:
```python
......@@ -881,6 +918,8 @@ def load_mesh(filename: str) -> trimesh.Trimesh:
mesh = qim3d.io.load_mesh("path/to/mesh.obj")
```
"""
mesh = trimesh.load(filename)
mesh = hmesh.load(filename)
return mesh
\ No newline at end of file
......@@ -2,39 +2,27 @@
Exporting data to different formats.
"""
import os
import math
import os
import shutil
import logging
from typing import List, Union
import dask.array as da
import numpy as np
import zarr
import tqdm
from ome_zarr import scale
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from ome_zarr.scale import dask_resize
from ome_zarr.writer import (
write_image,
_create_mip,
write_multiscale,
CurrentFormat,
Format,
write_multiscale,
)
from ome_zarr.scale import dask_resize
from ome_zarr.reader import Reader
from ome_zarr import scale
from scipy.ndimage import zoom
from typing import Any, Callable, Iterator, List, Tuple, Union
import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from skimage.transform import (
resize,
)
from qim3d.utils import log
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
from qim3d.utils._ome_zarr import get_n_chunks
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
ListOfArrayLike = Union[List[da.Array], List[np.ndarray]]
ArrayLike = Union[da.Array, np.ndarray]
......@@ -43,10 +31,19 @@ ArrayLike = Union[da.Array, np.ndarray]
class OMEScaler(
scale.Scaler,
):
"""Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken."""
def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
"""
Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken.
"""
def __init__(
self,
order: int = 0,
downscale: float = 2,
max_layer: int = 5,
method: str = 'scaleZYXdask',
):
self.order = order
self.downscale = downscale
self.max_layer = max_layer
......@@ -55,11 +52,11 @@ class OMEScaler(
def scaleZYX(self, base: da.core.Array):
"""Downsample using :func:`scipy.ndimage.zoom`."""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order))
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
......@@ -82,8 +79,8 @@ class OMEScaler(
"""
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library
better_chunksize = tuple(
......@@ -92,7 +89,7 @@ class OMEScaler(
).astype(int)
)
log.debug(f"better chunk size: {better_chunksize}")
log.debug(f'better chunk size: {better_chunksize}')
# Compute the chunk size after the downscaling
new_chunk_size = tuple(
......@@ -100,17 +97,16 @@ class OMEScaler(
)
log.debug(
f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}"
f'orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}'
)
def resize_chunk(chunk, scale_factors, order):
# print(f"zoom factors: {scale_factors}")
resized_chunk = zoom(
chunk,
zoom=scale_factors,
order=order,
mode="grid-constant",
mode='grid-constant',
grid_mode=True,
)
# print(f"resized chunk shape: {resized_chunk.shape}")
......@@ -121,7 +117,7 @@ class OMEScaler(
# Testing new shape
predicted_shape = np.multiply(vol.shape, scale_factors)
log.debug(f"predicted shape: {predicted_shape}")
log.debug(f'predicted shape: {predicted_shape}')
scaled_vol = da.map_blocks(
resize_chunk,
vol,
......@@ -136,7 +132,7 @@ class OMEScaler(
return scaled_vol
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
log.debug(f"\nScale {i+1}\n{'-'*32}")
......@@ -147,17 +143,17 @@ class OMEScaler(
np.ceil(np.multiply(base.shape, downscale_factor)).astype(int)
)
log.debug(f"target shape: {scaled_shape}")
log.debug(f'target shape: {scaled_shape}')
downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float))
log.debug(f"downscale rate: {downscale_rate}")
log.debug(f'downscale rate: {downscale_rate}')
scale_factors = tuple(np.divide(1, downscale_rate))
log.debug(f"scale factors: {scale_factors}")
log.debug(f'scale factors: {scale_factors}')
log.debug("\nResizing volume chunk-wise")
log.debug('\nResizing volume chunk-wise')
scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape)
rv.append(scaled_vol)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
......@@ -165,10 +161,9 @@ class OMEScaler(
"""Downsample using the original OME-Zarr python library"""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
scaled_shape = tuple(
base.shape[j] // (self.downscale ** (i + 1)) for j in range(3)
)
......@@ -176,7 +171,7 @@ class OMEScaler(
scaled = dask_resize(base, scaled_shape, order=self.order)
rv.append(scaled)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
......@@ -187,9 +182,9 @@ def export_ome_zarr(
downsample_rate: int = 2,
order: int = 1,
replace: bool = False,
method: str = "scaleZYX",
method: str = 'scaleZYX',
progress_bar: bool = True,
progress_bar_repeat_time: str = "auto",
progress_bar_repeat_time: str = 'auto',
) -> None:
"""
Export 3D image data to OME-Zarr format with pyramidal downsampling.
......@@ -220,6 +215,7 @@ def export_ome_zarr(
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2)
```
"""
# Check if directory exists
......@@ -228,19 +224,19 @@ def export_ome_zarr(
shutil.rmtree(path)
else:
raise ValueError(
f"Directory {path} already exists. Use replace=True to overwrite."
f'Directory {path} already exists. Use replace=True to overwrite.'
)
# Check if downsample_rate is valid
if downsample_rate <= 1:
raise ValueError("Downsample rate must be greater than 1.")
raise ValueError('Downsample rate must be greater than 1.')
log.info(f"Exporting data to OME-Zarr format at {path}")
log.info(f'Exporting data to OME-Zarr format at {path}')
# Get the number of scales
min_dim = np.max(np.shape(data))
nscales = math.ceil(math.log(min_dim / chunk_size) / math.log(downsample_rate))
log.info(f"Number of scales: {nscales + 1}")
log.info(f'Number of scales: {nscales + 1}')
# Create scaler
scaler = OMEScaler(
......@@ -249,32 +245,31 @@ def export_ome_zarr(
# write the image data
os.mkdir(path)
store = parse_url(path, mode="w").store
store = parse_url(path, mode='w').store
root = zarr.group(store=store)
# Check if we want to process using Dask
if "dask" in method and not isinstance(data, da.Array):
log.info("\nConverting input data to Dask array")
if 'dask' in method and not isinstance(data, da.Array):
log.info('\nConverting input data to Dask array')
data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
elif "dask" in method and isinstance(data, da.Array):
log.info("\nInput data will be rechunked")
elif 'dask' in method and isinstance(data, da.Array):
log.info('\nInput data will be rechunked')
data = data.rechunk((chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
log.info("Calculating the multi-scale pyramid")
log.info('Calculating the multi-scale pyramid')
# Generate multi-scale pyramid
mip = scaler.func(data)
log.info("Writing data to disk")
log.info('Writing data to disk')
kwargs = dict(
pyramid=mip,
group=root,
fmt=CurrentFormat(),
axes="zyx",
axes='zyx',
name=None,
compute=True,
storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)),
......@@ -291,15 +286,13 @@ def export_ome_zarr(
else:
write_multiscale(**kwargs)
log.info("\nAll done!")
log.info('\nAll done!')
return
def import_ome_zarr(
path: str|os.PathLike,
scale: int = 0,
load: bool = True
path: str | os.PathLike, scale: int = 0, load: bool = True
) -> np.ndarray:
"""
Import image data from an OME-Zarr file.
......@@ -339,22 +332,22 @@ def import_ome_zarr(
image_node = nodes[0]
dask_data = image_node.data
log.info(f"Data contains {len(dask_data)} scales:")
log.info(f'Data contains {len(dask_data)} scales:')
for i in np.arange(len(dask_data)):
log.info(f"- Scale {i}: {dask_data[i].shape}")
log.info(f'- Scale {i}: {dask_data[i].shape}')
if scale == "highest":
if scale == 'highest':
scale = 0
if scale == "lowest":
if scale == 'lowest':
scale = len(dask_data) - 1
if scale >= len(dask_data):
raise ValueError(
f"Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}."
f'Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}.'
)
log.info(f"\nLoading scale {scale} with shape {dask_data[scale].shape}")
log.info(f'\nLoading scale {scale} with shape {dask_data[scale].shape}')
if load:
vol = dask_data[scale].compute()
......
......@@ -26,25 +26,29 @@ Example:
import datetime
import os
import dask.array as da
import h5py
import nibabel as nib
import numpy as np
import PIL
import tifffile
import trimesh
import zarr
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
import trimesh
from pygel3d import hmesh
from qim3d.utils import log
from qim3d.utils._misc import sizeof, stringify_path
from qim3d.utils._misc import stringify_path
class DataSaver:
"""Utility class for saving data to different file formats.
Attributes:
"""
Utility class for saving data to different file formats.
Attributes
replace (bool): Specifies if an existing file with identical path is replaced.
compression (bool): Specifies if the file is saved with Deflate compression (lossless).
basename (str): Specifies the basename for a TIFF stack saved as several files
......@@ -52,13 +56,15 @@ class DataSaver:
sliced_dim (int): Specifies the dimension that is sliced in case a TIFF stack is saved
as several files (only relevant for TIFF stacks)
Methods:
Methods
save_tiff(path,data): Save data to a TIFF file to the given path.
load(path,data): Save data to the given path.
"""
def __init__(self, **kwargs):
"""Initializes a new instance of the DataSaver class.
"""
Initializes a new instance of the DataSaver class.
Args:
replace (bool, optional): Specifies if an existing file with identical path should be replaced.
......@@ -69,35 +75,40 @@ class DataSaver:
(only relevant for TIFF stacks). Default is None
sliced_dim (int, optional): Specifies the dimension that is sliced in case a TIFF stack is saved
as several files (only relevant for TIFF stacks). Default is 0, i.e., the first dimension.
"""
self.replace = kwargs.get("replace", False)
self.compression = kwargs.get("compression", False)
self.basename = kwargs.get("basename", None)
self.sliced_dim = kwargs.get("sliced_dim", 0)
self.chunk_shape = kwargs.get("chunk_shape", "auto")
self.replace = kwargs.get('replace', False)
self.compression = kwargs.get('compression', False)
self.basename = kwargs.get('basename', None)
self.sliced_dim = kwargs.get('sliced_dim', 0)
self.chunk_shape = kwargs.get('chunk_shape', 'auto')
def save_tiff(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to a TIFF file to the given path.
"""
Save data to a TIFF file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
tifffile.imwrite(path, data, compression=self.compression)
def save_tiff_stack(self, path: str | os.PathLike, data: np.ndarray):
"""Save data as a TIFF stack containing slices in separate files to the given path.
"""
Save data as a TIFF stack containing slices in separate files to the given path.
The slices will be named according to the basename plus a suffix with a zero-filled
value corresponding to the slice number
Args:
path (str): The directory to save files to
data (numpy.ndarray): The data to be saved
"""
extension = ".tif"
extension = '.tif'
if data.ndim <= 2:
path = os.path.join(path, self.basename, ".tif")
path = os.path.join(path, self.basename, '.tif')
self.save_tiff(path, data)
else:
# get number of total slices
......@@ -117,7 +128,7 @@ class DataSaver:
self.save_tiff(filepath, sliced)
pattern_string = (
filepath[: -(len(extension) + zfill_val)] + "-" * zfill_val + extension
filepath[: -(len(extension) + zfill_val)] + '-' * zfill_val + extension
)
log.info(
......@@ -125,13 +136,14 @@ class DataSaver:
)
def save_nifti(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to a NIfTI file to the given path.
"""
Save data to a NIfTI file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
import nibabel as nib
# Create header
header = nib.Nifti1Header()
......@@ -141,11 +153,11 @@ class DataSaver:
img = nib.Nifti1Image(data, np.eye(4), header)
# nib does automatically compress if filetype ends with .gz
if self.compression and not path.endswith(".gz"):
path += ".gz"
if self.compression and not path.endswith('.gz'):
path += '.gz'
log.warning("File extension '.gz' is added since compression is enabled.")
if not self.compression and path.endswith(".gz"):
if not self.compression and path.endswith('.gz'):
path = path[:-3]
log.warning(
"File extension '.gz' is ignored since compression is disabled."
......@@ -155,82 +167,83 @@ class DataSaver:
nib.save(img, path)
def save_vol(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to a VOL file to the given path.
"""
Save data to a VOL file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# No support for compression yet
if self.compression:
raise NotImplementedError(
"Saving compressed .vol files is not yet supported"
'Saving compressed .vol files is not yet supported'
)
# Create custom .vgi metadata file
metadata = ""
metadata += "{volume1}\n" # .vgi organization
metadata += "[file1]\n" # .vgi organization
metadata += "Size = {} {} {}\n".format(
data.shape[1], data.shape[2], data.shape[0]
) # Swap axes to match .vol format
metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
metadata += "Name = {}.vol\n".format(
path.rsplit("/", 1)[-1][:-4]
metadata = ''
metadata += '{volume1}\n' # .vgi organization
metadata += '[file1]\n' # .vgi organization
metadata += f'Size = {data.shape[1]} {data.shape[2]} {data.shape[0]}\n' # Swap axes to match .vol format
metadata += f'Datatype = {str(data.dtype)}\n' # Get datatype as string
metadata += 'Name = {}.vol\n'.format(
path.rsplit('/', 1)[-1][:-4]
) # Get filename without extension
# Save metadata
with open(path[:-4] + ".vgi", "w") as f:
with open(path[:-4] + '.vgi', 'w') as f:
f.write(metadata)
# Save data using numpy in binary format
data.tofile(path[:-4] + ".vol")
data.tofile(path[:-4] + '.vol')
def save_h5(self, path, data):
"""Save data to a HDF5 file to the given path.
"""
Save data to a HDF5 file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
import h5py
with h5py.File(path, "w") as f:
with h5py.File(path, 'w') as f:
f.create_dataset(
"dataset", data=data, compression="gzip" if self.compression else None
'dataset', data=data, compression='gzip' if self.compression else None
)
def save_dicom(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to a DICOM file to the given path.
"""
Save data to a DICOM file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
import pydicom
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
# based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html
# Populate required values for file meta information
file_meta = FileMetaDataset()
file_meta.MediaStorageSOPClassUID = UID("1.2.840.10008.5.1.4.1.1.2")
file_meta.MediaStorageSOPInstanceUID = UID("1.2.3")
file_meta.ImplementationClassUID = UID("1.2.3.4")
file_meta.MediaStorageSOPClassUID = UID('1.2.840.10008.5.1.4.1.1.2')
file_meta.MediaStorageSOPInstanceUID = UID('1.2.3')
file_meta.ImplementationClassUID = UID('1.2.3.4')
# Create the FileDataset instance (initially no data elements, but file_meta
# supplied)
ds = FileDataset(path, {}, file_meta=file_meta, preamble=b"\0" * 128)
ds = FileDataset(path, {}, file_meta=file_meta, preamble=b'\0' * 128)
ds.PatientName = "Test^Firstname"
ds.PatientID = "123456"
ds.StudyInstanceUID = "1.2.3.4.5"
ds.PatientName = 'Test^Firstname'
ds.PatientID = '123456'
ds.StudyInstanceUID = '1.2.3.4.5'
ds.SamplesPerPixel = 1
ds.PixelRepresentation = 0
ds.BitsStored = 16
ds.BitsAllocated = 16
ds.PhotometricInterpretation = "MONOCHROME2"
ds.PhotometricInterpretation = 'MONOCHROME2'
ds.Rows = data.shape[1]
ds.Columns = data.shape[2]
ds.NumberOfFrames = data.shape[0]
......@@ -240,8 +253,8 @@ class DataSaver:
# Set creation date/time
dt = datetime.datetime.now()
ds.ContentDate = dt.strftime("%Y%m%d")
timeStr = dt.strftime("%H%M%S.%f") # long format with micro seconds
ds.ContentDate = dt.strftime('%Y%m%d')
timeStr = dt.strftime('%H%M%S.%f') # long format with micro seconds
ds.ContentTime = timeStr
# Needs to be here because of bug in pydicom
ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian
......@@ -256,7 +269,8 @@ class DataSaver:
ds.save_as(path)
def save_to_zarr(self, path: str | os.PathLike, data: da.core.Array):
"""Saves a Dask array to a Zarr array on disk.
"""
Saves a Dask array to a Zarr array on disk.
Args:
path (str): The path to the Zarr array on disk.
......@@ -264,20 +278,21 @@ class DataSaver:
Returns:
zarr.core.Array: The Zarr array saved on disk.
"""
if isinstance(data, da.Array):
# If the data is a Dask array, save using dask
if self.chunk_shape:
log.info("Rechunking data to shape %s", self.chunk_shape)
log.info('Rechunking data to shape %s', self.chunk_shape)
data = data.rechunk(self.chunk_shape)
log.info("Saving Dask array to Zarr array on disk")
log.info('Saving Dask array to Zarr array on disk')
da.to_zarr(data, path, overwrite=self.replace)
else:
zarr_array = zarr.open(
path,
mode="w",
mode='w',
shape=data.shape,
chunks=self.chunk_shape,
dtype=data.dtype,
......@@ -285,17 +300,19 @@ class DataSaver:
zarr_array[:] = data
def save_PIL(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to a PIL file to the given path.
"""
Save data to a PIL file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# No support for compression yet
if self.compression and path.endswith(".png"):
raise NotImplementedError("png does not support compression")
elif not self.compression and path.endswith((".jpeg", ".jpg")):
raise NotImplementedError("jpeg does not support no compression")
if self.compression and path.endswith('.png'):
raise NotImplementedError('png does not support compression')
elif not self.compression and path.endswith(('.jpeg', '.jpg')):
raise NotImplementedError('jpeg does not support no compression')
# Convert to PIL image
img = PIL.Image.fromarray(data)
......@@ -304,7 +321,8 @@ class DataSaver:
img.save(path)
def save(self, path: str | os.PathLike, data: np.ndarray):
"""Save data to the given path.
"""
Save data to the given path.
Args:
path (str): The path to save file to
......@@ -316,6 +334,7 @@ class DataSaver:
ValueError: If the provided path does not exist and self.basename is not provided
ValueError: If a file extension is not provided.
ValueError: if a file with the specified path already exists and replace=False.
"""
path = stringify_path(path)
......@@ -325,7 +344,7 @@ class DataSaver:
# If path is an existing directory
if isdir:
# Check if this is a Zarr directory
if ".zarr" in path:
if '.zarr' in path:
if self.replace:
return self.save_to_zarr(path, data)
if not self.replace:
......@@ -340,7 +359,7 @@ class DataSaver:
else:
raise ValueError(
f"To save a stack as several TIFF files to the directory '{path}', please provide the keyword argument 'basename'. "
+ "Otherwise, to save a single file, please provide a full path with a filename and valid extension."
+ 'Otherwise, to save a single file, please provide a full path with a filename and valid extension.'
)
# If path is not an existing directory
......@@ -353,7 +372,7 @@ class DataSaver:
return self.save_tiff_stack(path, data)
# Check if a parent directory exists
parentdir = os.path.dirname(path) or "."
parentdir = os.path.dirname(path) or '.'
if os.path.isdir(parentdir):
# If there is a file extension in the path
if ext:
......@@ -367,36 +386,36 @@ class DataSaver:
"A file with the provided path already exists. To replace it set 'replace=True'"
)
if path.endswith((".tif", ".tiff")):
if path.endswith(('.tif', '.tiff')):
return self.save_tiff(path, data)
elif path.endswith((".nii", "nii.gz")):
elif path.endswith(('.nii', 'nii.gz')):
return self.save_nifti(path, data)
elif path.endswith(("TXRM", "XRM", "TXM")):
elif path.endswith(('TXRM', 'XRM', 'TXM')):
raise NotImplementedError(
"Saving TXRM files is not yet supported"
'Saving TXRM files is not yet supported'
)
elif path.endswith((".h5")):
elif path.endswith('.h5'):
return self.save_h5(path, data)
elif path.endswith((".vol", ".vgi")):
elif path.endswith(('.vol', '.vgi')):
return self.save_vol(path, data)
elif path.endswith((".dcm", ".DCM")):
elif path.endswith(('.dcm', '.DCM')):
return self.save_dicom(path, data)
elif path.endswith((".zarr")):
elif path.endswith('.zarr'):
return self.save_to_zarr(path, data)
elif path.endswith((".jpeg", ".jpg", ".png")):
elif path.endswith(('.jpeg', '.jpg', '.png')):
return self.save_PIL(path, data)
else:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# If there is no file extension in the path
else:
raise ValueError(
"Please provide a file extension if you want to save as a single file."
+ " Otherwise, please provide a basename to save as a TIFF stack"
'Please provide a file extension if you want to save as a single file.'
+ ' Otherwise, please provide a basename to save as a TIFF stack'
)
else:
raise ValueError(
f"The directory '{parentdir}' does not exist.\n"
+ "Please provide a valid directory or specify a basename if you want to save a tiff stack as several files to a folder that does not yet exist"
+ 'Please provide a valid directory or specify a basename if you want to save a tiff stack as several files to a folder that does not yet exist'
)
......@@ -407,10 +426,11 @@ def save(
compression: bool = False,
basename: bool = None,
sliced_dim: int = 0,
chunk_shape: str = "auto",
chunk_shape: str = 'auto',
**kwargs,
) -> None:
"""Save data to a specified file path.
"""
Save data to a specified file path.
Args:
path (str or os.PathLike): The path to save file to. File format is chosen based on the extension.
......@@ -452,6 +472,7 @@ def save(
qim3d.io.save("slices", vol, basename="blob-slices", sliced_dim=0)
```
"""
DataSaver(
......@@ -464,31 +485,54 @@ def save(
).save(path, data)
def save_mesh(
filename: str,
mesh: trimesh.Trimesh
) -> None:
# def save_mesh(
# filename: str,
# mesh: trimesh.Trimesh
# ) -> None:
# """
# Save a trimesh object to an .obj file.
# Args:
# filename (str or os.PathLike): The name of the file to save the mesh.
# mesh (trimesh.Trimesh): A trimesh.Trimesh object representing the mesh.
# Example:
# ```python
# import qim3d
# vol = qim3d.generate.noise_object(base_shape=(32, 32, 32),
# final_shape=(32, 32, 32),
# noise_scale=0.05,
# order=1,
# gamma=1.0,
# max_value=255,
# threshold=0.5)
# mesh = qim3d.mesh.from_volume(vol)
# qim3d.io.save_mesh("mesh.obj", mesh)
# ```
# """
# # Export the mesh to the specified filename
# mesh.export(filename)
def save_mesh(filename: str, mesh: hmesh.Manifold) -> None:
"""
Save a trimesh object to an .obj file.
Save a mesh object to a specific file.
This function is based on the [PyGEL3D library's saving function implementation](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#save).
Args:
filename (str or os.PathLike): The name of the file to save the mesh.
mesh (trimesh.Trimesh): A trimesh.Trimesh object representing the mesh.
filename (str or os.PathLike): The path to save file to. File format is chosen based on the extension. Supported extensions are: '.x3d', '.obj', '.off'.
mesh (pygel3d.hmesh.Manifold): A hmesh.Manifold object representing the mesh.
Example:
```python
import qim3d
vol = qim3d.generate.noise_object(base_shape=(32, 32, 32),
final_shape=(32, 32, 32),
noise_scale=0.05,
order=1,
gamma=1.0,
max_value=255,
threshold=0.5)
mesh = qim3d.mesh.from_volume(vol)
synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015)
mesh = qim3d.mesh.from_volume(synthetic_blob)
qim3d.io.save_mesh("mesh.obj", mesh)
```
"""
# Export the mesh to the specified filename
mesh.export(filename)
\ No newline at end of file
hmesh.save(filename, mesh)
\ No newline at end of file
"""Dataset synchronization tasks"""
import os
import subprocess
from pathlib import Path
import outputformat as ouf
from qim3d.utils import log
from pathlib import Path
class Sync:
"""Class for dataset synchronization tasks"""
def __init__(self):
# Checks if rsync is available
if not self._check_rsync():
raise RuntimeError(
"Could not find rsync, please check if it is installed in your system."
'Could not find rsync, please check if it is installed in your system.'
)
def _check_rsync(self):
"""Check if rsync is available"""
try:
subprocess.run(["rsync", "--version"], capture_output=True, check=True)
subprocess.run(['rsync', '--version'], capture_output=True, check=True)
return True
except Exception as error:
log.error("rsync is not available")
log.error('rsync is not available')
log.error(error)
return False
def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
"""Check if all files from 'source' are in 'destination'
def check_destination(
self,
source: str,
destination: str,
checksum: bool = False,
verbose: bool = True,
) -> list[str]:
"""
Check if all files from 'source' are in 'destination'
This function compares the files in the 'source' directory to those in
the 'destination' directory and reports any differences or missing files.
......@@ -51,13 +62,13 @@ class Sync:
destination = Path(destination)
if checksum:
rsync_args = "-avrc"
rsync_args = '-avrc'
else:
rsync_args = "-avr"
rsync_args = '-avr'
command = [
"rsync",
"-n",
'rsync',
'-n',
rsync_args,
str(source) + os.path.sep,
str(destination) + os.path.sep,
......@@ -70,18 +81,25 @@ class Sync:
)
diff_files_and_folders = out.stdout.decode().splitlines()[1:-3]
diff_files = [f for f in diff_files_and_folders if not f.endswith("/")]
diff_files = [f for f in diff_files_and_folders if not f.endswith('/')]
if len(diff_files) > 0 and verbose:
title = "Source files differing or missing in destination"
title = 'Source files differing or missing in destination'
log.info(
ouf.showlist(diff_files, style="line", return_str=True, title=title)
ouf.showlist(diff_files, style='line', return_str=True, title=title)
)
return diff_files
def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
"""Checks whether 'source' and 'destination' directories are synchronized.
def compare_dirs(
self,
source: str,
destination: str,
checksum: bool = False,
verbose: bool = True,
) -> None:
"""
Checks whether 'source' and 'destination' directories are synchronized.
This function compares the contents of two directories
('source' and 'destination') and reports any differences.
......@@ -107,7 +125,7 @@ class Sync:
if verbose:
s_files, s_dirs = self.count_files_and_dirs(source)
d_files, d_dirs = self.count_files_and_dirs(destination)
log.info("\n")
log.info('\n')
s_d = self.check_destination(
source, destination, checksum=checksum, verbose=False
......@@ -120,7 +138,7 @@ class Sync:
# No differences
if verbose:
log.info(
"Source and destination are synchronized, no differences found."
'Source and destination are synchronized, no differences found.'
)
return
......@@ -128,9 +146,9 @@ class Sync:
log.info(
ouf.showlist(
union,
style="line",
style='line',
return_str=True,
title=f"{len(union)} files are not in sync",
title=f'{len(union)} files are not in sync',
)
)
......@@ -139,9 +157,9 @@ class Sync:
log.info(
ouf.showlist(
intersection,
style="line",
style='line',
return_str=True,
title=f"{len(intersection)} files present on both, but not equal",
title=f'{len(intersection)} files present on both, but not equal',
)
)
......@@ -150,9 +168,9 @@ class Sync:
log.info(
ouf.showlist(
s_exclusive,
style="line",
style='line',
return_str=True,
title=f"{len(s_exclusive)} files present only on {source}",
title=f'{len(s_exclusive)} files present only on {source}',
)
)
......@@ -161,15 +179,18 @@ class Sync:
log.info(
ouf.showlist(
d_exclusive,
style="line",
style='line',
return_str=True,
title=f"{len(d_exclusive)} files present only on {destination}",
title=f'{len(d_exclusive)} files present only on {destination}',
)
)
return
def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
"""Count the number of files and directories in the given path.
def count_files_and_dirs(
self, path: str | os.PathLike, verbose: bool = True
) -> tuple[int, int]:
"""
Count the number of files and directories in the given path.
This function recursively counts the number of files and
directories in the specified directory 'path'.
......@@ -202,6 +223,6 @@ class Sync:
dirs += dirs_count
if verbose:
log.info(f"Total of {files} files and {dirs} directories on {path}")
log.info(f'Total of {files} files and {dirs} directories on {path}')
return files, dirs
from typing import Any, Tuple
import numpy as np
from skimage import measure, filters
import trimesh
from pygel3d import hmesh
from typing import Tuple, Any
from qim3d.utils._logger import log
def from_volume(
volume: np.ndarray,
level: float = None,
step_size: int = 1,
allow_degenerate: bool = False,
padding: Tuple[int, int, int] = (2, 2, 2),
**kwargs: Any,
) -> trimesh.Trimesh:
"""
Convert a volume to a mesh using the Marching Cubes algorithm, with optional thresholding and padding.
**kwargs: any
) -> hmesh.Manifold:
""" Convert a 3D numpy array to a mesh object using the [volumetric_isocontour](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/hmesh.html#volumetric_isocontour) function from Pygel3D.
Args:
volume (np.ndarray): The 3D numpy array representing the volume.
level (float, optional): The threshold value for Marching Cubes. If None, Otsu's method is used.
step_size (int, optional): The step size for the Marching Cubes algorithm.
allow_degenerate (bool, optional): Whether to allow degenerate (i.e. zero-area) triangles in the end-result. If False, degenerate triangles are removed, at the cost of making the algorithm slower. Default False.
padding (tuple of ints, optional): Padding to add around the volume.
**kwargs: Additional keyword arguments to pass to `skimage.measure.marching_cubes`.
volume (np.ndarray): A 3D numpy array representing a volume.
**kwargs: Additional arguments to pass to the Pygel3D volumetric_isocontour function.
Raises:
ValueError: If the input volume is not a 3D numpy array or if the input volume is empty.
Returns:
mesh (trimesh.Trimesh): The generated mesh.
hmesh.Manifold: A Pygel3D mesh object representing the input volume.
Example:
Convert a 3D numpy array to a Pygel3D mesh object:
```python
import qim3d
vol = qim3d.generate.noise_object(base_shape=(128,128,128),
final_shape=(128,128,128),
noise_scale=0.03,
order=1,
gamma=1,
max_value=255,
threshold=0.5,
dtype='uint8'
)
mesh = qim3d.mesh.from_volume(vol, step_size=3)
qim3d.viz.mesh(mesh.vertices, mesh.faces)
# Generate a 3D blob
synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015)
# Convert the 3D numpy array to a Pygel3D mesh object
mesh = qim3d.mesh.from_volume(synthetic_blob)
```
<iframe src="https://platform.qim.dk/k3d/mesh_visualization.html" width="100%" height="500" frameborder="0"></iframe>
"""
if volume.ndim != 3:
raise ValueError("The input volume must be a 3D numpy array.")
# Compute the threshold level if not provided
if level is None:
level = filters.threshold_otsu(volume)
log.info(f"Computed level using Otsu's method: {level}")
# Apply padding to the volume
if padding is not None:
pad_z, pad_y, pad_x = padding
padding_value = np.min(volume)
volume = np.pad(
volume,
((pad_z, pad_z), (pad_y, pad_y), (pad_x, pad_x)),
mode="constant",
constant_values=padding_value,
)
log.info(f"Padded volume with {padding} to shape: {volume.shape}")
# Call skimage.measure.marching_cubes with user-provided kwargs
verts, faces, normals, values = measure.marching_cubes(
volume, level=level, step_size=step_size, allow_degenerate=allow_degenerate, **kwargs
)
# Create the Trimesh object
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
# Fix face orientation to ensure normals point outwards
trimesh.repair.fix_inversion(mesh, multibody=True)
if volume.size == 0:
raise ValueError("The input volume must not be empty.")
mesh = hmesh.volumetric_isocontour(volume, **kwargs)
return mesh
\ No newline at end of file
from ._unet import UNet, Hyperparameters
from ._unet import Hyperparameters, UNet