Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
Loading items

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes
Showing
with 1476 additions and 1190 deletions
from ._generators import noise_object
from ._aggregators import noise_object_collection from ._aggregators import noise_object_collection
from ._generators import noise_object
...@@ -22,6 +22,7 @@ def random_placement( ...@@ -22,6 +22,7 @@ def random_placement(
Returns: Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed. collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success. placed (bool): Flag for placement success.
""" """
# Find available (zero) elements in collection # Find available (zero) elements in collection
available_z, available_y, available_x = np.where(collection == 0) available_z, available_y, available_x = np.where(collection == 0)
...@@ -44,14 +45,12 @@ def random_placement( ...@@ -44,14 +45,12 @@ def random_placement(
if np.all( if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0 collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
): ):
# Check if placement is within bounds (bool) # Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all( within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape) end <= np.array(collection.shape)
) )
if within_bounds: if within_bounds:
# Place blob # Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = ( collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob blob
...@@ -81,6 +80,7 @@ def specific_placement( ...@@ -81,6 +80,7 @@ def specific_placement(
collection (numpy.ndarray): 3D volume of the collection with the blob placed. collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success. placed (bool): Flag for placement success.
positions (list[tuple]): List of remaining positions to place blobs. positions (list[tuple]): List of remaining positions to place blobs.
""" """
# Flag for placement success # Flag for placement success
placed = False placed = False
...@@ -99,14 +99,12 @@ def specific_placement( ...@@ -99,14 +99,12 @@ def specific_placement(
if np.all( if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0 collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
): ):
# Check if placement is within bounds (bool) # Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all( within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape) end <= np.array(collection.shape)
) )
if within_bounds: if within_bounds:
# Place blob # Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = ( collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob blob
...@@ -289,23 +287,24 @@ def noise_object_collection( ...@@ -289,23 +287,24 @@ def noise_object_collection(
qim3d.viz.slices_grid(vol, num_slices=15, slice_axis=1) qim3d.viz.slices_grid(vol, num_slices=15, slice_axis=1)
``` ```
![synthetic_collection_tube](../../assets/screenshots/synthetic_collection_tube_slices.png) ![synthetic_collection_tube](../../assets/screenshots/synthetic_collection_tube_slices.png)
""" """
if verbose: if verbose:
original_log_level = log.getEffectiveLevel() original_log_level = log.getEffectiveLevel()
log.setLevel("DEBUG") log.setLevel('DEBUG')
# Check valid input types # Check valid input types
if not isinstance(collection_shape, tuple) or len(collection_shape) != 3: if not isinstance(collection_shape, tuple) or len(collection_shape) != 3:
raise TypeError( raise TypeError(
"Shape of collection must be a tuple with three dimensions (z, y, x)" 'Shape of collection must be a tuple with three dimensions (z, y, x)'
) )
if len(min_shape) != len(max_shape): if len(min_shape) != len(max_shape):
raise ValueError("Object shapes must be tuples of the same length") raise ValueError('Object shapes must be tuples of the same length')
if (positions is not None) and (len(positions) != num_objects): if (positions is not None) and (len(positions) != num_objects):
raise ValueError( raise ValueError(
"Number of objects must match number of positions, otherwise set positions = None" 'Number of objects must match number of positions, otherwise set positions = None'
) )
# Set seed for random number generator # Set seed for random number generator
...@@ -318,8 +317,8 @@ def noise_object_collection( ...@@ -318,8 +317,8 @@ def noise_object_collection(
labels = np.zeros_like(collection_array) labels = np.zeros_like(collection_array)
# Fill the 3D array with synthetic blobs # Fill the 3D array with synthetic blobs
for i in tqdm(range(num_objects), desc="Objects placed"): for i in tqdm(range(num_objects), desc='Objects placed'):
log.debug(f"\nObject #{i+1}") log.debug(f'\nObject #{i+1}')
# Sample from blob parameter ranges # Sample from blob parameter ranges
if min_shape == max_shape: if min_shape == max_shape:
...@@ -328,7 +327,7 @@ def noise_object_collection( ...@@ -328,7 +327,7 @@ def noise_object_collection(
blob_shape = tuple( blob_shape = tuple(
rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3) rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3)
) )
log.debug(f"- Blob shape: {blob_shape}") log.debug(f'- Blob shape: {blob_shape}')
# Scale object shape # Scale object shape
final_shape = tuple(l * r for l, r in zip(blob_shape, object_shape_zoom)) final_shape = tuple(l * r for l, r in zip(blob_shape, object_shape_zoom))
...@@ -336,19 +335,19 @@ def noise_object_collection( ...@@ -336,19 +335,19 @@ def noise_object_collection(
# Sample noise scale # Sample noise scale
noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise) noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise)
log.debug(f"- Object noise scale: {noise_scale:.4f}") log.debug(f'- Object noise scale: {noise_scale:.4f}')
gamma = rng.uniform(low=min_gamma, high=max_gamma) gamma = rng.uniform(low=min_gamma, high=max_gamma)
log.debug(f"- Gamma correction: {gamma:.3f}") log.debug(f'- Gamma correction: {gamma:.3f}')
if max_high_value > min_high_value: if max_high_value > min_high_value:
max_value = rng.integers(low=min_high_value, high=max_high_value) max_value = rng.integers(low=min_high_value, high=max_high_value)
else: else:
max_value = min_high_value max_value = min_high_value
log.debug(f"- Max value: {max_value}") log.debug(f'- Max value: {max_value}')
threshold = rng.uniform(low=min_threshold, high=max_threshold) threshold = rng.uniform(low=min_threshold, high=max_threshold)
log.debug(f"- Threshold: {threshold:.3f}") log.debug(f'- Threshold: {threshold:.3f}')
# Generate synthetic object # Generate synthetic object
blob = qim3d.generate.noise_object( blob = qim3d.generate.noise_object(
...@@ -368,7 +367,7 @@ def noise_object_collection( ...@@ -368,7 +367,7 @@ def noise_object_collection(
low=min_rotation_degrees, high=max_rotation_degrees low=min_rotation_degrees, high=max_rotation_degrees
) # Sample rotation angle ) # Sample rotation angle
axes = rng.choice(rotation_axes) # Sample the two axes to rotate around axes = rng.choice(rotation_axes) # Sample the two axes to rotate around
log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}") log.debug(f'- Rotation angle: {angle:.2f} at axes: {axes}')
blob = scipy.ndimage.rotate(blob, angle, axes, order=1) blob = scipy.ndimage.rotate(blob, angle, axes, order=1)
...@@ -397,7 +396,7 @@ def noise_object_collection( ...@@ -397,7 +396,7 @@ def noise_object_collection(
if not placed: if not placed:
# Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar) # Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar)
log.error( log.error(
f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects." f'Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects.'
) )
if verbose: if verbose:
......
...@@ -4,6 +4,7 @@ from noise import pnoise3 ...@@ -4,6 +4,7 @@ from noise import pnoise3
import qim3d.processing import qim3d.processing
def noise_object( def noise_object(
base_shape: tuple = (128, 128, 128), base_shape: tuple = (128, 128, 128),
final_shape: tuple = (128, 128, 128), final_shape: tuple = (128, 128, 128),
...@@ -14,7 +15,7 @@ def noise_object( ...@@ -14,7 +15,7 @@ def noise_object(
threshold: float = 0.5, threshold: float = 0.5,
smooth_borders: bool = False, smooth_borders: bool = False,
object_shape: str = None, object_shape: str = None,
dtype: str = "uint8", dtype: str = 'uint8',
) -> np.ndarray: ) -> np.ndarray:
""" """
Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction. Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction.
...@@ -103,12 +104,13 @@ def noise_object( ...@@ -103,12 +104,13 @@ def noise_object(
qim3d.viz.slices_grid(vol, num_slices=15) qim3d.viz.slices_grid(vol, num_slices=15)
``` ```
![synthetic_blob_tube_slice](../../assets/screenshots/synthetic_blob_tube_slice.png) ![synthetic_blob_tube_slice](../../assets/screenshots/synthetic_blob_tube_slice.png)
""" """
if not isinstance(final_shape, tuple) or len(final_shape) != 3: if not isinstance(final_shape, tuple) or len(final_shape) != 3:
raise TypeError("Size must be a tuple of 3 dimensions") raise TypeError('Size must be a tuple of 3 dimensions')
if not np.issubdtype(dtype, np.number): if not np.issubdtype(dtype, np.number):
raise ValueError("Invalid data type") raise ValueError('Invalid data type')
# Initialize the 3D array for the shape # Initialize the 3D array for the shape
volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32) volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32)
...@@ -119,18 +121,17 @@ def noise_object( ...@@ -119,18 +121,17 @@ def noise_object(
# Calculate the distance from the center of the shape # Calculate the distance from the center of the shape
center = np.array(base_shape) / 2 center = np.array(base_shape) / 2
dist = np.sqrt((z - center[0])**2 + dist = np.sqrt((z - center[0]) ** 2 + (y - center[1]) ** 2 + (x - center[2]) ** 2)
(y - center[1])**2 +
(x - center[2])**2)
dist /= np.sqrt(3 * (center[0] ** 2)) dist /= np.sqrt(3 * (center[0] ** 2))
# Generate Perlin noise and adjust the values based on the distance from the center # Generate Perlin noise and adjust the values based on the distance from the center
vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input vectorized_pnoise3 = np.vectorize(
pnoise3
) # Vectorize pnoise3, since it only takes scalar input
noise = vectorized_pnoise3(z.flatten() * noise_scale, noise = vectorized_pnoise3(
y.flatten() * noise_scale, z.flatten() * noise_scale, y.flatten() * noise_scale, x.flatten() * noise_scale
x.flatten() * noise_scale
).reshape(base_shape) ).reshape(base_shape)
volume = (1 + noise) * (1 - dist) volume = (1 + noise) * (1 - dist)
...@@ -150,11 +151,16 @@ def noise_object( ...@@ -150,11 +151,16 @@ def noise_object(
if smooth_borders: if smooth_borders:
# Maximum value among the six sides of the 3D volume # Maximum value among the six sides of the 3D volume
max_border_value = np.max([ max_border_value = np.max(
np.max(volume[0, :, :]), np.max(volume[-1, :, :]), [
np.max(volume[:, 0, :]), np.max(volume[:, -1, :]), np.max(volume[0, :, :]),
np.max(volume[:, :, 0]), np.max(volume[:, :, -1]) np.max(volume[-1, :, :]),
]) np.max(volume[:, 0, :]),
np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]),
np.max(volume[:, :, -1]),
]
)
# Compute threshold such that there will be no straight cuts in the blob # Compute threshold such that there will be no straight cuts in the blob
threshold = max_border_value / max_value threshold = max_border_value / max_value
...@@ -171,42 +177,47 @@ def noise_object( ...@@ -171,42 +177,47 @@ def noise_object(
) )
# Fade into a shape if specified # Fade into a shape if specified
if object_shape == "cylinder": if object_shape == 'cylinder':
# Arguments for the fade_mask function # Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest axis = np.argmax(
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical volume.shape
) # Fade along the dimension where the object is the largest
target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
volume = qim3d.operations.fade_mask(volume, volume = qim3d.operations.fade_mask(
volume,
geometry=geometry, geometry=geometry,
axis=axis, axis=axis,
target_max_normalized_distance = target_max_normalized_distance target_max_normalized_distance=target_max_normalized_distance,
) )
elif object_shape == "tube": elif object_shape == 'tube':
# Arguments for the fade_mask function # Arguments for the fade_mask function
geometry = "cylindrical" # Fade in cylindrical geometry geometry = 'cylindrical' # Fade in cylindrical geometry
axis = np.argmax(volume.shape) # Fade along the dimension where the object is the largest axis = np.argmax(
volume.shape
) # Fade along the dimension where the object is the largest
decay_rate = 5 # Decay rate for the fade operation decay_rate = 5 # Decay rate for the fade operation
target_max_normalized_distance = 1.4 # This value ensures that the object will become cylindrical target_max_normalized_distance = (
1.4 # This value ensures that the object will become cylindrical
)
# Fade once for making the object cylindrical # Fade once for making the object cylindrical
volume = qim3d.operations.fade_mask(volume, volume = qim3d.operations.fade_mask(
volume,
geometry=geometry, geometry=geometry,
axis=axis, axis=axis,
decay_rate=decay_rate, decay_rate=decay_rate,
target_max_normalized_distance=target_max_normalized_distance, target_max_normalized_distance=target_max_normalized_distance,
invert = False invert=False,
) )
# Fade again with invert = True for making the object a tube (i.e. with a hole in the middle) # Fade again with invert = True for making the object a tube (i.e. with a hole in the middle)
volume = qim3d.operations.fade_mask(volume, volume = qim3d.operations.fade_mask(
geometry = geometry, volume, geometry=geometry, axis=axis, decay_rate=decay_rate, invert=True
axis = axis,
decay_rate = decay_rate,
invert = True
) )
# Convert to desired data type # Convert to desired data type
......
from fastapi import FastAPI from fastapi import FastAPI
import qim3d.utils import qim3d.utils
from . import data_explorer
from . import iso3d from . import annotation_tool, data_explorer, iso3d, layers2d, local_thickness
from . import local_thickness
from . import annotation_tool
from . import layers2d
from .qim_theme import QimTheme from .qim_theme import QimTheme
def run_gradio_app(gradio_interface, host="0.0.0.0"): def run_gradio_app(gradio_interface, host='0.0.0.0'):
import gradio as gr import gradio as gr
import uvicorn import uvicorn
# Get port using the QIM API # Get port using the QIM API
port_dict = qim3d.utils.get_port_dict() port_dict = qim3d.utils.get_port_dict()
if "gradio_port" in port_dict: if 'gradio_port' in port_dict:
port = port_dict["gradio_port"] port = port_dict['gradio_port']
elif "port" in port_dict: elif 'port' in port_dict:
port = port_dict["port"] port = port_dict['port']
else: else:
raise Exception("Port not specified from QIM API") raise Exception('Port not specified from QIM API')
qim3d.utils.gradio_header(gradio_interface.title, port) qim3d.utils.gradio_header(gradio_interface.title, port)
...@@ -30,7 +28,7 @@ def run_gradio_app(gradio_interface, host="0.0.0.0"): ...@@ -30,7 +28,7 @@ def run_gradio_app(gradio_interface, host="0.0.0.0"):
app = gr.mount_gradio_app(app, gradio_interface, path=path) app = gr.mount_gradio_app(app, gradio_interface, path=path)
# Full path # Full path
print(f"http://{host}:{port}{path}") print(f'http://{host}:{port}{path}')
# Run the FastAPI server usign uvicorn # Run the FastAPI server usign uvicorn
uvicorn.run(app, host=host, port=int(port)) uvicorn.run(app, host=host, port=int(port))
...@@ -27,6 +27,7 @@ import tempfile ...@@ -27,6 +27,7 @@ import tempfile
import gradio as gr import gradio as gr
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import qim3d import qim3d
from qim3d.gui.interface import BaseInterface from qim3d.gui.interface import BaseInterface
...@@ -34,17 +35,19 @@ from qim3d.gui.interface import BaseInterface ...@@ -34,17 +35,19 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface): class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None): def __init__(
self, name_suffix: str = '', verbose: bool = False, img: np.ndarray = None
):
super().__init__( super().__init__(
title="Annotation Tool", title='Annotation Tool',
height=768, height=768,
width="100%", width='100%',
verbose=verbose, verbose=verbose,
custom_css="annotation_tool.css", custom_css='annotation_tool.css',
) )
self.username = getpass.getuser() self.username = getpass.getuser()
self.temp_dir = os.path.join(tempfile.gettempdir(), f"qim-{self.username}") self.temp_dir = os.path.join(tempfile.gettempdir(), f'qim-{self.username}')
self.name_suffix = name_suffix self.name_suffix = name_suffix
self.img = img self.img = img
...@@ -57,7 +60,7 @@ class Interface(BaseInterface): ...@@ -57,7 +60,7 @@ class Interface(BaseInterface):
# Get the temporary files from gradio # Get the temporary files from gradio
temp_path_list = [] temp_path_list = []
for filename in os.listdir(self.temp_dir): for filename in os.listdir(self.temp_dir):
if "mask" and self.name_suffix in str(filename): if 'mask' and self.name_suffix in str(filename):
# Get the list of the temporary files # Get the list of the temporary files
temp_path_list.append(os.path.join(self.temp_dir, filename)) temp_path_list.append(os.path.join(self.temp_dir, filename))
...@@ -76,9 +79,9 @@ class Interface(BaseInterface): ...@@ -76,9 +79,9 @@ class Interface(BaseInterface):
this is safer and backwards compatible (should be) this is safer and backwards compatible (should be)
""" """
self.mask_names = [ self.mask_names = [
f"red{self.name_suffix}", f'red{self.name_suffix}',
f"green{self.name_suffix}", f'green{self.name_suffix}',
f"blue{self.name_suffix}", f'blue{self.name_suffix}',
] ]
# Clean up old files # Clean up old files
...@@ -86,7 +89,7 @@ class Interface(BaseInterface): ...@@ -86,7 +89,7 @@ class Interface(BaseInterface):
files = os.listdir(self.temp_dir) files = os.listdir(self.temp_dir)
for filename in files: for filename in files:
# Check if "mask" is in the filename # Check if "mask" is in the filename
if ("mask" in filename) and (self.name_suffix in filename): if ('mask' in filename) and (self.name_suffix in filename):
file_path = os.path.join(self.temp_dir, filename) file_path = os.path.join(self.temp_dir, filename)
os.remove(file_path) os.remove(file_path)
...@@ -94,13 +97,13 @@ class Interface(BaseInterface): ...@@ -94,13 +97,13 @@ class Interface(BaseInterface):
files = None files = None
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray: def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"] background = img_editor['background']
masks = img_editor["layers"][0] masks = img_editor['layers'][0]
overlay_image = qim3d.operations.overlay_rgb_images(background, masks) overlay_image = qim3d.operations.overlay_rgb_images(background, masks)
return overlay_image return overlay_image
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]: def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0] masks_rgb = img_editor['layers'][0]
mask_threshold = 200 # This value is based mask_threshold = 200 # This value is based
mask_list = [] mask_list = []
...@@ -114,7 +117,7 @@ class Interface(BaseInterface): ...@@ -114,7 +117,7 @@ class Interface(BaseInterface):
# Save only if we have a mask # Save only if we have a mask
if np.sum(mask) > 0: if np.sum(mask) > 0:
mask_list.append(mask) mask_list.append(mask)
filename = f"mask_{self.mask_names[idx]}.tif" filename = f'mask_{self.mask_names[idx]}.tif'
if not os.path.exists(self.temp_dir): if not os.path.exists(self.temp_dir):
os.makedirs(self.temp_dir) os.makedirs(self.temp_dir)
filepath = os.path.join(self.temp_dir, filename) filepath = os.path.join(self.temp_dir, filename)
...@@ -128,11 +131,11 @@ class Interface(BaseInterface): ...@@ -128,11 +131,11 @@ class Interface(BaseInterface):
def define_interface(self, **kwargs): def define_interface(self, **kwargs):
brush = gr.Brush( brush = gr.Brush(
colors=[ colors=[
"rgb(255,50,100)", 'rgb(255,50,100)',
"rgb(50,250,100)", 'rgb(50,250,100)',
"rgb(50,100,255)", 'rgb(50,100,255)',
], ],
color_mode="fixed", color_mode='fixed',
default_size=10, default_size=10,
) )
with gr.Row(): with gr.Row():
...@@ -142,26 +145,25 @@ class Interface(BaseInterface): ...@@ -142,26 +145,25 @@ class Interface(BaseInterface):
img_editor = gr.ImageEditor( img_editor = gr.ImageEditor(
value=( value=(
{ {
"background": self.img, 'background': self.img,
"layers": [Image.new("RGBA", self.img.shape, (0, 0, 0, 0))], 'layers': [Image.new('RGBA', self.img.shape, (0, 0, 0, 0))],
"composite": None, 'composite': None,
} }
if self.img is not None if self.img is not None
else None else None
), ),
type="numpy", type='numpy',
image_mode="RGB", image_mode='RGB',
brush=brush, brush=brush,
sources="upload", sources='upload',
interactive=True, interactive=True,
show_download_button=True, show_download_button=True,
container=False, container=False,
transforms=["crop"], transforms=['crop'],
layers=False, layers=False,
) )
with gr.Column(scale=1, min_width=256): with gr.Column(scale=1, min_width=256):
with gr.Row(): with gr.Row():
overlay_img = gr.Image( overlay_img = gr.Image(
show_download_button=False, show_download_button=False,
...@@ -169,7 +171,7 @@ class Interface(BaseInterface): ...@@ -169,7 +171,7 @@ class Interface(BaseInterface):
visible=False, visible=False,
) )
with gr.Row(): with gr.Row():
masks_download = gr.File(label="Download masks", visible=False) masks_download = gr.File(label='Download masks', visible=False)
# fmt: off # fmt: off
img_editor.change( img_editor.change(
......
This diff is collapsed.
from abc import ABC, abstractmethod
from os import listdir, path
from pathlib import Path from pathlib import Path
from abc import abstractmethod, ABC
from os import path, listdir
import gradio as gr import gradio as gr
import numpy as np
from .qim_theme import QimTheme
import qim3d.gui import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli # TODO: when offline it throws an error in cli
class BaseInterface(ABC): class BaseInterface(ABC):
""" """
Annotation tool and Data explorer as those don't need any examples. Annotation tool and Data explorer as those don't need any examples.
""" """
...@@ -19,7 +19,7 @@ class BaseInterface(ABC): ...@@ -19,7 +19,7 @@ class BaseInterface(ABC):
self, self,
title: str, title: str,
height: int, height: int,
width: int = "100%", width: int = '100%',
verbose: bool = False, verbose: bool = False,
custom_css: str = None, custom_css: str = None,
): ):
...@@ -38,7 +38,7 @@ class BaseInterface(ABC): ...@@ -38,7 +38,7 @@ class BaseInterface(ABC):
self.qim_dir = Path(qim3d.__file__).parents[0] self.qim_dir = Path(qim3d.__file__).parents[0]
self.custom_css = ( self.custom_css = (
path.join(self.qim_dir, "css", custom_css) path.join(self.qim_dir, 'css', custom_css)
if custom_css is not None if custom_css is not None
else None else None
) )
...@@ -72,8 +72,7 @@ class BaseInterface(ABC): ...@@ -72,8 +72,7 @@ class BaseInterface(ABC):
quiet=not self.verbose, quiet=not self.verbose,
height=self.height, height=self.height,
width=self.width, width=self.width,
favicon_path=Path(qim3d.__file__).parents[0] favicon_path=Path(qim3d.__file__).parents[0] / 'gui/assets/qim3d-icon.svg',
/ "gui/assets/qim3d-icon.svg",
**kwargs, **kwargs,
) )
...@@ -88,7 +87,7 @@ class BaseInterface(ABC): ...@@ -88,7 +87,7 @@ class BaseInterface(ABC):
title=self.title, title=self.title,
css=self.custom_css, css=self.custom_css,
) as gradio_interface: ) as gradio_interface:
gr.Markdown(f"# {self.title}") gr.Markdown(f'# {self.title}')
self.define_interface(**kwargs) self.define_interface(**kwargs)
return gradio_interface return gradio_interface
...@@ -96,11 +95,12 @@ class BaseInterface(ABC): ...@@ -96,11 +95,12 @@ class BaseInterface(ABC):
def define_interface(self, **kwargs): def define_interface(self, **kwargs):
pass pass
def run_interface(self, host: str = "0.0.0.0"): def run_interface(self, host: str = '0.0.0.0'):
qim3d.gui.run_gradio_app(self.create_interface(), host) qim3d.gui.run_gradio_app(self.create_interface(), host)
class InterfaceWithExamples(BaseInterface): class InterfaceWithExamples(BaseInterface):
""" """
For Iso3D and Local Thickness For Iso3D and Local Thickness
""" """
...@@ -117,7 +117,23 @@ class InterfaceWithExamples(BaseInterface): ...@@ -117,7 +117,23 @@ class InterfaceWithExamples(BaseInterface):
self._set_examples_list() self._set_examples_list()
def _set_examples_list(self): def _set_examples_list(self):
valid_sufixes = (".tif", ".tiff", ".h5", ".nii", ".gz", ".dcm", ".DCM", ".vol", ".vgi", ".txrm", ".txm", ".xrm") valid_sufixes = (
'.tif',
'.tiff',
'.h5',
'.nii',
'.gz',
'.dcm',
'.DCM',
'.vol',
'.vgi',
'.txrm',
'.txm',
'.xrm',
)
examples_folder = path.join(self.qim_dir, 'examples') examples_folder = path.join(self.qim_dir, 'examples')
self.img_examples = [path.join(examples_folder, example) for example in listdir(examples_folder) if example.endswith(valid_sufixes)] self.img_examples = [
path.join(examples_folder, example)
for example in listdir(examples_folder)
if example.endswith(valid_sufixes)
]
...@@ -15,6 +15,7 @@ app.launch() ...@@ -15,6 +15,7 @@ app.launch()
``` ```
""" """
import os import os
import gradio as gr import gradio as gr
...@@ -23,21 +24,19 @@ import plotly.graph_objects as go ...@@ -23,21 +24,19 @@ import plotly.graph_objects as go
from scipy import ndimage from scipy import ndimage
import qim3d import qim3d
from qim3d.utils._logger import log
from qim3d.gui.interface import InterfaceWithExamples from qim3d.gui.interface import InterfaceWithExamples
from qim3d.utils._logger import log
# TODO img in launch should be self.img # TODO img in launch should be self.img
class Interface(InterfaceWithExamples): class Interface(InterfaceWithExamples):
def __init__(self, def __init__(self, verbose: bool = False, plot_height: int = 768, img=None):
verbose:bool = False, super().__init__(
plot_height:int = 768, title='Isosurfaces for 3D visualization',
img = None):
super().__init__(title = "Isosurfaces for 3D visualization",
height=1024, height=1024,
width=960, width=960,
verbose = verbose) verbose=verbose,
)
self.interface = None self.interface = None
self.img = img self.img = img
...@@ -48,11 +47,13 @@ class Interface(InterfaceWithExamples): ...@@ -48,11 +47,13 @@ class Interface(InterfaceWithExamples):
self.vol = qim3d.io.load(gradiofile.name) self.vol = qim3d.io.load(gradiofile.name)
assert self.vol.ndim == 3 assert self.vol.ndim == 3
except AttributeError: except AttributeError:
raise gr.Error("You have to select a file") raise gr.Error('You have to select a file')
except ValueError: except ValueError:
raise gr.Error("Unsupported file format") raise gr.Error('Unsupported file format')
except AssertionError: except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
def resize_vol(self, display_size: int): def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size""" """Resizes the loaded volume to the display size"""
...@@ -61,7 +62,7 @@ class Interface(InterfaceWithExamples): ...@@ -61,7 +62,7 @@ class Interface(InterfaceWithExamples):
original_Z, original_Y, original_X = np.shape(self.vol) original_Z, original_Y, original_X = np.shape(self.vol)
max_size = np.max([original_Z, original_Y, original_X]) max_size = np.max([original_Z, original_Y, original_X])
if self.verbose: if self.verbose:
log.info(f"\nOriginal volume: {original_Z, original_Y, original_X}") log.info(f'\nOriginal volume: {original_Z, original_Y, original_X}')
# Resize for display # Resize for display
self.vol = ndimage.zoom( self.vol = ndimage.zoom(
...@@ -76,14 +77,15 @@ class Interface(InterfaceWithExamples): ...@@ -76,14 +77,15 @@ class Interface(InterfaceWithExamples):
) )
if self.verbose: if self.verbose:
log.info( log.info(
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}" f'Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}'
) )
def save_fig(self, fig: go.Figure, filename: str): def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk # Write Plotly figure to disk
fig.write_html(filename) fig.write_html(filename)
def create_fig(self, def create_fig(
self,
gradio_file: gr.File, gradio_file: gr.File,
display_size: int, display_size: int,
opacity: float, opacity: float,
...@@ -106,7 +108,6 @@ class Interface(InterfaceWithExamples): ...@@ -106,7 +108,6 @@ class Interface(InterfaceWithExamples):
show_x_slice: bool, show_x_slice: bool,
slice_x_location: int, slice_x_location: int,
) -> tuple[go.Figure, str]: ) -> tuple[go.Figure, str]:
# Load volume # Load volume
self.load_data(gradio_file) self.load_data(gradio_file)
...@@ -161,10 +162,10 @@ class Interface(InterfaceWithExamples): ...@@ -161,10 +162,10 @@ class Interface(InterfaceWithExamples):
), ),
showscale=show_colorbar, showscale=show_colorbar,
colorbar=dict( colorbar=dict(
thickness=8, outlinecolor="#fff", len=0.5, orientation="h" thickness=8, outlinecolor='#fff', len=0.5, orientation='h'
), ),
reversescale=reversescale, reversescale=reversescale,
hoverinfo = "skip", hoverinfo='skip',
) )
) )
...@@ -175,13 +176,13 @@ class Interface(InterfaceWithExamples): ...@@ -175,13 +176,13 @@ class Interface(InterfaceWithExamples):
scene_xaxis_visible=show_axis, scene_xaxis_visible=show_axis,
scene_yaxis_visible=show_axis, scene_yaxis_visible=show_axis,
scene_zaxis_visible=show_axis, scene_zaxis_visible=show_axis,
scene_aspectmode="data", scene_aspectmode='data',
height=self.plot_height, height=self.plot_height,
hovermode=False, hovermode=False,
scene_camera_eye=dict(x=2.0, y=-2.0, z=1.5), scene_camera_eye=dict(x=2.0, y=-2.0, z=1.5),
) )
filename = "iso3d.html" filename = 'iso3d.html'
self.save_fig(fig, filename) self.save_fig(fig, filename)
return fig, filename return fig, filename
...@@ -189,10 +190,9 @@ class Interface(InterfaceWithExamples): ...@@ -189,10 +190,9 @@ class Interface(InterfaceWithExamples):
def remove_unused_file(self): def remove_unused_file(self):
# Remove localthickness.tif file from working directory # Remove localthickness.tif file from working directory
# as it otherwise is not deleted # as it otherwise is not deleted
os.remove("iso3d.html") os.remove('iso3d.html')
def define_interface(self, **kwargs): def define_interface(self, **kwargs):
gr.Markdown( gr.Markdown(
""" """
This tool uses Plotly Volume (https://plotly.com/python/3d-volume-plots/) to create iso surfaces from voxels based on their intensity levels. This tool uses Plotly Volume (https://plotly.com/python/3d-volume-plots/) to create iso surfaces from voxels based on their intensity levels.
...@@ -203,117 +203,111 @@ class Interface(InterfaceWithExamples): ...@@ -203,117 +203,111 @@ class Interface(InterfaceWithExamples):
with gr.Row(): with gr.Row():
# Input and parameters column # Input and parameters column
with gr.Column(scale=1, min_width=320): with gr.Column(scale=1, min_width=320):
with gr.Tab("Input"): with gr.Tab('Input'):
# File loader # File loader
gradio_file = gr.File( gradio_file = gr.File(show_label=False)
show_label=False with gr.Tab('Examples'):
)
with gr.Tab("Examples"):
gr.Examples(examples=self.img_examples, inputs=gradio_file) gr.Examples(examples=self.img_examples, inputs=gradio_file)
# Run button # Run button
with gr.Row(): with gr.Row():
with gr.Column(scale=3, min_width=64): with gr.Column(scale=3, min_width=64):
btn_run = gr.Button( btn_run = gr.Button(
value="Run 3D visualization", variant = "primary" value='Run 3D visualization', variant='primary'
) )
with gr.Column(scale=1, min_width=64): with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button( btn_clear = gr.Button(value='Clear', variant='stop')
value="Clear", variant = "stop"
)
with gr.Tab("Display"): with gr.Tab('Display'):
# Display options # Display options
display_size = gr.Slider( display_size = gr.Slider(
32, 32,
128, 128,
step=4, step=4,
label="Display resolution", label='Display resolution',
info="Number of voxels for the largest dimension", info='Number of voxels for the largest dimension',
value=64, value=64,
) )
surface_count = gr.Slider( surface_count = gr.Slider(
2, 16, step=1, label="Total iso-surfaces", value=6 2, 16, step=1, label='Total iso-surfaces', value=6
) )
show_caps = gr.Checkbox(value=False, label="Show surface caps") show_caps = gr.Checkbox(value=False, label='Show surface caps')
with gr.Row(): with gr.Row():
opacityscale = gr.Dropdown( opacityscale = gr.Dropdown(
choices=["uniform", "extremes", "min", "max"], choices=['uniform', 'extremes', 'min', 'max'],
value="uniform", value='uniform',
label="Opacity scale", label='Opacity scale',
info="Handles opacity acording to voxel value", info='Handles opacity acording to voxel value',
) )
opacity = gr.Slider( opacity = gr.Slider(
0.0, 1.0, step=0.1, label="Max opacity", value=0.4 0.0, 1.0, step=0.1, label='Max opacity', value=0.4
) )
with gr.Row(): with gr.Row():
min_value = gr.Slider( min_value = gr.Slider(
0.0, 1.0, step=0.05, label="Min value", value=0.1 0.0, 1.0, step=0.05, label='Min value', value=0.1
) )
max_value = gr.Slider( max_value = gr.Slider(
0.0, 1.0, step=0.05, label="Max value", value=1 0.0, 1.0, step=0.05, label='Max value', value=1
) )
with gr.Tab("Slices") as slices: with gr.Tab('Slices') as slices:
show_z_slice = gr.Checkbox(value=False, label="Show Z slice") show_z_slice = gr.Checkbox(value=False, label='Show Z slice')
slice_z_location = gr.Slider( slice_z_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position" 0.0, 1.0, step=0.05, value=0.5, label='Position'
) )
show_y_slice = gr.Checkbox(value=False, label="Show Y slice") show_y_slice = gr.Checkbox(value=False, label='Show Y slice')
slice_y_location = gr.Slider( slice_y_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position" 0.0, 1.0, step=0.05, value=0.5, label='Position'
) )
show_x_slice = gr.Checkbox(value=False, label="Show X slice") show_x_slice = gr.Checkbox(value=False, label='Show X slice')
slice_x_location = gr.Slider( slice_x_location = gr.Slider(
0.0, 1.0, step=0.05, value=0.5, label="Position" 0.0, 1.0, step=0.05, value=0.5, label='Position'
) )
with gr.Tab("Misc"): with gr.Tab('Misc'):
with gr.Row(): with gr.Row():
colormap = gr.Dropdown( colormap = gr.Dropdown(
choices=[ choices=[
"Blackbody", 'Blackbody',
"Bluered", 'Bluered',
"Blues", 'Blues',
"Cividis", 'Cividis',
"Earth", 'Earth',
"Electric", 'Electric',
"Greens", 'Greens',
"Greys", 'Greys',
"Hot", 'Hot',
"Jet", 'Jet',
"Magma", 'Magma',
"Picnic", 'Picnic',
"Portland", 'Portland',
"Rainbow", 'Rainbow',
"RdBu", 'RdBu',
"Reds", 'Reds',
"Viridis", 'Viridis',
"YlGnBu", 'YlGnBu',
"YlOrRd", 'YlOrRd',
], ],
value="Magma", value='Magma',
label="Colormap", label='Colormap',
) )
show_colorbar = gr.Checkbox( show_colorbar = gr.Checkbox(
value=False, label="Show color scale" value=False, label='Show color scale'
) )
reversescale = gr.Checkbox( reversescale = gr.Checkbox(
value=False, label="Reverse color scale" value=False, label='Reverse color scale'
)
flip_z = gr.Checkbox(value=True, label="Flip Z axis")
show_axis = gr.Checkbox(value=True, label="Show axis")
show_ticks = gr.Checkbox(value=False, label="Show ticks")
only_wireframe = gr.Checkbox(
value=False, label="Only wireframe"
) )
flip_z = gr.Checkbox(value=True, label='Flip Z axis')
show_axis = gr.Checkbox(value=True, label='Show axis')
show_ticks = gr.Checkbox(value=False, label='Show ticks')
only_wireframe = gr.Checkbox(value=False, label='Only wireframe')
# Inputs for gradio # Inputs for gradio
inputs = [ inputs = [
...@@ -346,7 +340,7 @@ class Interface(InterfaceWithExamples): ...@@ -346,7 +340,7 @@ class Interface(InterfaceWithExamples):
plot_download = gr.File( plot_download = gr.File(
interactive=False, interactive=False,
label="Download interactive plot", label='Download interactive plot',
show_label=True, show_label=True,
visible=False, visible=False,
) )
...@@ -367,5 +361,6 @@ class Interface(InterfaceWithExamples): ...@@ -367,5 +361,6 @@ class Interface(InterfaceWithExamples):
fn=self.remove_unused_file).success( fn=self.remove_unused_file).success(
fn=self.set_visible, inputs=None, outputs=plot_download) fn=self.set_visible, inputs=None, outputs=plot_download)
if __name__ == "__main__":
if __name__ == '__main__':
Interface().run_interface() Interface().run_interface()
This diff is collapsed.
...@@ -32,29 +32,31 @@ app.launch() ...@@ -32,29 +32,31 @@ app.launch()
``` ```
""" """
import os import os
import gradio as gr
import localthickness as lt
# matplotlib.use("Agg") # matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import gradio as gr
import numpy as np import numpy as np
import tifffile import tifffile
import localthickness as lt
import qim3d
import qim3d
class Interface(qim3d.gui.interface.InterfaceWithExamples): class Interface(qim3d.gui.interface.InterfaceWithExamples):
def __init__(self, def __init__(
self,
img: np.ndarray = None, img: np.ndarray = None,
verbose: bool = False, verbose: bool = False,
plot_height: int = 768, plot_height: int = 768,
figsize:int = 6): figsize: int = 6,
):
super().__init__(title = "Local thickness", super().__init__(
height = 1024, title='Local thickness', height=1024, width=960, verbose=verbose
width = 960, )
verbose = verbose)
self.plot_height = plot_height self.plot_height = plot_height
self.figsize = figsize self.figsize = figsize
...@@ -64,7 +66,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -64,7 +66,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# Get the temporary files from gradio # Get the temporary files from gradio
temp_sets = self.interface.temp_file_sets temp_sets = self.interface.temp_file_sets
for temp_set in temp_sets: for temp_set in temp_sets:
if "localthickness" in str(temp_set): if 'localthickness' in str(temp_set):
# Get the lsit of the temporary files # Get the lsit of the temporary files
temp_path_list = list(temp_set) temp_path_list = list(temp_set)
...@@ -84,7 +86,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -84,7 +86,7 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
def define_interface(self): def define_interface(self):
gr.Markdown( gr.Markdown(
"Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)" 'Interface for _Fast local thickness in 3D_ (https://github.com/vedranaa/local-thickness)'
) )
with gr.Row(): with gr.Row():
...@@ -92,12 +94,12 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -92,12 +94,12 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
if self.img is not None: if self.img is not None:
data = gr.State(value=self.img) data = gr.State(value=self.img)
else: else:
with gr.Tab("Input"): with gr.Tab('Input'):
data = gr.File( data = gr.File(
show_label=False, show_label=False,
value=self.img, value=self.img,
) )
with gr.Tab("Examples"): with gr.Tab('Examples'):
gr.Examples(examples=self.img_examples, inputs=data) gr.Examples(examples=self.img_examples, inputs=data)
with gr.Row(): with gr.Row():
...@@ -106,17 +108,15 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -106,17 +108,15 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
maximum=1, maximum=1,
value=0.5, value=0.5,
step=0.01, step=0.01,
label="Z position", label='Z position',
info="Local thickness is calculated in 3D, this slider controls the visualization only.", info='Local thickness is calculated in 3D, this slider controls the visualization only.',
) )
with gr.Tab("Parameters"): with gr.Tab('Parameters'):
gr.Markdown( gr.Markdown(
"It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results." 'It is possible to scale down the image before processing. Lower values will make the algorithm run faster, but decreases the accuracy of results.'
)
lt_scale = gr.Slider(
0.1, 1.0, label="Scale", value=0.5, step=0.1
) )
lt_scale = gr.Slider(0.1, 1.0, label='Scale', value=0.5, step=0.1)
with gr.Row(): with gr.Row():
threshold = gr.Slider( threshold = gr.Slider(
...@@ -124,85 +124,83 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -124,85 +124,83 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
1.0, 1.0,
value=0.5, value=0.5,
step=0.05, step=0.05,
label="Threshold", label='Threshold',
info="Local thickness uses a binary image, so a threshold value is needed.", info='Local thickness uses a binary image, so a threshold value is needed.',
) )
dark_objects = gr.Checkbox( dark_objects = gr.Checkbox(
value=False, value=False,
label="Dark objects", label='Dark objects',
info="Inverts the image before thresholding. Use in case your foreground is darker than the background.", info='Inverts the image before thresholding. Use in case your foreground is darker than the background.',
) )
with gr.Tab("Display options"): with gr.Tab('Display options'):
cmap_original = gr.Dropdown( cmap_original = gr.Dropdown(
value="viridis", value='viridis',
choices=plt.colormaps(), choices=plt.colormaps(),
label="Colormap - input", label='Colormap - input',
interactive=True, interactive=True,
) )
cmap_lt = gr.Dropdown( cmap_lt = gr.Dropdown(
value="magma", value='magma',
choices=plt.colormaps(), choices=plt.colormaps(),
label="Colormap - local thickness", label='Colormap - local thickness',
interactive=True, interactive=True,
) )
nbins = gr.Slider( nbins = gr.Slider(5, 50, value=25, step=1, label='Histogram bins')
5, 50, value=25, step=1, label="Histogram bins"
)
# Run button # Run button
with gr.Row(): with gr.Row():
with gr.Column(scale=3, min_width=64): with gr.Column(scale=3, min_width=64):
btn = gr.Button( btn = gr.Button('Run local thickness', variant='primary')
"Run local thickness", variant = "primary"
)
with gr.Column(scale=1, min_width=64): with gr.Column(scale=1, min_width=64):
btn_clear = gr.Button("Clear", variant = "stop") btn_clear = gr.Button('Clear', variant='stop')
with gr.Column(scale=4): with gr.Column(scale=4):
def create_uniform_image(intensity=1): def create_uniform_image(intensity=1):
""" """
Generates a blank image with a single color. Generates a blank image with a single color.
Gradio `gr.Plot` components will flicker if there is no default value. Gradio `gr.Plot` components will flicker if there is no default value.
bug fix on gradio version 4.44.0 bug fix on gradio version 4.44.0
""" """
pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(intensity * 255) pixels = np.zeros((100, 100, 3), dtype=np.uint8) + int(
intensity * 255
)
fig, ax = plt.subplots(figsize=(10, 10)) fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(pixels, interpolation="nearest") ax.imshow(pixels, interpolation='nearest')
# Adjustments # Adjustments
ax.axis("off") ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1) fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig return fig
with gr.Row(): with gr.Row():
input_vol = gr.Plot( input_vol = gr.Plot(
show_label=True, show_label=True,
label="Original", label='Original',
visible=True, visible=True,
value=create_uniform_image(), value=create_uniform_image(),
) )
binary_vol = gr.Plot( binary_vol = gr.Plot(
show_label=True, show_label=True,
label="Binary", label='Binary',
visible=True, visible=True,
value=create_uniform_image(), value=create_uniform_image(),
) )
output_vol = gr.Plot( output_vol = gr.Plot(
show_label=True, show_label=True,
label="Local thickness", label='Local thickness',
visible=True, visible=True,
value=create_uniform_image(), value=create_uniform_image(),
) )
with gr.Row(): with gr.Row():
histogram = gr.Plot( histogram = gr.Plot(
show_label=True, show_label=True,
label="Thickness histogram", label='Thickness histogram',
visible=True, visible=True,
value=create_uniform_image(), value=create_uniform_image(),
) )
...@@ -210,11 +208,10 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -210,11 +208,10 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
lt_output = gr.File( lt_output = gr.File(
interactive=False, interactive=False,
show_label=True, show_label=True,
label="Output file", label='Output file',
visible=False, visible=False,
) )
# Run button # Run button
# fmt: off # fmt: off
viz_input = lambda zpos, cmap: self.show_slice(self.vol, zpos, self.vmin, self.vmax, cmap) viz_input = lambda zpos, cmap: self.show_slice(self.vol, zpos, self.vmin, self.vmax, cmap)
...@@ -274,7 +271,9 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -274,7 +271,9 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
except AttributeError: except AttributeError:
self.vol = data self.vol = data
except AssertionError: except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") raise gr.Error(
f"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}"
)
if dark_objects: if dark_objects:
self.vol = np.invert(self.vol) self.vol = np.invert(self.vol)
...@@ -283,15 +282,22 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -283,15 +282,22 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
self.vmin = np.min(self.vol) self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol) self.vmax = np.max(self.vol)
def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"): def show_slice(
self,
vol: np.ndarray,
zpos: int,
vmin: float = None,
vmax: float = None,
cmap: str = 'viridis',
):
plt.close() plt.close()
z_idx = int(zpos * (vol.shape[0] - 1)) z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(vol[z_idx], interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) ax.imshow(vol[z_idx], interpolation='nearest', cmap=cmap, vmin=vmin, vmax=vmax)
# Adjustments # Adjustments
ax.axis("off") ax.axis('off')
fig.subplots_adjust(left=0, right=1, bottom=0, top=1) fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return fig return fig
...@@ -318,20 +324,20 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -318,20 +324,20 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
ax.bar( ax.bar(
bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec="white", align="edge" bin_edges[:-1], vol_hist, width=np.diff(bin_edges), ec='white', align='edge'
) )
# Adjustments # Adjustments
ax.spines["right"].set_visible(False) ax.spines['right'].set_visible(False)
ax.spines["top"].set_visible(False) ax.spines['top'].set_visible(False)
ax.spines["left"].set_visible(True) ax.spines['left'].set_visible(True)
ax.spines["bottom"].set_visible(True) ax.spines['bottom'].set_visible(True)
ax.set_yscale("log") ax.set_yscale('log')
return fig return fig
def save_lt(self): def save_lt(self):
filename = "localthickness.tif" filename = 'localthickness.tif'
# Save output image in a temp space # Save output image in a temp space
tifffile.imwrite(filename, self.vol_thickness) tifffile.imwrite(filename, self.vol_thickness)
...@@ -342,5 +348,6 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples): ...@@ -342,5 +348,6 @@ class Interface(qim3d.gui.interface.InterfaceWithExamples):
# as it otherwise is not deleted # as it otherwise is not deleted
os.remove('localthickness.tif') os.remove('localthickness.tif')
if __name__ == "__main__":
if __name__ == '__main__':
Interface().run_interface() Interface().run_interface()
import gradio as gr import gradio as gr
class QimTheme(gr.themes.Default): class QimTheme(gr.themes.Default):
""" """
Theme for qim3d gradio interfaces. Theme for qim3d gradio interfaces.
The theming options are quite broad. However if there is something you can not achieve with this theme The theming options are quite broad. However if there is something you can not achieve with this theme
there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
in comments. in comments.
""" """
def __init__(self, force_light_mode: bool = True): def __init__(self, force_light_mode: bool = True):
""" """
Parameters: Parameters
----------- ----------
- force_light_mode (bool, optional): Gradio themes have dark mode by default. - force_light_mode (bool, optional): Gradio themes have dark mode by default.
QIM platform is not ready for dark mode yet, thus the tools should also be in light mode. QIM platform is not ready for dark mode yet, thus the tools should also be in light mode.
This sets the darkmode values to be the same as light mode values. This sets the darkmode values to be the same as light mode values.
""" """
super().__init__() super().__init__()
self.force_light_mode = force_light_mode self.force_light_mode = force_light_mode
...@@ -34,8 +38,14 @@ class QimTheme(gr.themes.Default): ...@@ -34,8 +38,14 @@ class QimTheme(gr.themes.Default):
def set_dark_mode_values(self): def set_dark_mode_values(self):
if self.force_light_mode: if self.force_light_mode:
for attr in [dark_attr for dark_attr in dir(self) if not dark_attr.startswith("_") and dark_attr.endswith("dark")]: for attr in [
self.__dict__[attr] = self.__dict__[attr[:-5]] # ligth and dark attributes have same names except for '_dark' at the end dark_attr
for dark_attr in dir(self)
if not dark_attr.startswith('_') and dark_attr.endswith('dark')
]:
self.__dict__[attr] = self.__dict__[
attr[:-5]
] # ligth and dark attributes have same names except for '_dark' at the end
else: else:
self.set_dark_primary_button() self.set_dark_primary_button()
# Secondary button looks good by default in dark mode # Secondary button looks good by default in dark mode
...@@ -44,26 +54,28 @@ class QimTheme(gr.themes.Default): ...@@ -44,26 +54,28 @@ class QimTheme(gr.themes.Default):
# Example looks good by default in dark mode # Example looks good by default in dark mode
def set_button(self): def set_button(self):
self.button_transition = "0.15s" self.button_transition = '0.15s'
self.button_large_text_weight = "normal" self.button_large_text_weight = 'normal'
def set_light_primary_button(self): def set_light_primary_button(self):
self.run_color = "#198754" self.run_color = '#198754'
self.button_primary_background_fill = "#FFFFFF" self.button_primary_background_fill = '#FFFFFF'
self.button_primary_background_fill_hover = self.run_color self.button_primary_background_fill_hover = self.run_color
self.button_primary_border_color = self.run_color self.button_primary_border_color = self.run_color
self.button_primary_text_color = self.run_color self.button_primary_text_color = self.run_color
self.button_primary_text_color_hover = "#FFFFFF" self.button_primary_text_color_hover = '#FFFFFF'
def set_dark_primary_button(self): def set_dark_primary_button(self):
self.bright_run_color = "#299764" self.bright_run_color = '#299764'
self.button_primary_background_fill_dark = self.button_primary_background_fill_hover self.button_primary_background_fill_dark = (
self.button_primary_background_fill_hover
)
self.button_primary_background_fill_hover_dark = self.bright_run_color self.button_primary_background_fill_hover_dark = self.bright_run_color
self.button_primary_border_color_dark = self.button_primary_border_color self.button_primary_border_color_dark = self.button_primary_border_color
self.button_primary_border_color_hover_dark = self.bright_run_color self.button_primary_border_color_hover_dark = self.bright_run_color
def set_light_secondary_button(self): def set_light_secondary_button(self):
self.button_secondary_background_fill = "white" self.button_secondary_background_fill = 'white'
def set_light_example(self): def set_light_example(self):
""" """
...@@ -73,10 +85,10 @@ class QimTheme(gr.themes.Default): ...@@ -73,10 +85,10 @@ class QimTheme(gr.themes.Default):
self.color_accent_soft = self.neutral_100 self.color_accent_soft = self.neutral_100
def set_h1(self): def set_h1(self):
self.text_xxl = "2.5rem" self.text_xxl = '2.5rem'
def set_light_checkbox(self): def set_light_checkbox(self):
light_blue = "#60a5fa" light_blue = '#60a5fa'
self.checkbox_background_color_selected = light_blue self.checkbox_background_color_selected = light_blue
self.checkbox_border_color_selected = light_blue self.checkbox_border_color_selected = light_blue
self.checkbox_border_color_focus = light_blue self.checkbox_border_color_focus = light_blue
...@@ -86,21 +98,20 @@ class QimTheme(gr.themes.Default): ...@@ -86,21 +98,20 @@ class QimTheme(gr.themes.Default):
self.checkbox_border_color_focus_dark = self.checkbox_border_color_focus_dark self.checkbox_border_color_focus_dark = self.checkbox_border_color_focus_dark
def set_light_cancel_button(self): def set_light_cancel_button(self):
self.cancel_color = "#dc3545" self.cancel_color = '#dc3545'
self.button_cancel_background_fill = "white" self.button_cancel_background_fill = 'white'
self.button_cancel_background_fill_hover = self.cancel_color self.button_cancel_background_fill_hover = self.cancel_color
self.button_cancel_border_color = self.cancel_color self.button_cancel_border_color = self.cancel_color
self.button_cancel_text_color = self.cancel_color self.button_cancel_text_color = self.cancel_color
self.button_cancel_text_color_hover = "white" self.button_cancel_text_color_hover = 'white'
def set_dark_cancel_button(self): def set_dark_cancel_button(self):
self.button_cancel_background_fill_dark = self.cancel_color self.button_cancel_background_fill_dark = self.cancel_color
self.button_cancel_background_fill_hover_dark = "red" self.button_cancel_background_fill_hover_dark = 'red'
self.button_cancel_border_color_dark = self.cancel_color self.button_cancel_border_color_dark = self.cancel_color
self.button_cancel_border_color_hover_dark = "red" self.button_cancel_border_color_hover_dark = 'red'
self.button_cancel_text_color_dark = "white" self.button_cancel_text_color_dark = 'white'
# def _get_theme_css(self): # def _get_theme_css(self):
# sup = super()._get_theme_css() # sup = super()._get_theme_css()
# return "\n.svelte-182fdeq {\nbackground: rgba(255, 0, 0, 0.5) !important;\n}\n" + sup # You have to use !important, so it overrides other css # return "\n.svelte-182fdeq {\nbackground: rgba(255, 0, 0, 0.5) !important;\n}\n" + sup # You have to use !important, so it overrides other css
\ No newline at end of file
# from ._sync import Sync # this will be added back after future development
from ._loading import load, load_mesh from ._loading import load, load_mesh
from ._downloader import Downloader from ._downloader import Downloader
from ._saving import save, save_mesh from ._saving import save, save_mesh
# from ._sync import Sync # this will be added back after future development
from ._convert import convert from ._convert import convert
from ._ome_zarr import export_ome_zarr, import_ome_zarr from ._ome_zarr import export_ome_zarr, import_ome_zarr
...@@ -6,21 +6,24 @@ import nibabel as nib ...@@ -6,21 +6,24 @@ import nibabel as nib
import numpy as np import numpy as np
import tifffile as tiff import tifffile as tiff
import zarr import zarr
from tqdm import tqdm
import zarr.core import zarr.core
import qim3d
from tqdm import tqdm
from qim3d.utils._misc import stringify_path from qim3d.utils._misc import stringify_path
from qim3d.io import save
class Convert: class Convert:
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Utility class to convert files to other formats without loading the entire file into memory """
Utility class to convert files to other formats without loading the entire file into memory
Args: Args:
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64). chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
""" """
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64)) self.chunk_shape = kwargs.get('chunk_shape', (64, 64, 64))
def convert(self, input_path: str, output_path: str): def convert(self, input_path: str, output_path: str):
def get_file_extension(file_path): def get_file_extension(file_path):
...@@ -29,6 +32,7 @@ class Convert: ...@@ -29,6 +32,7 @@ class Convert:
root, ext2 = os.path.splitext(root) root, ext2 = os.path.splitext(root)
ext = ext2 + ext ext = ext2 + ext
return ext return ext
# Stringify path in case it is not already a string # Stringify path in case it is not already a string
input_path = stringify_path(input_path) input_path = stringify_path(input_path)
input_ext = get_file_extension(input_path) input_ext = get_file_extension(input_path)
...@@ -37,28 +41,30 @@ class Convert: ...@@ -37,28 +41,30 @@ class Convert:
if os.path.isfile(input_path): if os.path.isfile(input_path):
match input_ext, output_ext: match input_ext, output_ext:
case (".tif", ".zarr") | (".tiff", ".zarr"): case ('.tif', '.zarr') | ('.tiff', '.zarr'):
return self.convert_tif_to_zarr(input_path, output_path) return self.convert_tif_to_zarr(input_path, output_path)
case (".nii", ".zarr") | (".nii.gz", ".zarr"): case ('.nii', '.zarr') | ('.nii.gz', '.zarr'):
return self.convert_nifti_to_zarr(input_path, output_path) return self.convert_nifti_to_zarr(input_path, output_path)
case _: case _:
raise ValueError("Unsupported file format") raise ValueError('Unsupported file format')
# Load a directory # Load a directory
elif os.path.isdir(input_path): elif os.path.isdir(input_path):
match input_ext, output_ext: match input_ext, output_ext:
case (".zarr", ".tif") | (".zarr", ".tiff"): case ('.zarr', '.tif') | ('.zarr', '.tiff'):
return self.convert_zarr_to_tif(input_path, output_path) return self.convert_zarr_to_tif(input_path, output_path)
case (".zarr", ".nii"): case ('.zarr', '.nii'):
return self.convert_zarr_to_nifti(input_path, output_path) return self.convert_zarr_to_nifti(input_path, output_path)
case (".zarr", ".nii.gz"): case ('.zarr', '.nii.gz'):
return self.convert_zarr_to_nifti(input_path, output_path, compression=True) return self.convert_zarr_to_nifti(
input_path, output_path, compression=True
)
case _: case _:
raise ValueError("Unsupported file format") raise ValueError('Unsupported file format')
# Fail # Fail
else: else:
# Find the closest matching path to warn the user # Find the closest matching path to warn the user
parent_dir = os.path.dirname(input_path) or "." parent_dir = os.path.dirname(input_path) or '.'
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ''
valid_paths = [os.path.join(parent_dir, file) for file in parent_files] valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
similar_paths = difflib.get_close_matches(input_path, valid_paths) similar_paths = difflib.get_close_matches(input_path, valid_paths)
if similar_paths: if similar_paths:
...@@ -66,10 +72,11 @@ class Convert: ...@@ -66,10 +72,11 @@ class Convert:
message = f"Invalid path. Did you mean '{suggestion}'?" message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message)) raise ValueError(repr(message))
else: else:
raise ValueError("Invalid path") raise ValueError('Invalid path')
def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array: def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file """
Convert a tiff file to a zarr file
Args: Args:
tif_path (str): path to the tiff file tif_path (str): path to the tiff file
...@@ -77,10 +84,15 @@ class Convert: ...@@ -77,10 +84,15 @@ class Convert:
Returns: Returns:
zarr.core.Array: zarr array containing the data from the tiff file zarr.core.Array: zarr array containing the data from the tiff file
""" """
vol = tiff.memmap(tif_path) vol = tiff.memmap(tif_path)
z = zarr.open( z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
) )
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar # ! Fastest way is z[:] = vol[:], but does not have a progress bar
...@@ -98,7 +110,8 @@ class Convert: ...@@ -98,7 +110,8 @@ class Convert:
return z return z
def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None: def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file """
Convert a zarr file to a tiff file
Args: Args:
zarr_path (str): path to the zarr file zarr_path (str): path to the zarr file
...@@ -106,12 +119,14 @@ class Convert: ...@@ -106,12 +119,14 @@ class Convert:
returns: returns:
None None
""" """
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(tif_path, z) qim3d.io.save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array: def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file """
Convert a nifti file to a zarr file
Args: Args:
nifti_path (str): path to the nifti file nifti_path (str): path to the nifti file
...@@ -119,10 +134,15 @@ class Convert: ...@@ -119,10 +134,15 @@ class Convert:
Returns: Returns:
zarr.core.Array: zarr array containing the data from the nifti file zarr.core.Array: zarr array containing the data from the nifti file
""" """
vol = nib.load(nifti_path).dataobj vol = nib.load(nifti_path).dataobj
z = zarr.open( z = zarr.open(
zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype zarr_path,
mode='w',
shape=vol.shape,
chunks=self.chunk_shape,
dtype=vol.dtype,
) )
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar # ! Fastest way is z[:] = vol[:], but does not have a progress bar
...@@ -139,8 +159,11 @@ class Convert: ...@@ -139,8 +159,11 @@ class Convert:
return z return z
def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None: def convert_zarr_to_nifti(
"""Convert a zarr file to a nifti file self, zarr_path: str, nifti_path: str, compression: bool = False
) -> None:
"""
Convert a zarr file to a nifti file
Args: Args:
zarr_path (str): path to the zarr file zarr_path (str): path to the zarr file
...@@ -148,18 +171,23 @@ class Convert: ...@@ -148,18 +171,23 @@ class Convert:
Returns: Returns:
None None
""" """
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(nifti_path, z, compression=compression) qim3d.io.save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None: def convert(
"""Convert a file to another format without loading the entire file into memory input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)
) -> None:
"""
Convert a file to another format without loading the entire file into memory
Args: Args:
input_path (str): path to the input file input_path (str): path to the input file
output_path (str): path to the output file output_path (str): path to the output file
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64). chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
""" """
converter = Convert(chunk_shape=chunk_shape) converter = Convert(chunk_shape=chunk_shape)
converter.convert(input_path, output_path) converter.convert(input_path, output_path)
"Manages downloads and access to data" """Manages downloads and access to data"""
import os import os
import urllib.request import urllib.request
from urllib.parse import quote from urllib.parse import quote
import outputformat as ouf
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
from qim3d.io import load from qim3d.io import load
from qim3d.utils import log from qim3d.utils import log
import outputformat as ouf
class Downloader: class Downloader:
"""Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
"""
Class for downloading large data files available on the [QIM data repository](https://data.qim.dk/).
Attributes: Attributes:
folder_name (str or os.PathLike): Folder class with the name of the folder in <https://data.qim.dk/> folder_name (str or os.PathLike): Folder class with the name of the folder in <https://data.qim.dk/>
...@@ -59,17 +60,18 @@ class Downloader: ...@@ -59,17 +60,18 @@ class Downloader:
qim3d.viz.slicer_orthogonal(data, color_map="magma") qim3d.viz.slicer_orthogonal(data, color_map="magma")
``` ```
![cowry shell](../../assets/screenshots/cowry_shell_slicer.gif) ![cowry shell](../../assets/screenshots/cowry_shell_slicer.gif)
""" """
def __init__(self): def __init__(self):
folders = _extract_names() folders = _extract_names()
for idx, folder in enumerate(folders): for idx, folder in enumerate(folders):
exec(f"self.{folder} = self._Myfolder(folder)") exec(f'self.{folder} = self._Myfolder(folder)')
def list_files(self): def list_files(self):
"""Generate and print formatted folder, file, and size information.""" """Generate and print formatted folder, file, and size information."""
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository" url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
folders = _extract_names() folders = _extract_names()
...@@ -78,17 +80,20 @@ class Downloader: ...@@ -78,17 +80,20 @@ class Downloader:
files = _extract_names(folder) files = _extract_names(folder)
for file in files: for file in files:
url = os.path.join(url_dl, folder, file).replace("\\", "/") url = os.path.join(url_dl, folder, file).replace('\\', '/')
file_size = _get_file_size(url) file_size = _get_file_size(url)
formatted_file = f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}" formatted_file = (
f"{file[:-len(file.split('.')[-1])-1].replace('%20', '_')}"
)
formatted_size = _format_file_size(file_size) formatted_size = _format_file_size(file_size)
path_string = f'{folder}.{formatted_file}' path_string = f'{folder}.{formatted_file}'
log.info(f'{path_string:<50}({formatted_size})') log.info(f'{path_string:<50}({formatted_size})')
class _Myfolder: class _Myfolder:
"""Class for extracting the files from each folder in the Downloader class.
"""
Class for extracting the files from each folder in the Downloader class.
Args: Args:
folder(str): name of the folder of interest in the QIM data repository. folder(str): name of the folder of interest in the QIM data repository.
...@@ -99,6 +104,7 @@ class Downloader: ...@@ -99,6 +104,7 @@ class Downloader:
[file_name_2](load_file,optional): Function to download file number 2 in the given folder. [file_name_2](load_file,optional): Function to download file number 2 in the given folder.
... ...
[file_name_n](load_file,optional): Function to download file number n in the given folder. [file_name_n](load_file,optional): Function to download file number n in the given folder.
""" """
def __init__(self, folder: str): def __init__(self, folder: str):
...@@ -107,14 +113,15 @@ class Downloader: ...@@ -107,14 +113,15 @@ class Downloader:
for idx, file in enumerate(files): for idx, file in enumerate(files):
# Changes names to usable function name. # Changes names to usable function name.
file_name = file file_name = file
if ("%20" in file) or ("-" in file): if ('%20' in file) or ('-' in file):
file_name = file_name.replace("%20", "_") file_name = file_name.replace('%20', '_')
file_name = file_name.replace("-", "_") file_name = file_name.replace('-', '_')
setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file)) setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
def _make_fn(self, folder: str, file: str): def _make_fn(self, folder: str, file: str):
"""Private method that returns a function. The function downloads the chosen file from the folder. """
Private method that returns a function. The function downloads the chosen file from the folder.
Args: Args:
folder(str): Folder where the file is located. folder(str): Folder where the file is located.
...@@ -122,23 +129,26 @@ class Downloader: ...@@ -122,23 +129,26 @@ class Downloader:
Returns: Returns:
function: the function used to download the file. function: the function used to download the file.
""" """
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository" url_dl = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository'
def _download(load_file: bool = False, virtual_stack: bool = True): def _download(load_file: bool = False, virtual_stack: bool = True):
"""Downloads the file and optionally also loads it. """
Downloads the file and optionally also loads it.
Args: Args:
load_file(bool,optional): Whether to simply download or also load the file. load_file(bool,optional): Whether to simply download or also load the file.
Returns: Returns:
virtual_stack: The loaded image. virtual_stack: The loaded image.
""" """
download_file(url_dl, folder, file) download_file(url_dl, folder, file)
if load_file == True: if load_file == True:
log.info(f"\nLoading {file}") log.info(f'\nLoading {file}')
file_path = os.path.join(folder, file) file_path = os.path.join(folder, file)
return load(path=file_path, virtual_stack=virtual_stack) return load(path=file_path, virtual_stack=virtual_stack)
...@@ -159,38 +169,40 @@ def _get_file_size(url: str): ...@@ -159,38 +169,40 @@ def _get_file_size(url: str):
Helper function for the ´download_file()´ function. Finds the size of the file. Helper function for the ´download_file()´ function. Finds the size of the file.
""" """
return int(urllib.request.urlopen(url).info().get("Content-Length", -1)) return int(urllib.request.urlopen(url).info().get('Content-Length', -1))
def download_file(path: str, name: str, file: str): def download_file(path: str, name: str, file: str):
"""Downloads the file from path / name / file. """
Downloads the file from path / name / file.
Args: Args:
path(str): path to the folders available. path(str): path to the folders available.
name(str): name of the folder of interest. name(str): name of the folder of interest.
file(str): name of the file to be downloaded. file(str): name of the file to be downloaded.
""" """
if not os.path.exists(name): if not os.path.exists(name):
os.makedirs(name) os.makedirs(name)
url = os.path.join(path, name, file).replace("\\", "/") # if user is on windows url = os.path.join(path, name, file).replace('\\', '/') # if user is on windows
file_path = os.path.join(name, file) file_path = os.path.join(name, file)
if os.path.exists(file_path): if os.path.exists(file_path):
log.warning(f"File already downloaded:\n{os.path.abspath(file_path)}") log.warning(f'File already downloaded:\n{os.path.abspath(file_path)}')
return return
else: else:
log.info( log.info(
f"Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}" f'Downloading {ouf.b(file, return_str=True)}\n{os.path.join(path,name,file)}'
) )
if " " in url: if ' ' in url:
url = quote(url, safe=":/") url = quote(url, safe=':/')
with tqdm( with tqdm(
total=_get_file_size(url), total=_get_file_size(url),
unit="B", unit='B',
unit_scale=True, unit_scale=True,
unit_divisor=1024, unit_divisor=1024,
ncols=80, ncols=80,
...@@ -203,28 +215,31 @@ def download_file(path: str, name: str, file: str): ...@@ -203,28 +215,31 @@ def download_file(path: str, name: str, file: str):
def _extract_html(url: str): def _extract_html(url: str):
"""Extracts the html content of a webpage in "utf-8" """
Extracts the html content of a webpage in "utf-8"
Args: Args:
url(str): url to the location where all the data is stored. url(str): url to the location where all the data is stored.
Returns: Returns:
html_content(str): decoded html. html_content(str): decoded html.
""" """
try: try:
with urllib.request.urlopen(url) as response: with urllib.request.urlopen(url) as response:
html_content = response.read().decode( html_content = response.read().decode(
"utf-8" 'utf-8'
) # Assuming the content is in UTF-8 encoding ) # Assuming the content is in UTF-8 encoding
except urllib.error.URLError as e: except urllib.error.URLError as e:
log.warning(f"Failed to retrieve data from {url}. Error: {e}") log.warning(f'Failed to retrieve data from {url}. Error: {e}')
return html_content return html_content
def _extract_names(name: str = None): def _extract_names(name: str = None):
"""Extracts the names of the folders and files. """
Extracts the names of the folders and files.
Finds the names of either the folders if no name is given, Finds the names of either the folders if no name is given,
or all the names of all files in the given folder. or all the names of all files in the given folder.
...@@ -235,31 +250,33 @@ def _extract_names(name: str = None): ...@@ -235,31 +250,33 @@ def _extract_names(name: str = None):
Returns: Returns:
list: If name is None, returns a list of all folders available. list: If name is None, returns a list of all folders available.
If name is not None, returns a list of all files available in the given 'name' folder. If name is not None, returns a list of all files available in the given 'name' folder.
""" """
url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository" url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository'
if name: if name:
datapath = os.path.join(url, name).replace("\\", "/") datapath = os.path.join(url, name).replace('\\', '/')
html_content = _extract_html(datapath) html_content = _extract_html(datapath)
data_split = html_content.split( data_split = html_content.split(
"files/public/projects/viscomp_data_repository/" 'files/public/projects/viscomp_data_repository/'
)[3:] )[3:]
data_files = [ data_files = [
element.split(" ")[0][(len(name) + 1) : -3] for element in data_split element.split(' ')[0][(len(name) + 1) : -3] for element in data_split
] ]
return data_files return data_files
else: else:
html_content = _extract_html(url) html_content = _extract_html(url)
split = html_content.split('"icon-folder-open">')[2:] split = html_content.split('"icon-folder-open">')[2:]
folders = [element.split(" ")[0][4:-4] for element in split] folders = [element.split(' ')[0][4:-4] for element in split]
return folders return folders
def _format_file_size(size_in_bytes): def _format_file_size(size_in_bytes):
# Define size units # Define size units
units = ["B", "KB", "MB", "GB", "TB", "PB"] units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
size = float(size_in_bytes) size = float(size_in_bytes)
unit_index = 0 unit_index = 0
...@@ -269,4 +286,4 @@ def _format_file_size(size_in_bytes): ...@@ -269,4 +286,4 @@ def _format_file_size(size_in_bytes):
unit_index += 1 unit_index += 1
# Format the size with 1 decimal place # Format the size with 1 decimal place
return f"{size:.2f}{units[unit_index]}" return f'{size:.2f}{units[unit_index]}'
This diff is collapsed.
...@@ -2,39 +2,27 @@ ...@@ -2,39 +2,27 @@
Exporting data to different formats. Exporting data to different formats.
""" """
import os
import math import math
import os
import shutil import shutil
import logging from typing import List, Union
import dask.array as da
import numpy as np import numpy as np
import zarr import zarr
import tqdm from ome_zarr import scale
from ome_zarr.io import parse_url from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from ome_zarr.scale import dask_resize
from ome_zarr.writer import ( from ome_zarr.writer import (
write_image,
_create_mip,
write_multiscale,
CurrentFormat, CurrentFormat,
Format, write_multiscale,
) )
from ome_zarr.scale import dask_resize
from ome_zarr.reader import Reader
from ome_zarr import scale
from scipy.ndimage import zoom from scipy.ndimage import zoom
from typing import Any, Callable, Iterator, List, Tuple, Union
import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from skimage.transform import (
resize,
)
from qim3d.utils import log from qim3d.utils import log
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
from qim3d.utils._ome_zarr import get_n_chunks from qim3d.utils._ome_zarr import get_n_chunks
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
ListOfArrayLike = Union[List[da.Array], List[np.ndarray]] ListOfArrayLike = Union[List[da.Array], List[np.ndarray]]
ArrayLike = Union[da.Array, np.ndarray] ArrayLike = Union[da.Array, np.ndarray]
...@@ -43,10 +31,19 @@ ArrayLike = Union[da.Array, np.ndarray] ...@@ -43,10 +31,19 @@ ArrayLike = Union[da.Array, np.ndarray]
class OMEScaler( class OMEScaler(
scale.Scaler, scale.Scaler,
): ):
"""Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken."""
def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"): """
Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken.
"""
def __init__(
self,
order: int = 0,
downscale: float = 2,
max_layer: int = 5,
method: str = 'scaleZYXdask',
):
self.order = order self.order = order
self.downscale = downscale self.downscale = downscale
self.max_layer = max_layer self.max_layer = max_layer
...@@ -55,11 +52,11 @@ class OMEScaler( ...@@ -55,11 +52,11 @@ class OMEScaler(
def scaleZYX(self, base: da.core.Array): def scaleZYX(self, base: da.core.Array):
"""Downsample using :func:`scipy.ndimage.zoom`.""" """Downsample using :func:`scipy.ndimage.zoom`."""
rv = [base] rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer): for i in range(self.max_layer):
rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order)) rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order))
log.info(f"- Scale {i+1}: {rv[-1].shape}") log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv) return list(rv)
...@@ -82,8 +79,8 @@ class OMEScaler( ...@@ -82,8 +79,8 @@ class OMEScaler(
""" """
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape # Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library # This snippet comes from the original OME-Zarr-python library
better_chunksize = tuple( better_chunksize = tuple(
...@@ -92,7 +89,7 @@ class OMEScaler( ...@@ -92,7 +89,7 @@ class OMEScaler(
).astype(int) ).astype(int)
) )
log.debug(f"better chunk size: {better_chunksize}") log.debug(f'better chunk size: {better_chunksize}')
# Compute the chunk size after the downscaling # Compute the chunk size after the downscaling
new_chunk_size = tuple( new_chunk_size = tuple(
...@@ -100,17 +97,16 @@ class OMEScaler( ...@@ -100,17 +97,16 @@ class OMEScaler(
) )
log.debug( log.debug(
f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}" f'orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}'
) )
def resize_chunk(chunk, scale_factors, order): def resize_chunk(chunk, scale_factors, order):
# print(f"zoom factors: {scale_factors}") # print(f"zoom factors: {scale_factors}")
resized_chunk = zoom( resized_chunk = zoom(
chunk, chunk,
zoom=scale_factors, zoom=scale_factors,
order=order, order=order,
mode="grid-constant", mode='grid-constant',
grid_mode=True, grid_mode=True,
) )
# print(f"resized chunk shape: {resized_chunk.shape}") # print(f"resized chunk shape: {resized_chunk.shape}")
...@@ -121,7 +117,7 @@ class OMEScaler( ...@@ -121,7 +117,7 @@ class OMEScaler(
# Testing new shape # Testing new shape
predicted_shape = np.multiply(vol.shape, scale_factors) predicted_shape = np.multiply(vol.shape, scale_factors)
log.debug(f"predicted shape: {predicted_shape}") log.debug(f'predicted shape: {predicted_shape}')
scaled_vol = da.map_blocks( scaled_vol = da.map_blocks(
resize_chunk, resize_chunk,
vol, vol,
...@@ -136,7 +132,7 @@ class OMEScaler( ...@@ -136,7 +132,7 @@ class OMEScaler(
return scaled_vol return scaled_vol
rv = [base] rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer): for i in range(self.max_layer):
log.debug(f"\nScale {i+1}\n{'-'*32}") log.debug(f"\nScale {i+1}\n{'-'*32}")
...@@ -147,17 +143,17 @@ class OMEScaler( ...@@ -147,17 +143,17 @@ class OMEScaler(
np.ceil(np.multiply(base.shape, downscale_factor)).astype(int) np.ceil(np.multiply(base.shape, downscale_factor)).astype(int)
) )
log.debug(f"target shape: {scaled_shape}") log.debug(f'target shape: {scaled_shape}')
downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float)) downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float))
log.debug(f"downscale rate: {downscale_rate}") log.debug(f'downscale rate: {downscale_rate}')
scale_factors = tuple(np.divide(1, downscale_rate)) scale_factors = tuple(np.divide(1, downscale_rate))
log.debug(f"scale factors: {scale_factors}") log.debug(f'scale factors: {scale_factors}')
log.debug("\nResizing volume chunk-wise") log.debug('\nResizing volume chunk-wise')
scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape) scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape)
rv.append(scaled_vol) rv.append(scaled_vol)
log.info(f"- Scale {i+1}: {rv[-1].shape}") log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv) return list(rv)
...@@ -165,10 +161,9 @@ class OMEScaler( ...@@ -165,10 +161,9 @@ class OMEScaler(
"""Downsample using the original OME-Zarr python library""" """Downsample using the original OME-Zarr python library"""
rv = [base] rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f'- Scale 0: {rv[-1].shape}')
for i in range(self.max_layer): for i in range(self.max_layer):
scaled_shape = tuple( scaled_shape = tuple(
base.shape[j] // (self.downscale ** (i + 1)) for j in range(3) base.shape[j] // (self.downscale ** (i + 1)) for j in range(3)
) )
...@@ -176,7 +171,7 @@ class OMEScaler( ...@@ -176,7 +171,7 @@ class OMEScaler(
scaled = dask_resize(base, scaled_shape, order=self.order) scaled = dask_resize(base, scaled_shape, order=self.order)
rv.append(scaled) rv.append(scaled)
log.info(f"- Scale {i+1}: {rv[-1].shape}") log.info(f'- Scale {i+1}: {rv[-1].shape}')
return list(rv) return list(rv)
...@@ -187,9 +182,9 @@ def export_ome_zarr( ...@@ -187,9 +182,9 @@ def export_ome_zarr(
downsample_rate: int = 2, downsample_rate: int = 2,
order: int = 1, order: int = 1,
replace: bool = False, replace: bool = False,
method: str = "scaleZYX", method: str = 'scaleZYX',
progress_bar: bool = True, progress_bar: bool = True,
progress_bar_repeat_time: str = "auto", progress_bar_repeat_time: str = 'auto',
) -> None: ) -> None:
""" """
Export 3D image data to OME-Zarr format with pyramidal downsampling. Export 3D image data to OME-Zarr format with pyramidal downsampling.
...@@ -220,6 +215,7 @@ def export_ome_zarr( ...@@ -220,6 +215,7 @@ def export_ome_zarr(
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2) qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2)
``` ```
""" """
# Check if directory exists # Check if directory exists
...@@ -228,19 +224,19 @@ def export_ome_zarr( ...@@ -228,19 +224,19 @@ def export_ome_zarr(
shutil.rmtree(path) shutil.rmtree(path)
else: else:
raise ValueError( raise ValueError(
f"Directory {path} already exists. Use replace=True to overwrite." f'Directory {path} already exists. Use replace=True to overwrite.'
) )
# Check if downsample_rate is valid # Check if downsample_rate is valid
if downsample_rate <= 1: if downsample_rate <= 1:
raise ValueError("Downsample rate must be greater than 1.") raise ValueError('Downsample rate must be greater than 1.')
log.info(f"Exporting data to OME-Zarr format at {path}") log.info(f'Exporting data to OME-Zarr format at {path}')
# Get the number of scales # Get the number of scales
min_dim = np.max(np.shape(data)) min_dim = np.max(np.shape(data))
nscales = math.ceil(math.log(min_dim / chunk_size) / math.log(downsample_rate)) nscales = math.ceil(math.log(min_dim / chunk_size) / math.log(downsample_rate))
log.info(f"Number of scales: {nscales + 1}") log.info(f'Number of scales: {nscales + 1}')
# Create scaler # Create scaler
scaler = OMEScaler( scaler = OMEScaler(
...@@ -249,32 +245,31 @@ def export_ome_zarr( ...@@ -249,32 +245,31 @@ def export_ome_zarr(
# write the image data # write the image data
os.mkdir(path) os.mkdir(path)
store = parse_url(path, mode="w").store store = parse_url(path, mode='w').store
root = zarr.group(store=store) root = zarr.group(store=store)
# Check if we want to process using Dask # Check if we want to process using Dask
if "dask" in method and not isinstance(data, da.Array): if 'dask' in method and not isinstance(data, da.Array):
log.info("\nConverting input data to Dask array") log.info('\nConverting input data to Dask array')
data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size)) data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
elif "dask" in method and isinstance(data, da.Array): elif 'dask' in method and isinstance(data, da.Array):
log.info("\nInput data will be rechunked") log.info('\nInput data will be rechunked')
data = data.rechunk((chunk_size, chunk_size, chunk_size)) data = data.rechunk((chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") log.info(f' - shape...: {data.shape}\n - chunks..: {data.chunksize}\n')
log.info("Calculating the multi-scale pyramid") log.info('Calculating the multi-scale pyramid')
# Generate multi-scale pyramid # Generate multi-scale pyramid
mip = scaler.func(data) mip = scaler.func(data)
log.info("Writing data to disk") log.info('Writing data to disk')
kwargs = dict( kwargs = dict(
pyramid=mip, pyramid=mip,
group=root, group=root,
fmt=CurrentFormat(), fmt=CurrentFormat(),
axes="zyx", axes='zyx',
name=None, name=None,
compute=True, compute=True,
storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)), storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)),
...@@ -291,15 +286,13 @@ def export_ome_zarr( ...@@ -291,15 +286,13 @@ def export_ome_zarr(
else: else:
write_multiscale(**kwargs) write_multiscale(**kwargs)
log.info("\nAll done!") log.info('\nAll done!')
return return
def import_ome_zarr( def import_ome_zarr(
path: str|os.PathLike, path: str | os.PathLike, scale: int = 0, load: bool = True
scale: int = 0,
load: bool = True
) -> np.ndarray: ) -> np.ndarray:
""" """
Import image data from an OME-Zarr file. Import image data from an OME-Zarr file.
...@@ -339,22 +332,22 @@ def import_ome_zarr( ...@@ -339,22 +332,22 @@ def import_ome_zarr(
image_node = nodes[0] image_node = nodes[0]
dask_data = image_node.data dask_data = image_node.data
log.info(f"Data contains {len(dask_data)} scales:") log.info(f'Data contains {len(dask_data)} scales:')
for i in np.arange(len(dask_data)): for i in np.arange(len(dask_data)):
log.info(f"- Scale {i}: {dask_data[i].shape}") log.info(f'- Scale {i}: {dask_data[i].shape}')
if scale == "highest": if scale == 'highest':
scale = 0 scale = 0
if scale == "lowest": if scale == 'lowest':
scale = len(dask_data) - 1 scale = len(dask_data) - 1
if scale >= len(dask_data): if scale >= len(dask_data):
raise ValueError( raise ValueError(
f"Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}." f'Scale {scale} does not exist in the data. Please choose a scale between 0 and {len(dask_data)-1}.'
) )
log.info(f"\nLoading scale {scale} with shape {dask_data[scale].shape}") log.info(f'\nLoading scale {scale} with shape {dask_data[scale].shape}')
if load: if load:
vol = dask_data[scale].compute() vol = dask_data[scale].compute()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from ._unet import UNet, Hyperparameters from ._unet import Hyperparameters, UNet