Skip to content
Snippets Groups Projects
Commit 0d31c332 authored by fima's avatar fima :beers:
Browse files

Merge branch 'ome_zarr_export_optimization' into 'main'

Ome zarr export optimization

See merge request !125
parents 85014a4b 0053dccc
No related branches found
No related tags found
1 merge request!125Ome zarr export optimization
......@@ -5,6 +5,7 @@ Exporting data to different formats.
import os
import math
import shutil
import logging
import numpy as np
import zarr
......@@ -23,6 +24,9 @@ from ome_zarr import scale
from scipy.ndimage import zoom
from typing import Any, Callable, Iterator, List, Tuple, Union
import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from skimage.transform import (
resize,
)
......@@ -54,13 +58,112 @@ class OMEScaler(
log.info(f"- Scale 0: {rv[-1].shape}")
for i in range(self.max_layer):
downscale_ratio = (1 / self.downscale) ** (i + 1)
rv.append(zoom(base, zoom=downscale_ratio, order=self.order))
rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order))
log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv)
def scaleZYXdask(self, base):
"""Downsample using :func:`scipy.ndimage.zoom`."""
"""
Downsample a 3D volume using Dask and scipy.ndimage.zoom.
This method performs multi-scale downsampling on a 3D dataset, generating image pyramids. It processes the data in chunks using Dask.
Args:
base (dask.array): The 3D array (volume) to be downsampled. Must be a Dask array for chunked processing.
Returns:
list of dask.array: A list of downsampled volumes, where each element represents a different scale. The first element corresponds to the original resolution, and subsequent elements represent progressively downsampled versions.
The downsampling process occurs scale by scale, using the following steps:
- For each scale, the array is resized based on the downscale factor, computed as a function of the current scale level.
- The `scipy.ndimage.zoom` function is used to perform interpolation, with chunk-wise processing handled by Dask's `map_blocks` function.
- The output is rechunked to match the input volume's original chunk size.
"""
def resize_zoom(vol, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library
better_chunksize = tuple(
np.maximum(
1, np.round(np.array(vol.chunksize) * scale_factors) / scale_factors
).astype(int)
)
log.debug(f"better chunk size: {better_chunksize}")
# Compute the chunk size after the downscaling
new_chunk_size = tuple(
np.ceil(np.multiply(better_chunksize, scale_factors)).astype(int)
)
log.debug(
f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}"
)
def resize_chunk(chunk, scale_factors, order):
#print(f"zoom factors: {scale_factors}")
resized_chunk = zoom(
chunk,
zoom=scale_factors,
order=order,
mode="grid-constant",
grid_mode=True,
)
#print(f"resized chunk shape: {resized_chunk.shape}")
return resized_chunk
output_slices = tuple(slice(0, d) for d in scaled_shape)
# Testing new shape
predicted_shape = np.multiply(vol.shape, scale_factors)
log.debug(f"predicted shape: {predicted_shape}")
scaled_vol = da.map_blocks(
resize_chunk,
vol,
scale_factors,
order,
chunks=new_chunk_size,
)[output_slices]
# Rechunk the output to match the input
# This is needed because chunks were scaled down
scaled_vol = scaled_vol.rechunk(vol.chunksize)
return scaled_vol
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
for i in range(self.max_layer):
log.debug(f"\nScale {i+1}\n{'-'*32}")
# Calculate the downscale factor for this scale
downscale_factor = 1 / (self.downscale ** (i + 1))
scaled_shape = tuple(
np.ceil(np.multiply(base.shape, downscale_factor)).astype(int)
)
log.debug(f"target shape: {scaled_shape}")
downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float))
log.debug(f"downscale rate: {downscale_rate}")
scale_factors = tuple(np.divide(1, downscale_rate))
log.debug(f"scale factors: {scale_factors}")
log.debug("\nResizing volume chunk-wise")
scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape)
rv.append(scaled_vol)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv)
def scaleZYXdask_legacy(self, base):
"""Downsample using the original OME-Zarr python library"""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
......@@ -69,7 +172,9 @@ class OMEScaler(
scaled_shape = tuple(
base.shape[j] // (self.downscale ** (i + 1)) for j in range(3)
)
rv.append(dask_resize(base, scaled_shape, order=self.order))
scaled = dask_resize(base, scaled_shape, order=self.order)
rv.append(scaled)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv)
......@@ -78,27 +183,30 @@ class OMEScaler(
def export_ome_zarr(
path,
data,
chunk_size=100,
chunk_size=256,
downsample_rate=2,
order=0,
order=1,
replace=False,
method="scaleZYX",
progress_bar: bool = True,
progress_bar_repeat_time="auto",
):
"""
Export image data to OME-Zarr format with pyramidal downsampling.
Export 3D image data to OME-Zarr format with pyramidal downsampling.
Automatically calculates the number of downsampled scales such that the smallest scale fits within the specified `chunk_size`.
This function generates a multi-scale OME-Zarr representation of the input data, which is commonly used for large imaging datasets. The downsampled scales are calculated such that the smallest scale fits within the specified `chunk_size`.
Args:
path (str): The directory where the OME-Zarr data will be stored.
data (np.ndarray): The image data to be exported.
chunk_size (int, optional): The size of the chunks for storing data. Defaults to 100.
downsample_rate (int, optional): Factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2.
order (int, optional): Interpolation order to use when downsampling. Defaults to 0 (nearest-neighbor).
data (np.ndarray or dask.array): The 3D image data to be exported. Supports both NumPy and Dask arrays.
chunk_size (int, optional): The size of the chunks for storing data. This affects both the original data and the downsampled scales. Defaults to 256.
downsample_rate (int, optional): The factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2.
order (int, optional): The interpolation order to use when downsampling. Defaults to 1 (linear). Use 0 for a faster nearest-neighbor interpolation.
replace (bool, optional): Whether to replace the existing directory if it already exists. Defaults to False.
progress_bar (bool, optional): Whether to display progress while writing data to disk. Defaults to True.
method (str, optional): The method used for downsampling. If set to "dask", Dask arrays are used for chunking and downsampling. Defaults to "scaleZYX".
progress_bar (bool, optional): Whether to display a progress bar during export. Defaults to True.
progress_bar_repeat_time (str or int, optional): The repeat interval (in seconds) for updating the progress bar. Defaults to "auto".
Raises:
ValueError: If the directory already exists and `replace` is False.
ValueError: If `downsample_rate` is less than or equal to 1.
......@@ -111,9 +219,10 @@ def export_ome_zarr(
data = downloader.Snail.Escargot(load_file=True)
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2)
```
Returns:
None: This function writes the OME-Zarr data to the specified directory and does not return any value.
"""
# Check if directory exists
......@@ -146,30 +255,47 @@ def export_ome_zarr(
store = parse_url(path, mode="w").store
root = zarr.group(store=store)
fmt = CurrentFormat()
log.info("Creating a multi-scale pyramid")
mip, axes = _create_mip(image=data, fmt=fmt, scaler=scaler, axes="zyx")
# Check if we want to process using Dask
if "dask" in method and not isinstance(data, da.Array):
log.info("\nConverting input data to Dask array")
data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
elif "dask" in method and isinstance(data, da.Array):
log.info("\nInput data will be rechunked")
data = data.rechunk((chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info("Calculating the multi-scale pyramid")
# Generate multi-scale pyramid
mip = scaler.func(data)
log.info("Writing data to disk")
kwargs = dict(
pyramid=mip,
group=root,
fmt=fmt,
axes=axes,
fmt=CurrentFormat(),
axes="zyx",
name=None,
compute=True,
storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)),
)
if progress_bar:
n_chunks = get_n_chunks(
shapes=(scaled_data.shape for scaled_data in mip),
dtypes = (scaled_data.dtype for scaled_data in mip)
dtypes=(scaled_data.dtype for scaled_data in mip),
)
with OmeZarrExportProgressBar(path = path, n_chunks = n_chunks, reapeat_time=progress_bar_repeat_time):
with OmeZarrExportProgressBar(
path=path, n_chunks=n_chunks, reapeat_time=progress_bar_repeat_time
):
write_multiscale(**kwargs)
else:
write_multiscale(**kwargs)
log.info("All done!")
log.info("\nAll done!")
return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment