Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • QIM/tools/qim3d
1 result
Show changes
Showing
with 1476 additions and 1190 deletions
from ._generators import noise_object
from ._aggregators import noise_object_collection
from ._generators import noise_object
......@@ -22,6 +22,7 @@ def random_placement(
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
"""
# Find available (zero) elements in collection
available_z, available_y, available_x = np.where(collection == 0)
......@@ -44,14 +45,12 @@ def random_placement(
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
......@@ -81,6 +80,7 @@ def specific_placement(
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
positions (list[tuple]): List of remaining positions to place blobs.
"""
# Flag for placement success
placed = False
......@@ -99,14 +99,12 @@ def specific_placement(
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
......@@ -253,13 +251,13 @@ def noise_object_collection(
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_cylinder.html" width="100%" height="500" frameborder="0"></iframe>
```python
# Visualize slices
qim3d.viz.slices_grid(vol, num_slices=15)
```
![synthetic_collection_cylinder](../../assets/screenshots/synthetic_collection_cylinder_slices.png)
![synthetic_collection_cylinder](../../assets/screenshots/synthetic_collection_cylinder_slices.png)
Example:
```python
import qim3d
......@@ -283,29 +281,30 @@ def noise_object_collection(
qim3d.viz.volumetric(vol)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_tube.html" width="100%" height="500" frameborder="0"></iframe>
```python
# Visualize slices
qim3d.viz.slices_grid(vol, num_slices=15, slice_axis=1)
```
![synthetic_collection_tube](../../assets/screenshots/synthetic_collection_tube_slices.png)
"""
if verbose:
original_log_level = log.getEffectiveLevel()
log.setLevel("DEBUG")
log.setLevel('DEBUG')
# Check valid input types
if not isinstance(collection_shape, tuple) or len(collection_shape) != 3:
raise TypeError(
"Shape of collection must be a tuple with three dimensions (z, y, x)"
'Shape of collection must be a tuple with three dimensions (z, y, x)'
)
if len(min_shape) != len(max_shape):
raise ValueError("Object shapes must be tuples of the same length")
raise ValueError('Object shapes must be tuples of the same length')
if (positions is not None) and (len(positions) != num_objects):
raise ValueError(
"Number of objects must match number of positions, otherwise set positions = None"
'Number of objects must match number of positions, otherwise set positions = None'
)
# Set seed for random number generator
......@@ -318,8 +317,8 @@ def noise_object_collection(
labels = np.zeros_like(collection_array)
# Fill the 3D array with synthetic blobs
for i in tqdm(range(num_objects), desc="Objects placed"):
log.debug(f"\nObject #{i+1}")
for i in tqdm(range(num_objects), desc='Objects placed'):
log.debug(f'\nObject #{i+1}')
# Sample from blob parameter ranges
if min_shape == max_shape:
......@@ -328,27 +327,27 @@ def noise_object_collection(
blob_shape = tuple(
rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3)
)
log.debug(f"- Blob shape: {blob_shape}")
log.debug(f'- Blob shape: {blob_shape}')
# Scale object shape
final_shape = tuple(l * r for l, r in zip(blob_shape, object_shape_zoom))
final_shape = tuple(int(x) for x in final_shape) # NOTE: Added this
final_shape = tuple(int(x) for x in final_shape) # NOTE: Added this
# Sample noise scale
noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise)
log.debug(f"- Object noise scale: {noise_scale:.4f}")
log.debug(f'- Object noise scale: {noise_scale:.4f}')
gamma = rng.uniform(low=min_gamma, high=max_gamma)
log.debug(f"- Gamma correction: {gamma:.3f}")
log.debug(f'- Gamma correction: {gamma:.3f}')
if max_high_value > min_high_value:
max_value = rng.integers(low=min_high_value, high=max_high_value)
else:
max_value = min_high_value
log.debug(f"- Max value: {max_value}")
log.debug(f'- Max value: {max_value}')
threshold = rng.uniform(low=min_threshold, high=max_threshold)
log.debug(f"- Threshold: {threshold:.3f}")
log.debug(f'- Threshold: {threshold:.3f}')
# Generate synthetic object
blob = qim3d.generate.noise_object(
......@@ -368,7 +367,7 @@ def noise_object_collection(
low=min_rotation_degrees, high=max_rotation_degrees
) # Sample rotation angle
axes = rng.choice(rotation_axes) # Sample the two axes to rotate around
log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}")
log.debug(f'- Rotation angle: {angle:.2f} at axes: {axes}')
blob = scipy.ndimage.rotate(blob, angle, axes, order=1)
......@@ -397,7 +396,7 @@ def noise_object_collection(
if not placed:
# Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar)
log.error(
f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects."
f'Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects.'
)
if verbose:
......
......@@ -4,6 +4,7 @@ from noise import pnoise3
import qim3d.processing
def noise_object(
base_shape: tuple = (128, 128, 128),
final_shape: tuple = (128, 128, 128),
......@@ -14,8 +15,8 @@ def noise_object(
threshold: float = 0.5,
smooth_borders: bool = False,
object_shape: str = None,
dtype: str = "uint8",
) -> np.ndarray:
dtype: str = 'uint8',
) -> np.ndarray:
"""
Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction.
......@@ -97,18 +98,19 @@ def noise_object(
qim3d.viz.volumetric(vol)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_blob_tube.html" width="100%" height="500" frameborder="0"></iframe>
```python
# Visualize
qim3d.viz.slices_grid(vol, num_slices=15)
```
![synthetic_blob_tube_slice](../../assets/screenshots/synthetic_blob_tube_slice.png)
![synthetic_blob_tube_slice](../../assets/screenshots/synthetic_blob_tube_slice.png)
"""
if not isinstance(final_shape, tuple) or len(final_shape) != 3:
raise TypeError("Size must be a tuple of 3 dimensions")
raise TypeError('Size must be a tuple of 3 dimensions')
if not np.issubdtype(dtype, np.number):
raise ValueError("Invalid data type")
raise ValueError('Invalid data type')
# Initialize the 3D array for the shape
volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32)
......@@ -119,19 +121,18 @@ def noise_object(
# Calculate the distance from the center of the shape
center = np.array(base_shape) / 2
dist = np.sqrt((z - center[0])**2 +
(y - center[1])**2 +
(x - center[2])**2)
dist /= np.sqrt(3 * (center[0]**2))
dist = np.sqrt((z - center[0]) ** 2 + (y - center[1]) ** 2 + (x - center[2]) ** 2)
dist /= np.sqrt(3 * (center[0] ** 2))
# Generate Perlin noise and adjust the values based on the distance from the center
vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input
vectorized_pnoise3 = np.vectorize(
pnoise3
) # Vectorize pnoise3, since it only takes scalar input
noise = vectorized_pnoise3(z.flatten() * noise_scale,
y.flatten() * noise_scale,
x.flatten() * noise_scale
).reshape(base_shape)
noise = vectorized_pnoise3(
z.flatten() * noise_scale, y.flatten() * noise_scale, x.flatten() * noise_scale
).reshape(base_shape)
volume = (1 + noise) * (1 - dist)
......@@ -148,17 +149,22 @@ def noise_object(
if object_shape:
smooth_borders = False
if smooth_borders:
if smooth_borders:
# Maximum value among the six sides of the 3D volume
max_border_value = np.max([
np.max(volume[0, :, :]), np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]), np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]), np.max(volume[:, :, -1])
])
max_border_value = np.max(
[
np.max(volume[0, :, :]),
np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]),
np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]),
np.max(volume[:, :, -1]),
]
)
# Compute threshold such that there will be no straight cuts in the blob
threshold = max_border_value / max_value
# Clip the low values of the volume to create a coherent volume
volume[volume < threshold * max_value] = 0
......@@ -171,45 +177,50 @@ def noise_object(
)
# Fade into a shape if specified
if object_shape == "cylinder":
if object_shape == 'cylinder':
# Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical
volume = qim3d.operations.fade_mask(volume,
geometry = geometry,
axis = axis,
target_max_normalized_distance = target_max_normalized_distance
)
elif object_shape == "tube":
geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(
volume.shape
) # Fade along the dimension where the object is the largest
target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
volume = qim3d.operations.fade_mask(
volume,
geometry=geometry,
axis=axis,
target_max_normalized_distance=target_max_normalized_distance,
)
elif object_shape == 'tube':
# Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest
decay_rate = 5 # Decay rate for the fade operation
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical
geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(
volume.shape
) # Fade along the dimension where the object is the largest
decay_rate = 5 # Decay rate for the fade operation
target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
# Fade once for making the object cylindrical
volume = qim3d.operations.fade_mask(volume,
geometry = geometry,
axis = axis,
decay_rate = decay_rate,
target_max_normalized_distance = target_max_normalized_distance,
invert = False
)
volume = qim3d.operations.fade_mask(
volume,
geometry=geometry,
axis=axis,
decay_rate=decay_rate,
target_max_normalized_distance=target_max_normalized_distance,
invert=False,
)
# Fade again with invert = True for making the object a tube (i.e. with a hole in the middle)
volume = qim3d.operations.fade_mask(volume,
geometry = geometry,
axis = axis,
decay_rate = decay_rate,
invert = True
)
volume = qim3d.operations.fade_mask(
volume, geometry=geometry, axis=axis, decay_rate=decay_rate, invert=True
)
# Convert to desired data type
volume = volume.astype(dtype)
return volume
\ No newline at end of file
return volume
from fastapi import FastAPI
import qim3d.utils
from . import data_explorer
from . import iso3d
from . import local_thickness
from . import annotation_tool
from . import layers2d
from . import annotation_tool, data_explorer, iso3d, layers2d, local_thickness
from .qim_theme import QimTheme
def run_gradio_app(gradio_interface, host="0.0.0.0"):
def run_gradio_app(gradio_interface, host='0.0.0.0'):
import gradio as gr
import uvicorn
# Get port using the QIM API
port_dict = qim3d.utils.get_port_dict()
if "gradio_port" in port_dict:
port = port_dict["gradio_port"]
elif "port" in port_dict:
port = port_dict["port"]
if 'gradio_port' in port_dict:
port = port_dict['gradio_port']
elif 'port' in port_dict:
port = port_dict['port']
else:
raise Exception("Port not specified from QIM API")
raise Exception('Port not specified from QIM API')
qim3d.utils.gradio_header(gradio_interface.title, port)
......@@ -30,7 +28,7 @@ def run_gradio_app(gradio_interface, host="0.0.0.0"):
app = gr.mount_gradio_app(app, gradio_interface, path=path)
# Full path
print(f"http://{host}:{port}{path}")
print(f'http://{host}:{port}{path}')
# Run the FastAPI server usign uvicorn
uvicorn.run(app, host=host, port=int(port))
......@@ -27,6 +27,7 @@ import tempfile
import gradio as gr
import numpy as np
from PIL import Image
import qim3d
from qim3d.gui.interface import BaseInterface
......@@ -34,17 +35,19 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
def __init__(
self, name_suffix: str = '', verbose: bool = False, img: np.ndarray = None
):
super().__init__(
title="Annotation Tool",
title='Annotation Tool',
height=768,
width="100%",
width='100%',
verbose=verbose,
custom_css="annotation_tool.css",
custom_css='annotation_tool.css',
)
self.username = getpass.getuser()
self.temp_dir = os.path.join(tempfile.gettempdir(), f"qim-{self.username}")
self.temp_dir = os.path.join(tempfile.gettempdir(), f'qim-{self.username}')
self.name_suffix = name_suffix
self.img = img
......@@ -57,7 +60,7 @@ class Interface(BaseInterface):
# Get the temporary files from gradio
temp_path_list = []
for filename in os.listdir(self.temp_dir):
if "mask" and self.name_suffix in str(filename):
if 'mask' and self.name_suffix in str(filename):
# Get the list of the temporary files
temp_path_list.append(os.path.join(self.temp_dir, filename))
......@@ -76,9 +79,9 @@ class Interface(BaseInterface):
this is safer and backwards compatible (should be)
"""
self.mask_names = [
f"red{self.name_suffix}",
f"green{self.name_suffix}",
f"blue{self.name_suffix}",
f'red{self.name_suffix}',
f'green{self.name_suffix}',
f'blue{self.name_suffix}',
]
# Clean up old files
......@@ -86,7 +89,7 @@ class Interface(BaseInterface):
files = os.listdir(self.temp_dir)
for filename in files:
# Check if "mask" is in the filename
if ("mask" in filename) and (self.name_suffix in filename):
if ('mask' in filename) and (self.name_suffix in filename):
file_path = os.path.join(self.temp_dir, filename)
os.remove(file_path)
......@@ -94,13 +97,13 @@ class Interface(BaseInterface):
files = None
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"]
masks = img_editor["layers"][0]
background = img_editor['background']
masks = img_editor['layers'][0]
overlay_image = qim3d.operations.overlay_rgb_images(background, masks)
return overlay_image
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0]
masks_rgb = img_editor['layers'][0]
mask_threshold = 200 # This value is based
mask_list = []
......@@ -114,7 +117,7 @@ class Interface(BaseInterface):
# Save only if we have a mask
if np.sum(mask) > 0:
mask_list.append(mask)
filename = f"mask_{self.mask_names[idx]}.tif"
filename = f'mask_{self.mask_names[idx]}.tif'
if not os.path.exists(self.temp_dir):
os.makedirs(self.temp_dir)
filepath = os.path.join(self.temp_dir, filename)
......@@ -128,11 +131,11 @@ class Interface(BaseInterface):
def define_interface(self, **kwargs):
brush = gr.Brush(
colors=[
"rgb(255,50,100)",
"rgb(50,250,100)",
"rgb(50,100,255)",
'rgb(255,50,100)',
'rgb(50,250,100)',
'rgb(50,100,255)',
],
color_mode="fixed",
color_mode='fixed',
default_size=10,
)
with gr.Row():
......@@ -142,26 +145,25 @@ class Interface(BaseInterface):
img_editor = gr.ImageEditor(
value=(
{
"background": self.img,
"layers": [Image.new("RGBA", self.img.shape, (0, 0, 0, 0))],
"composite": None,
'background': self.img,
'layers': [Image.new('RGBA', self.img.shape, (0, 0, 0, 0))],
'composite': None,
}
if self.img is not None
else None
),
type="numpy",
image_mode="RGB",
type='numpy',
image_mode='RGB',
brush=brush,
sources="upload",
sources='upload',
interactive=True,
show_download_button=True,
container=False,
transforms=["crop"],
transforms=['crop'],
layers=False,
)
with gr.Column(scale=1, min_width=256):
with gr.Row():
overlay_img = gr.Image(
show_download_button=False,
......@@ -169,7 +171,7 @@ class Interface(BaseInterface):
visible=False,
)
with gr.Row():
masks_download = gr.File(label="Download masks", visible=False)
masks_download = gr.File(label='Download masks', visible=False)
# fmt: off
img_editor.change(
......
This diff is collapsed.
from abc import ABC, abstractmethod
from os import listdir, path
from pathlib import Path
from abc import abstractmethod, ABC
from os import path, listdir
import gradio as gr
import numpy as np
from .qim_theme import QimTheme
import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli
class BaseInterface(ABC):
"""
Annotation tool and Data explorer as those don't need any examples.
"""
......@@ -19,7 +19,7 @@ class BaseInterface(ABC):
self,
title: str,
height: int,
width: int = "100%",
width: int = '100%',
verbose: bool = False,
custom_css: str = None,
):
......@@ -38,7 +38,7 @@ class BaseInterface(ABC):
self.qim_dir = Path(qim3d.__file__).parents[0]
self.custom_css = (
path.join(self.qim_dir, "css", custom_css)
path.join(self.qim_dir, 'css', custom_css)
if custom_css is not None
else None
)
......@@ -48,9 +48,9 @@ class BaseInterface(ABC):
def set_invisible(self):
return gr.update(visible=False)
def change_visibility(self, is_visible: bool):
return gr.update(visible = is_visible)
return gr.update(visible=is_visible)
def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
"""
......@@ -72,8 +72,7 @@ class BaseInterface(ABC):
quiet=not self.verbose,
height=self.height,
width=self.width,
favicon_path=Path(qim3d.__file__).parents[0]
/ "gui/assets/qim3d-icon.svg",
favicon_path=Path(qim3d.__file__).parents[0] / 'gui/assets/qim3d-icon.svg',
**kwargs,
)
......@@ -88,7 +87,7 @@ class BaseInterface(ABC):
title=self.title,
css=self.custom_css,
) as gradio_interface:
gr.Markdown(f"# {self.title}")
gr.Markdown(f'# {self.title}')
self.define_interface(**kwargs)
return gradio_interface
......@@ -96,11 +95,12 @@ class BaseInterface(ABC):
def define_interface(self, **kwargs):
pass
def run_interface(self, host: str = "0.0.0.0"):
def run_interface(self, host: str = '0.0.0.0'):
qim3d.gui.run_gradio_app(self.create_interface(), host)
class InterfaceWithExamples(BaseInterface):
"""
For Iso3D and Local Thickness
"""
......@@ -117,7 +117,23 @@ class InterfaceWithExamples(BaseInterface):
self._set_examples_list()
def _set_examples_list(self):
valid_sufixes = (".tif", ".tiff", ".h5", ".nii", ".gz", ".dcm", ".DCM", ".vol", ".vgi", ".txrm", ".txm", ".xrm")
valid_sufixes = (
'.tif',
'.tiff',
'.h5',
'.nii',
'.gz',
'.dcm',
'.DCM',
'.vol',
'.vgi',
'.txrm',
'.txm',
'.xrm',
)
examples_folder = path.join(self.qim_dir, 'examples')
self.img_examples = [path.join(examples_folder, example) for example in listdir(examples_folder) if example.endswith(valid_sufixes)]
self.img_examples = [
path.join(examples_folder, example)
for example in listdir(examples_folder)
if example.endswith(valid_sufixes)
]
......@@ -15,6 +15,7 @@ app.launch()
```
"""
import os
import gradio as gr
......@@ -23,21 +24,19 @@ import plotly.graph_objects as go
from scipy import ndimage
import qim3d
from qim3d.utils._logger import log
from qim3d.gui.interface import InterfaceWithExamples
from qim3d.utils._logger import log
#TODO img in launch should be self.img
# TODO img in launch should be self.img
class Interface(InterfaceWithExamples):
def __init__(self,
verbose:bool = False,
plot_height:int = 768,
img = None):
super().__init__(title = "Isosurfaces for 3D visualization",
height = 1024,
width = 960,
verbose = verbose)
def __init__(self, verbose: bool = False, plot_height: int = 768, img=None):
super().__init__(
title='Isosurfaces for 3D visualization',
height=1024,
width=960,
verbose=verbose,
)
self.interface = None
self.img = img
......@@ -48,11 +47,13 @@ class Interface(InterfaceWithExamples):
self.vol = qim3d.io.load(gradiofile.name)
assert self.vol.ndim == 3
except AttributeError:
raise gr.Error("You have to select a file")
raise gr.Error('You have to select a file')
except ValueError:
raise gr.Error("Unsupported file format")
raise gr.Error('Unsupported file format')
except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size"""
......@@ -61,12 +62,12 @@ class Interface(InterfaceWithExamples):
original_Z, original_Y, original_X = np.shape(self.vol)
max_size = np.max([original_Z, original_Y, original_X])
if self.verbose:
log.info(f"\nOriginal volume: {original_Z, original_Y, original_X}")
log.info(f'\nOriginal volume: {original_Z, original_Y, original_X}')
# Resize for display
self.vol = ndimage.zoom(
input=self.vol,
zoom = display_size / max_size,
zoom=display_size / max_size,
order=0,
prefilter=False,
)
......@@ -76,16 +77,17 @@ class Interface(InterfaceWithExamples):
)
if self.verbose:
log.info(
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
f'Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}'
)
def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk
fig.write_html(filename)
def create_fig(self,
def create_fig(
self,
gradio_file: gr.File,
display_size: int ,
display_size: int,
opacity: float,
opacityscale: str,
only_wireframe: bool,
......@@ -105,8 +107,7 @@ class Interface(InterfaceWithExamples):
slice_y_location: int,
show_x_slice: bool,
slice_x_location: int,
) -> tuple[go.Figure, str]:
) -> tuple[go.Figure, str]:
# Load volume
self.load_data(gradio_file)
......@@ -129,191 +130,184 @@ class Interface(InterfaceWithExamples):
fig = go.Figure(
go.Volume(
z = Z.flatten(),
y = Y.flatten(),
x = X.flatten(),
value = self.vol.flatten(),
isomin = min_value * np.max(self.vol),
isomax = max_value * np.max(self.vol),
cmin = np.min(self.vol),
cmax = np.max(self.vol),
opacity = opacity,
opacityscale = opacityscale,
surface_count = surface_count,
colorscale = colormap,
slices_z = dict(
show = show_z_slice,
locations = [int(self.display_size_z * slice_z_location)],
z=Z.flatten(),
y=Y.flatten(),
x=X.flatten(),
value=self.vol.flatten(),
isomin=min_value * np.max(self.vol),
isomax=max_value * np.max(self.vol),
cmin=np.min(self.vol),
cmax=np.max(self.vol),
opacity=opacity,
opacityscale=opacityscale,
surface_count=surface_count,
colorscale=colormap,
slices_z=dict(
show=show_z_slice,
locations=[int(self.display_size_z * slice_z_location)],
),
slices_y = dict(
show = show_y_slice,
slices_y=dict(
show=show_y_slice,
locations=[int(self.display_size_y * slice_y_location)],
),
slices_x = dict(
show = show_x_slice,
locations = [int(self.display_size_x * slice_x_location)],
slices_x=dict(
show=show_x_slice,
locations=[int(self.display_size_x * slice_x_location)],
),
surface = dict(fill=surface_fill),
caps = dict(
x_show = show_caps,
y_show = show_caps,
z_show = show_caps,
surface=dict(fill=surface_fill),
caps=dict(
x_show=show_caps,
y_show=show_caps,
z_show=show_caps,
),
showscale = show_colorbar,
showscale=show_colorbar,
colorbar=dict(
thickness=8, outlinecolor="#fff", len=0.5, orientation="h"
thickness=8, outlinecolor='#fff', len=0.5, orientation='h'
),
reversescale = reversescale,
hoverinfo = "skip",
reversescale=reversescale,
hoverinfo='skip',
)
)
fig.update_layout(
scene_xaxis_showticklabels = show_ticks,
scene_yaxis_showticklabels = show_ticks,
scene_zaxis_showticklabels = show_ticks,
scene_xaxis_visible = show_axis,
scene_yaxis_visible = show_axis,
scene_zaxis_visible = show_axis,
scene_aspectmode="data",
scene_xaxis_showticklabels=show_ticks,
scene_yaxis_showticklabels=show_ticks,
scene_zaxis_showticklabels=show_ticks,
scene_xaxis_visible=show_axis,
scene_yaxis_visible=show_axis,
scene_zaxis_visible=show_axis,
scene_aspectmode='data',
height=self.plot_height,
hovermode=False,
scene_camera_eye=dict(x=2.0, y=-2.0, z=1.5),
)
filename = "iso3d.html"
filename = 'iso3d.html'
self.save_fig(fig, filename)
return fig, filename
def remove_unused_file(self):
# Remove localthickness.tif file from working directory
# as it otherwise is not deleted
os.remove("iso3d.html")
os.remove('iso3d.html')
def define_interface(self, **kwargs):
gr.Markdown(
"""
"""
This tool uses Plotly Volume (https://plotly.com/python/3d-volume-plots/) to create iso surfaces from voxels based on their intensity levels.
To optimize performance when generating visualizations, set the number of voxels (_display resolution_) and isosurfaces (_total surfaces_) to lower levels.
"""
)
)
with gr.Row():
# Input and parameters column
with gr.Column(scale=1, min_width=320):
with gr.Tab("Input"):
with gr.Tab('Input'):
# File loader
gradio_file = gr.File(
show_label=False
)
with gr.Tab("Examples"):
gradio_file = gr.File(show_label=False)
with gr.Tab('Examples'):
gr.Examples(examples=self.img_examples, inputs=gradio_file)
# Run button
with gr.Row():
with gr.Column(scale=3, min_width=64):
btn_run = gr.Button(
value="Run 3D visualization", variant = "primary"
value='Run 3D visualization', variant='primary'
)
with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button(
value="Clear", variant = "stop"
)
btn_clear = gr.Button(value='Clear', variant='stop')
with gr.Tab("Display"):
with gr.Tab('Display'):
# Display options
display_size = gr.Slider(
32,
128,
step=4,
label="Display resolution",
info="Number of voxels for the largest dimension",
label='Display resolution',
info='Number of voxels for the largest dimension',
value=64,
)
surface_count = gr.Slider(
2, 16, step=1, label="Total iso-surfaces", value=6
2, 16, step=1, label='Total iso-surfaces', value=6
)
show_caps = gr.Checkbox(value=False, label="Show surface caps")
show_caps = gr.Checkbox(value=False, label='Show surface caps')
with gr.Row():
opacityscale = gr.Dropdown(
choices=["uniform", "extremes", "min", "max"],
value="uniform",
label="Opacity scale",
info="Handles opacity acording to voxel value",
choices=['uniform', 'extremes', 'min', 'max'],
value='uniform',
label='Opacity scale',
info='Handles opacity acording to voxel value',
)
opacity = gr.Slider(
0.0, 1.0, step=0.1, label="Max opacity", value=0.4
0.0, 1.0, step=0.1, label='Max opacity', value=0.4
)
with gr.Row():
min_value = gr.Slider(
0.0, 1.0, step=0.05, label="Min value", value=0.1
0.0, 1.0, step=0.05, label='Min value', value=0.1
)
max_value = gr.Slider(
0.0, 1.0, step=0.05, label="Max value", value=1
0.0, 1.0, step=0.05, label='Max value', value=1
)
with gr.Tab("Slices") as slices:
show_z_slice = gr.Checkbox(value=False, label="Show Z slice")
with gr.Tab('Slices') as slices:
show_z_slice = gr.Checkbox(value=False, label='Show Z slice')
slice_z_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
show_y_slice = gr.Checkbox(value=False, label="Show Y slice")
show_y_slice = gr.Checkbox(value=False, label='Show Y slice')
slice_y_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
show_x_slice = gr.Checkbox(value=False, label="Show X slice")
show_x_slice = gr.Checkbox(value=False, label='Show X slice')
slice_x_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position"
0.0, 1.0, step=0.05, value=0.5, label='Position'
)
with gr.Tab("Misc"):
with gr.Tab('Misc'):
with gr.Row():
colormap = gr.Dropdown(
choices=[
"Blackbody",
"Bluered",
"Blues",
"Cividis",
"Earth",
"Electric",
"Greens",
"Greys",
"Hot",
"Jet",
"Magma",
"Picnic",
"Portland",
"Rainbow",
"RdBu",
"Reds",
"Viridis",
"YlGnBu",
"YlOrRd",
'Blackbody',
'Bluered',
'Blues',
'Cividis',
'Earth',
'Electric',
'Greens',
'Greys',
'Hot',
'Jet',
'Magma',
'Picnic',
'Portland',
'Rainbow',
'RdBu',
'Reds',
'Viridis',
'YlGnBu',
'YlOrRd',
],
value="Magma",
label="Colormap",
value='Magma',
label='Colormap',
)
show_colorbar = gr.Checkbox(
value=False, label="Show color scale"
value=False, label='Show color scale'
)
reversescale = gr.Checkbox(
value=False, label="Reverse color scale"
value=False, label='Reverse color scale'
)
flip_z = gr.Checkbox(value=True, label="Flip Z axis")
show_axis = gr.Checkbox(value=True, label="Show axis")
show_ticks = gr.Checkbox(value=False, label="Show ticks")
only_wireframe = gr.Checkbox(
value=False, label="Only wireframe"
)
flip_z = gr.Checkbox(value=True, label='Flip Z axis')
show_axis = gr.Checkbox(value=True, label='Show axis')
show_ticks = gr.Checkbox(value=False, label='Show ticks')
only_wireframe = gr.Checkbox(value=False, label='Only wireframe')
# Inputs for gradio
inputs = [
......@@ -346,7 +340,7 @@ class Interface(InterfaceWithExamples):
plot_download = gr.File(
interactive=False,
label="Download interactive plot",
label='Download interactive plot',
show_label=True,
visible=False,
)
......@@ -367,5 +361,6 @@ class Interface(InterfaceWithExamples):
fn=self.remove_unused_file).success(
fn=self.set_visible, inputs=None, outputs=plot_download)
if __name__ == "__main__":
Interface().run_interface()
\ No newline at end of file
if __name__ == '__main__':
Interface().run_interface()
This diff is collapsed.
"""
!!! quote "Reference"
Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW).
Dahl, V. A., & Dahl, A. B. (2023, June). Fast Local Thickness. 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW).
<https://doi.org/10.1109/cvprw59228.2023.00456>
```bibtex
@inproceedings{Dahl_2023, title={Fast Local Thickness},
url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456},
DOI={10.1109/cvprw59228.2023.00456},
booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
publisher={IEEE},
author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm},
year={2023},
@inproceedings{Dahl_2023, title={Fast Local Thickness},
url={http://dx.doi.org/10.1109/CVPRW59228.2023.00456},
DOI={10.1109/cvprw59228.2023.00456},
booktitle={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
publisher={IEEE},
author={Dahl, Vedrana Andersen and Dahl, Anders Bjorholm},
year={2023},
month=jun }
```
......@@ -32,29 +32,31 @@ app.launch()
```
"""
import os
import gradio as gr
import localthickness as lt
# matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
import tifffile
import localthickness as lt
import qim3d
import qim3d
class Interface(qim3d.gui.interface.InterfaceWithExamples):
def __init__(self,
img: np.ndarray = None,
verbose:bool = False,
plot_height:int = 768,
figsize:int = 6):
super().__init__(title = "Local thickness",
height = 1024,
width = 960,
verbose = verbose)
def __init__(
self,
img: np.ndarray = None,
verbose: bool = False,
plot_height: int = 768,
figsize: int = 6,
):
super().__init__(
title='Local thickness', height=1024, width=960, verbose=verbose
)
self.plot_height = plot_height
self.figsize = figsize
......@@ -64,7 +66,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# Get the temporary files from gradio
temp_sets = self.interface.temp_file_sets
for temp_set in temp_sets:
if "localthickness" in str(temp_set):
if 'localthickness' in str(temp_set):
# Get the lsit of the temporary files
temp_path_list = list(temp_set)
......@@ -84,7 +86,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
def define_interface(self):
gr.Markdown(
"Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)"
'Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)'
)
with gr.Row():
......@@ -92,12 +94,12 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
if self.img is not None:
data = gr.State(value=self.img)
else:
with gr.Tab("Input"):
with gr.Tab('Input'):
data = gr.File(
show_label=False,
value=self.img,
)
with gr.Tab("Examples"):
with gr.Tab('Examples'):
gr.Examples(examples=self.img_examples, inputs=data)
with gr.Row():
......@@ -106,17 +108,15 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
maximum=1,
value=0.5,
step=0.01,
label="Z position",
info="Local thickness is calculated in 3D, this slider controls the visualization only.",
label='Z position',
info='Local thickness is calculated in 3D, this slider controls the visualization only.',
)
with gr.Tab("Parameters"):
with gr.Tab('Parameters'):
gr.Markdown(
"It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results."
)
lt_scale = gr.Slider(
0.1, 1.0, label="Scale", value=0.5, step=0.1
'It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results.'
)
lt_scale = gr.Slider(0.1, 1.0, label='Scale', value=0.5, step=0.1)
with gr.Row():
threshold = gr.Slider(
......@@ -124,85 +124,83 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
1.0,
value=0.5,
step=0.05,
label="Threshold",
info="Local thickness uses a binary image, so a threshold value is needed.",
label='Threshold',
info='Local thickness uses a binary image, so a threshold value is needed.',
)
dark_objects = gr.Checkbox(
value=False,
label="Dark objects",
info="Inverts the image before thresholding. Use in case your foreground is darker than the background.",
label='Dark objects',
info='Inverts the image before thresholding. Use in case your foreground is darker than the background.',
)
with gr.Tab("Display options"):
with gr.Tab('Display options'):
cmap_original = gr.Dropdown(
value="viridis",
value='viridis',
choices=plt.colormaps(),
label="Colormap - input",
label='Colormap - input',
interactive=True,
)
cmap_lt = gr.Dropdown(
value="magma",
value='magma',
choices=plt.colormaps(),
label="Colormap - local thickness",
label='Colormap - local thickness',
interactive=True,
)
nbins = gr.Slider(
5, 50, value=25, step=1, label="Histogram bins"
)
nbins = gr.Slider(5, 50, value=25, step=1, label='Histogram bins')
# Run button
with gr.Row():
with gr.Column(scale=3, min_width=64):
btn = gr.Button(
"Run local thickness", variant = "primary"
)
btn = gr.Button('Run local thickness', variant='primary')
with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button("Clear", variant = "stop")
btn_clear = gr.Button('Clear', variant='stop')
with gr.Column(scale=4):
def create_uniform_image(intensity=1):
"""
Generates a blank image with a single color.
Gradio `gr.Plot` components will flicker if there is no default value.
bug fix on gradio version 4.44.0
"""
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255)
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(
intensity * 255
)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(pixels, interpolation="nearest")
ax.imshow(pixels, interpolation='nearest')
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
with gr.Row():
input_vol = gr.Plot(
show_label=True,
label="Original",
label='Original',
visible=True,
value=create_uniform_image(),
)
binary_vol = gr.Plot(
show_label=True,
label="Binary",
label='Binary',
visible=True,
value=create_uniform_image(),
)
output_vol = gr.Plot(
show_label=True,
label="Local thickness",
label='Local thickness',
visible=True,
value=create_uniform_image(),
)
with gr.Row():
histogram = gr.Plot(
show_label=True,
label="Thickness histogram",
label='Thickness histogram',
visible=True,
value=create_uniform_image(),
)
......@@ -210,11 +208,10 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
lt_output = gr.File(
interactive=False,
show_label=True,
label="Output file",
label='Output file',
visible=False,
)
# Run button
# fmt: off
viz_input = lambda zpos, cmap: self.show_slice(self.vol, zpos, self.vmin, self.vmax, cmap)
......@@ -246,11 +243,11 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
fn=viz_input, inputs = [zpos, cmap_original], outputs=input_vol, show_progress=False).success(
fn=viz_binary, inputs = [zpos, cmap_original], outputs=binary_vol, show_progress=False).success(
fn=viz_output, inputs = [zpos, cmap_lt], outputs=output_vol, show_progress=False)
cmap_original.change(
fn=viz_input, inputs = [zpos, cmap_original],outputs=input_vol, show_progress=False).success(
fn=viz_binary, inputs = [zpos, cmap_original], outputs=binary_vol, show_progress=False)
cmap_lt.change(
fn=viz_output, inputs = [zpos, cmap_lt], outputs=output_vol, show_progress=False
)
......@@ -274,7 +271,9 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
except AttributeError:
self.vol = data
except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
if dark_objects:
self.vol = np.invert(self.vol)
......@@ -283,15 +282,22 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol)
def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
def show_slice(
self,
vol: np.ndarray,
zpos: int,
vmin: float = None,
vmax: float = None,
cmap: str = 'viridis',
):
plt.close()
z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(vol[z_idx], interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
ax.imshow(vol[z_idx], interpolation='nearest', cmap=cmap, vmin=vmin, vmax=vmax)
# Adjustments
ax.axis("off")
ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig
......@@ -300,7 +306,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# Make a binary volume
# Nothing fancy, but we could add new features here
self.vol_binary = self.vol > (threshold * np.max(self.vol))
def compute_localthickness(self, lt_scale: float):
self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
......@@ -318,29 +324,30 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge"
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge'
)
# Adjustments
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_yscale("log")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_yscale('log')
return fig
def save_lt(self):
filename = "localthickness.tif"
filename = 'localthickness.tif'
# Save output image in a temp space
tifffile.imwrite(filename, self.vol_thickness)
return filename
def remove_unused_file(self):
# Remove localthickness.tif file from working directory
# as it otherwise is not deleted
os.remove('localthickness.tif')
if __name__ == "__main__":
Interface().run_interface()
\ No newline at end of file
if __name__ == '__main__':
Interface().run_interface()
import gradio as gr
class QimTheme(gr.themes.Default):
"""
Theme for qim3d gradio interfaces.
The theming options are quite broad. However if there is something you can not achieve with this theme
there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
in comments.
"""
def __init__(self, force_light_mode: bool = True):
"""
Parameters:
-----------
- force_light_mode (bool, optional): Gradio themes have dark mode by default.
Parameters
----------
- force_light_mode (bool, optional): Gradio themes have dark mode by default.
QIM platform is not ready for dark mode yet, thus the tools should also be in light mode.
This sets the darkmode values to be the same as light mode values.
"""
super().__init__()
self.force_light_mode = force_light_mode
self.general_values() # Not color related
self.general_values() # Not color related
self.set_light_mode_values()
self.set_dark_mode_values() # Checks the light mode setting inside
self.set_dark_mode_values() # Checks the light mode setting inside
def general_values(self):
self.set_button()
self.set_h1()
def set_light_mode_values(self):
self.set_light_primary_button()
self.set_light_secondary_button()
......@@ -34,8 +38,14 @@ class QimTheme(gr.themes.Default):
def set_dark_mode_values(self):
if self.force_light_mode:
for attr in [dark_attr for dark_attr in dir(self) if not dark_attr.startswith("_") and dark_attr.endswith("dark")]:
self.__dict__[attr] = self.__dict__[attr[:-5]] # ligth and dark attributes have same names except for '_dark' at the end
for attr in [
dark_attr
for dark_attr in dir(self)
if not dark_attr.startswith('_') and dark_attr.endswith('dark')
]:
self.__dict__[attr] = self.__dict__[
attr[:-5]
] # ligth and dark attributes have same names except for '_dark' at the end
else:
self.set_dark_primary_button()
# Secondary button looks good by default in dark mode
......@@ -44,26 +54,28 @@ class QimTheme(gr.themes.Default):
# Example looks good by default in dark mode
def set_button(self):
self.button_transition = "0.15s"
self.button_large_text_weight = "normal"
self.button_transition = '0.15s'
self.button_large_text_weight = 'normal'
def set_light_primary_button(self):
self.run_color = "#198754"
self.button_primary_background_fill = "#FFFFFF"
self.run_color = '#198754'
self.button_primary_background_fill = '#FFFFFF'
self.button_primary_background_fill_hover = self.run_color
self.button_primary_border_color = self.run_color
self.button_primary_text_color = self.run_color
self.button_primary_text_color_hover = "#FFFFFF"
self.button_primary_text_color_hover = '#FFFFFF'
def set_dark_primary_button(self):
self.bright_run_color = "#299764"
self.button_primary_background_fill_dark = self.button_primary_background_fill_hover
self.bright_run_color = '#299764'
self.button_primary_background_fill_dark = (
self.button_primary_background_fill_hover
)
self.button_primary_background_fill_hover_dark = self.bright_run_color
self.button_primary_border_color_dark = self.button_primary_border_color
self.button_primary_border_color_hover_dark = self.bright_run_color
def set_light_secondary_button(self):
self.button_secondary_background_fill = "white"
self.button_secondary_background_fill = 'white'
def set_light_example(self):
"""
......@@ -73,10 +85,10 @@ class QimTheme(gr.themes.Default):
self.color_accent_soft = self.neutral_100
def set_h1(self):
self.text_xxl = "2.5rem"
self.text_xxl = '2.5rem'
def set_light_checkbox(self):
light_blue = "#60a5fa"
light_blue = '#60a5fa'
self.checkbox_background_color_selected = light_blue
self.checkbox_border_color_selected = light_blue
self.checkbox_border_color_focus = light_blue
......@@ -86,21 +98,20 @@ class QimTheme(gr.themes.Default):
self.checkbox_border_color_focus_dark = self.checkbox_border_color_focus_dark
def set_light_cancel_button(self):
self.cancel_color = "#dc3545"
self.button_cancel_background_fill = "white"
self.cancel_color = '#dc3545'
self.button_cancel_background_fill = 'white'
self.button_cancel_background_fill_hover = self.cancel_color
self.button_cancel_border_color = self.cancel_color
self.button_cancel_text_color = self.cancel_color
self.button_cancel_text_color_hover = "white"
self.button_cancel_text_color_hover = 'white'
def set_dark_cancel_button(self):
self.button_cancel_background_fill_dark = self.cancel_color
self.button_cancel_background_fill_hover_dark = "red"
self.button_cancel_background_fill_hover_dark = 'red'
self.button_cancel_border_color_dark = self.cancel_color
self.button_cancel_border_color_hover_dark = "red"
self.button_cancel_text_color_dark = "white"
self.button_cancel_border_color_hover_dark = 'red'
self.button_cancel_text_color_dark = 'white'
# def _get_theme_css(self):
# sup = super()._get_theme_css()
# return "\n.svelte-182fdeq {\nbackground: rgba(255, 0, 0, 0.5) !important;\n}\n" + sup # You have to use !important, so it overrides other css
\ No newline at end of file
# from ._sync import Sync # this will be added back after future development
from ._loading import load, load_mesh
from ._downloader import Downloader
from ._saving import save, save_mesh
# from ._sync import Sync # this will be added back after future development
from ._convert import convert
from ._ome_zarr import export_ome_zarr, import_ome_zarr
......@@ -6,21 +6,24 @@ import nibabel as nib
import numpy as np
import tifffile as tiff
import zarr
from tqdm import tqdm
import zarr.core
import qim3d
from tqdm import tqdm
from qim3d.utils._misc import stringify_path
from qim3d.io import save
class Convert:
def __init__(self, **kwargs):
"""Utility class to convert files to other formats without loading the entire file into memory
"""
Utility class to convert files to other formats without loading the entire file into memory
Args:
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
self.chunk_shape = kwargs.get('chunk_shape', (64, 64, 64))
def convert(self, input_path: str, output_path: str):
def get_file_extension(file_path):
......@@ -29,6 +32,7 @@ class Convert:
root, ext2 = os.path.splitext(root)
ext = ext2 + ext
return ext
# Stringify path in case it is not already a string
input_path = stringify_path(input_path)
input_ext = get_file_extension(input_path)
......@@ -37,28 +41,30 @@ class Convert:
if os.path.isfile(input_path):
match input_ext, output_ext:
case (".tif", ".zarr") | (".tiff", ".zarr"):
case ('.tif', '.zarr') | ('.tiff', '.zarr'):
return self.convert_tif_to_zarr(input_path, output_path)
case (".nii", ".zarr") | (".nii.gz", ".zarr"):
case ('.nii', '.zarr') | ('.nii.gz', '.zarr'):
return self.convert_nifti_to_zarr(input_path, output_path)
case _:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# Load a directory
elif os.path.isdir(input_path):
match input_ext, output_ext:
case (".zarr", ".tif") | (".zarr", ".tiff"):
case ('.zarr', '.tif') | ('.zarr', '.tiff'):
return self.convert_zarr_to_tif(input_path, output_path)
case (".zarr", ".nii"):
case ('.zarr', '.nii'):
return self.convert_zarr_to_nifti(input_path, output_path)
case (".zarr", ".nii.gz"):
return self.convert_zarr_to_nifti(input_path, output_path, compression=True)
case ('.zarr', '.nii.gz'):
return self.convert_zarr_to_nifti(
input_path, output_path, compression=True
)
case _:
raise ValueError("Unsupported file format")
raise ValueError('Unsupported file format')
# Fail
else:
# Find the closest matching path to warn the user
parent_dir = os.path.dirname(input_path) or "."
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
parent_dir = os.path.dirname(input_path) or '.'
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ''
valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
similar_paths = difflib.get_close_matches(input_path, valid_paths)
if similar_paths:
......@@ -66,10 +72,11 @@ class Convert:
message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message))
else:
raise ValueError("Invalid path")
raise ValueError('Invalid path')
def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file
"""
Convert a tiff file to a zarr file
Args:
tif_path (str): path to the tiff file
......@@ -77,10 +84,15 @@ class Convert:
Returns:
zarr.core.Array: zarr array containing the data from the tiff file
"""
vol = tiff.memmap(tif_path)
z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
)
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar
......@@ -98,7 +110,8 @@ class Convert:
return z
def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file
"""
Convert a zarr file to a tiff file
Args:
zarr_path (str): path to the zarr file
......@@ -106,12 +119,14 @@ class Convert:
returns:
None
"""
z = zarr.open(zarr_path)
save(tif_path, z)
qim3d.io.save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file
"""
Convert a nifti file to a zarr file
Args:
nifti_path (str): path to the nifti file
......@@ -119,10 +134,15 @@ class Convert:
Returns:
zarr.core.Array: zarr array containing the data from the nifti file
"""
vol = nib.load(nifti_path).dataobj
z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
)
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar
......@@ -139,8 +159,11 @@ class Convert:
return z
def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
"""Convert a zarr file to a nifti file
def convert_zarr_to_nifti(
self, zarr_path: str, nifti_path: str, compression: bool = False
) -> None:
"""
Convert a zarr file to a nifti file
Args:
zarr_path (str): path to the zarr file
......@@ -148,18 +171,23 @@ class Convert:
Returns:
None
"""
z = zarr.open(zarr_path)
save(nifti_path, z, compression=compression)
qim3d.io.save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
"""Convert a file to another format without loading the entire file into memory
def convert(
input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)
) -> None:
"""
Convert a file to another format without loading the entire file into memory
Args:
input_path (str): path to the input file
output_path (str): path to the output file
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
converter = Convert(chunk_shape=chunk_shape)
converter.convert(input_path, output_path)
"Manages downloads and access to data"
"""Manages downloads and access to data"""
import os
import urllib.request
from urllib.parse import quote
import outputformat as ouf
from tqdm import tqdm
from pathlib import Path
from qim3d.io import load
from qim3d.utils import log
import outputformat as ouf
class Downloader:
"""Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
"""
Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
Attributes:
folder_name (str or os.PathLike): Folder class with the name of the folder in <https://data.qim.dk/>
Methods:
list_files(): Prints the downloadable files from the QIM data repository.
......@@ -51,25 +52,26 @@ class Downloader:
Example:
```python
import qim3d
downloader = qim3d.io.Downloader()
downloader.list_files()
downloader.list_files()
data = downloader.Cowry_Shell.Cowry_DOWNSAMPLED(load_file=True)
qim3d.viz.slicer_orthogonal(data, color_map="magma")
```
![cowry shell](../../assets/screenshots/cowry_shell_slicer.gif)
"""
def __init__(self):
folders = _extract_names()
for idx, folder in enumerate(folders):
exec(f"self.{folder} = self._Myfolder(folder)")
exec(f'self.{folder} = self._Myfolder(folder)')
def list_files(self):
"""Generate and print formatted folder, file, and size information."""
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
folders = _extract_names()
......@@ -78,17 +80,20 @@ class Downloader:
files = _extract_names(folder)
for file in files:
url = os.path.join(url_dl, folder, file).replace("\\", "/")
url = os.path.join(url_dl, folder, file).replace('\\', '/')
file_size = _get_file_size(url)
formatted_file = f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}"
formatted_file = (
f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}"
)
formatted_size = _format_file_size(file_size)
path_string = f'{folder}.{formatted_file}'
log.info(f'{path_string:<50}({formatted_size})')
class _Myfolder:
"""Class for extracting the files from each folder in the Downloader class.
"""
Class for extracting the files from each folder in the Downloader class.
Args:
folder(str): name of the folder of interest in the QIM data repository.
......@@ -99,6 +104,7 @@ class Downloader:
[file_name_2](load_file,optional): Function to download file number 2 in the given folder.
...
[file_name_n](load_file,optional): Function to download file number n in the given folder.
"""
def __init__(self, folder: str):
......@@ -107,14 +113,15 @@ class Downloader:
for idx, file in enumerate(files):
# Changes names to usable function name.
file_name = file
if ("%20" in file) or ("-" in file):
file_name = file_name.replace("%20", "_")
file_name = file_name.replace("-", "_")
if ('%20' in file) or ('-' in file):
file_name = file_name.replace('%20', '_')
file_name = file_name.replace('-', '_')
setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
def _make_fn(self, folder: str, file: str):
"""Private method that returns a function. The function downloads the chosen file from the folder.
"""
Private method that returns a function. The function downloads the chosen file from the folder.
Args:
folder(str): Folder where the file is located.
......@@ -122,23 +129,26 @@ class Downloader:
Returns:
function: the function used to download the file.
"""
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
def _download(load_file: bool = False, virtual_stack: bool = True):
"""Downloads the file and optionally also loads it.
"""
Downloads the file and optionally also loads it.
Args:
load_file(bool,optional): Whether to simply download or also load the file.
Returns:
virtual_stack: The loaded image.
"""
download_file(url_dl, folder, file)
if load_file == True:
log.info(f"\nLoading {file}")
log.info(f'\nLoading {file}')
file_path = os.path.join(folder, file)
return load(path=file_path, virtual_stack=virtual_stack)
......@@ -159,38 +169,40 @@ def _get_file_size(url: str):
Helper function for the ´download_file()´ function. Finds the size of the file.
"""
return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
return int(urllib.request.urlopen(url).info().get('Content-Length', -1))
def download_file(path: str, name: str, file: str):
"""Downloads the file from path / name / file.
"""
Downloads the file from path / name / file.
Args:
path(str): path to the folders available.
name(str): name of the folder of interest.
file(str): name of the file to be downloaded.
"""
if not os.path.exists(name):
os.makedirs(name)
url = os.path.join(path, name, file).replace("\\", "/") # if user is on windows
url = os.path.join(path, name, file).replace('\\', '/') # if user is on windows
file_path = os.path.join(name, file)
if os.path.exists(file_path):
log.warning(f"File already downloaded:\n{os.path.abspath(file_path)}")
log.warning(f'File already downloaded:\n{os.path.abspath(file_path)}')
return
else:
log.info(
f"Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}"
f'Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}'
)
if " " in url:
url = quote(url, safe=":/")
if ' ' in url:
url = quote(url, safe=':/')
with tqdm(
total=_get_file_size(url),
unit="B",
unit='B',
unit_scale=True,
unit_divisor=1024,
ncols=80,
......@@ -203,28 +215,31 @@ def download_file(path: str, name: str, file: str):
def _extract_html(url: str):
"""Extracts the html content of a webpage in "utf-8"
"""
Extracts the html content of a webpage in "utf-8"
Args:
url(str): url to the location where all the data is stored.
Returns:
html_content(str): decoded html.
"""
try:
with urllib.request.urlopen(url) as response:
html_content = response.read().decode(
"utf-8"
'utf-8'
) # Assuming the content is in UTF-8 encoding
except urllib.error.URLError as e:
log.warning(f"Failed to retrieve data from {url}. Error: {e}")
log.warning(f'Failed to retrieve data from {url}. Error: {e}')
return html_content
def _extract_names(name: str = None):
"""Extracts the names of the folders and files.
"""
Extracts the names of the folders and files.
Finds the names of either the folders if no name is given,
or all the names of all files in the given folder.
......@@ -235,31 +250,33 @@ def _extract_names(name: str = None):
Returns:
list: If name is None, returns a list of all folders available.
If name is not None, returns a list of all files available in the given 'name' folder.
"""
url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository"
url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository'
if name:
datapath = os.path.join(url, name).replace("\\", "/")
datapath = os.path.join(url, name).replace('\\', '/')
html_content = _extract_html(datapath)
data_split = html_content.split(
"files/public/projects/viscomp_data_repository/"
'files/public/projects/viscomp_data_repository/'
)[3:]
data_files = [
element.split(" ")[0][(len(name) + 1) : -3] for element in data_split
element.split(' ')[0][(len(name) + 1) : -3] for element in data_split
]
return data_files
else:
html_content = _extract_html(url)
split = html_content.split('"icon-folder-open">')[2:]
folders = [element.split(" ")[0][4:-4] for element in split]
folders = [element.split(' ')[0][4:-4] for element in split]
return folders
def _format_file_size(size_in_bytes):
# Define size units
units = ["B", "KB", "MB", "GB", "TB", "PB"]
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
size = float(size_in_bytes)
unit_index = 0
......@@ -269,4 +286,4 @@ def _format_file_size(size_in_bytes):
unit_index += 1
# Format the size with 1 decimal place
return f"{size:.2f}{units[unit_index]}"
return f'{size:.2f}{units[unit_index]}'
This diff is collapsed.
......@@ -2,39 +2,27 @@
Exporting data to different formats.
"""
import os
import math
import os
import shutil
import logging
from typing import List, Union
import dask.array as da
import numpy as np
import zarr
import tqdm
from ome_zarr import scale
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from ome_zarr.scale import dask_resize
from ome_zarr.writer import (
write_image,
_create_mip,
write_multiscale,
CurrentFormat,
Format,
write_multiscale,
)
from ome_zarr.scale import dask_resize
from ome_zarr.reader import Reader
from ome_zarr import scale
from scipy.ndimage import zoom
from typing import Any, Callable, Iterator, List, Tuple, Union
import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from skimage.transform import (
resize,
)
from qim3d.utils import log
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
from qim3d.utils._ome_zarr import get_n_chunks
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
ListOfArrayLike = Union[List[da.Array], List[np.ndarray]]
ArrayLike = Union[da.Array, np.ndarray]
......@@ -43,10 +31,19 @@ ArrayLike = Union[da.Array, np.ndarray]
class OMEScaler(
scale.Scaler,
):
"""Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken."""
def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
"""
Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken.
"""
def __init__(
self,
order: int = 0,
downscale: float = 2,
max_layer: int = 5,
method: str = 'scaleZYXdask',
):
self.order = order
self.downscale = downscale
self.max_layer = max_layer
......@@ -55,11 +52,11 @@ class OMEScaler(
def scaleZYX(self, base: da.core.Array):
"""Downsample using :func:`scipy.ndimage.zoom`."""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order))
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
......@@ -82,8 +79,8 @@ class OMEScaler(
"""
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library
better_chunksize = tuple(
......@@ -92,7 +89,7 @@ class OMEScaler(
).astype(int)
)
log.debug(f"better chunk size: {better_chunksize}")
log.debug(f'better chunk size: {better_chunksize}')
# Compute the chunk size after the downscaling
new_chunk_size = tuple(
......@@ -100,20 +97,19 @@ class OMEScaler(
)
log.debug(
f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}"
f'orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}'
)
def resize_chunk(chunk, scale_factors, order):
#print(f"zoom factors: {scale_factors}")
# print(f"zoom factors: {scale_factors}")
resized_chunk = zoom(
chunk,
zoom=scale_factors,
order=order,
mode="grid-constant",
mode='grid-constant',
grid_mode=True,
)
#print(f"resized chunk shape: {resized_chunk.shape}")
# print(f"resized chunk shape: {resized_chunk.shape}")
return resized_chunk
......@@ -121,7 +117,7 @@ class OMEScaler(
# Testing new shape
predicted_shape = np.multiply(vol.shape, scale_factors)
log.debug(f"predicted shape: {predicted_shape}")
log.debug(f'predicted shape: {predicted_shape}')
scaled_vol = da.map_blocks(
resize_chunk,
vol,
......@@ -136,7 +132,7 @@ class OMEScaler(
return scaled_vol
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
log.debug(f"\nScale {i+1}\n{'-'*32}")
......@@ -147,17 +143,17 @@ class OMEScaler(
np.ceil(np.multiply(base.shape, downscale_factor)).astype(int)
)
log.debug(f"target shape: {scaled_shape}")
log.debug(f'target shape: {scaled_shape}')
downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float))
log.debug(f"downscale rate: {downscale_rate}")
log.debug(f'downscale rate: {downscale_rate}')
scale_factors = tuple(np.divide(1, downscale_rate))
log.debug(f"scale factors: {scale_factors}")
log.debug(f'scale factors: {scale_factors}')
log.debug("\nResizing volume chunk-wise")
log.debug('\nResizing volume chunk-wise')
scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape)
rv.append(scaled_vol)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
......@@ -165,10 +161,9 @@ class OMEScaler(
"""Downsample using the original OME-Zarr python library"""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer):
scaled_shape = tuple(
base.shape[j] // (self.downscale ** (i + 1)) for j in range(3)
)
......@@ -176,20 +171,20 @@ class OMEScaler(
scaled = dask_resize(base, scaled_shape, order=self.order)
rv.append(scaled)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv)
def export_ome_zarr(
path: str|os.PathLike,
data: np.ndarray|da.core.Array,
path: str | os.PathLike,
data: np.ndarray | da.core.Array,
chunk_size: int = 256,
downsample_rate: int = 2,
order: int = 1,
replace: bool = False,
method: str = "scaleZYX",
method: str = 'scaleZYX',
progress_bar: bool = True,
progress_bar_repeat_time: str = "auto",
progress_bar_repeat_time: str = 'auto',
) -> None:
"""
Export 3D image data to OME-Zarr format with pyramidal downsampling.
......@@ -220,6 +215,7 @@ def export_ome_zarr(
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2)
```
"""
# Check if directory exists
......@@ -228,19 +224,19 @@ def export_ome_zarr(
shutil.rmtree(path)
else:
raise ValueError(
f"Directory {path} already exists. Use replace=True to overwrite."
f'Directory {path} already exists. Use replace=True to overwrite.'
)
# Check if downsample_rate is valid
if downsample_rate <= 1:
raise ValueError("Downsample rate must be greater than 1.")
raise ValueError('Downsample rate must be greater than 1.')
log.info(f"Exporting data to OME-Zarr format at {path}")
log.info(f'Exporting data to OME-Zarr format at {path}')
# Get the number of scales
min_dim = np.max(np.shape(data))
nscales = math.ceil(math.log(min_dim / chunk_size) / math.log(downsample_rate))
log.info(f"Number of scales: {nscales + 1}")
log.info(f'Number of scales: {nscales + 1}')
# Create scaler
scaler = OMEScaler(
......@@ -249,32 +245,31 @@ def export_ome_zarr(
# write the image data
os.mkdir(path)
store = parse_url(path, mode="w").store
store = parse_url(path, mode='w').store
root = zarr.group(store=store)
# Check if we want to process using Dask
if "dask" in method and not isinstance(data, da.Array):
log.info("\nConverting input data to Dask array")
if 'dask' in method and not isinstance(data, da.Array):
log.info('\nConverting input data to Dask array')
data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
elif "dask" in method and isinstance(data, da.Array):
log.info("\nInput data will be rechunked")
elif 'dask' in method and isinstance(data, da.Array):
log.info('\nInput data will be rechunked')
data = data.rechunk((chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
log.info("Calculating the multi-scale pyramid")
log.info('Calculating the multi-scale pyramid')
# Generate multi-scale pyramid
mip = scaler.func(data)
log.info("Writing data to disk")
log.info('Writing data to disk')
kwargs = dict(
pyramid=mip,
group=root,
fmt=CurrentFormat(),
axes="zyx",
axes='zyx',
name=None,
compute=True,
storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)),
......@@ -291,16 +286,14 @@ def export_ome_zarr(
else:
write_multiscale(**kwargs)
log.info("\nAll done!")
log.info('\nAll done!')
return
def import_ome_zarr(
path: str|os.PathLike,
scale: int = 0,
load: bool = True
) -> np.ndarray:
path: str | os.PathLike, scale: int = 0, load: bool = True
) -> np.ndarray:
"""
Import image data from an OME-Zarr file.
......@@ -339,22 +332,22 @@ def import_ome_zarr(
image_node = nodes[0]
dask_data = image_node.data
log.info(f"Data contains {len(dask_data)} scales:")
log.info(f'Data contains {len(dask_data)} scales:')
for i in np.arange(len(dask_data)):
log.info(f"- Scale {i}: {dask_data[i].shape}")
log.info(f'- Scale {i}: {dask_data[i].shape}')
if scale == "highest":
if scale == 'highest':
scale = 0
if scale == "lowest":
if scale == 'lowest':
scale = len(dask_data) - 1
if scale >= len(dask_data):
raise ValueError(
f"Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}."
f'Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}.'
)
log.info(f"\nLoading scale {scale} with shape {dask_data[scale].shape}")
log.info(f'\nLoading scale {scale} with shape {dask_data[scale].shape}')
if load:
vol = dask_data[scale].compute()
......
This diff is collapsed.
""" Dataset synchronization tasks """
"""Dataset synchronization tasks"""
import os
import subprocess
from pathlib import Path
import outputformat as ouf
from qim3d.utils import log
from pathlib import Path
class Sync:
"""Class for dataset synchronization tasks"""
def __init__(self):
# Checks if rsync is available
if not self._check_rsync():
raise RuntimeError(
"Could not find rsync, please check if it is installed in your system."
'Could not find rsync, please check if it is installed in your system.'
)
def _check_rsync(self):
"""Check if rsync is available"""
try:
subprocess.run(["rsync", "--version"], capture_output=True, check=True)
subprocess.run(['rsync', '--version'], capture_output=True, check=True)
return True
except Exception as error:
log.error("rsync is not available")
log.error('rsync is not available')
log.error(error)
return False
def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
"""Check if all files from 'source' are in 'destination'
def check_destination(
self,
source: str,
destination: str,
checksum: bool = False,
verbose: bool = True,
) -> list[str]:
"""
Check if all files from 'source' are in 'destination'
This function compares the files in the 'source' directory to those in
the 'destination' directory and reports any differences or missing files.
......@@ -51,13 +62,13 @@ class Sync:
destination = Path(destination)
if checksum:
rsync_args = "-avrc"
rsync_args = '-avrc'
else:
rsync_args = "-avr"
rsync_args = '-avr'
command = [
"rsync",
"-n",
'rsync',
'-n',
rsync_args,
str(source) + os.path.sep,
str(destination) + os.path.sep,
......@@ -70,18 +81,25 @@ class Sync:
)
diff_files_and_folders = out.stdout.decode().splitlines()[1:-3]
diff_files = [f for f in diff_files_and_folders if not f.endswith("/")]
diff_files = [f for f in diff_files_and_folders if not f.endswith('/')]
if len(diff_files) > 0 and verbose:
title = "Source files differing or missing in destination"
title = 'Source files differing or missing in destination'
log.info(
ouf.showlist(diff_files, style="line", return_str=True, title=title)
ouf.showlist(diff_files, style='line', return_str=True, title=title)
)
return diff_files
def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
"""Checks whether 'source' and 'destination' directories are synchronized.
def compare_dirs(
self,
source: str,
destination: str,
checksum: bool = False,
verbose: bool = True,
) -> None:
"""
Checks whether 'source' and 'destination' directories are synchronized.
This function compares the contents of two directories
('source' and 'destination') and reports any differences.
......@@ -107,7 +125,7 @@ class Sync:
if verbose:
s_files, s_dirs = self.count_files_and_dirs(source)
d_files, d_dirs = self.count_files_and_dirs(destination)
log.info("\n")
log.info('\n')
s_d = self.check_destination(
source, destination, checksum=checksum, verbose=False
......@@ -120,7 +138,7 @@ class Sync:
# No differences
if verbose:
log.info(
"Source and destination are synchronized, no differences found."
'Source and destination are synchronized, no differences found.'
)
return
......@@ -128,9 +146,9 @@ class Sync:
log.info(
ouf.showlist(
union,
style="line",
style='line',
return_str=True,
title=f"{len(union)} files are not in sync",
title=f'{len(union)} files are not in sync',
)
)
......@@ -139,9 +157,9 @@ class Sync:
log.info(
ouf.showlist(
intersection,
style="line",
style='line',
return_str=True,
title=f"{len(intersection)} files present on both, but not equal",
title=f'{len(intersection)} files present on both, but not equal',
)
)
......@@ -150,9 +168,9 @@ class Sync:
log.info(
ouf.showlist(
s_exclusive,
style="line",
style='line',
return_str=True,
title=f"{len(s_exclusive)} files present only on {source}",
title=f'{len(s_exclusive)} files present only on {source}',
)
)
......@@ -161,15 +179,18 @@ class Sync:
log.info(
ouf.showlist(
d_exclusive,
style="line",
style='line',
return_str=True,
title=f"{len(d_exclusive)} files present only on {destination}",
title=f'{len(d_exclusive)} files present only on {destination}',
)
)
return
def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
"""Count the number of files and directories in the given path.
def count_files_and_dirs(
self, path: str | os.PathLike, verbose: bool = True
) -> tuple[int, int]:
"""
Count the number of files and directories in the given path.
This function recursively counts the number of files and
directories in the specified directory 'path'.
......@@ -202,6 +223,6 @@ class Sync:
dirs += dirs_count
if verbose:
log.info(f"Total of {files} files and {dirs} directories on {path}")
log.info(f'Total of {files} files and {dirs} directories on {path}')
return files, dirs
This diff is collapsed.
from ._unet import UNet, Hyperparameters
from ._unet import Hyperparameters, UNet