Skip to content
Snippets Groups Projects
Commit 0053dccc authored by fima's avatar fima :beers:
Browse files

Ome zarr export optimization

parent 85014a4b
No related branches found
No related tags found
1 merge request!125Ome zarr export optimization
...@@ -5,6 +5,7 @@ Exporting data to different formats. ...@@ -5,6 +5,7 @@ Exporting data to different formats.
import os import os
import math import math
import shutil import shutil
import logging
import numpy as np import numpy as np
import zarr import zarr
...@@ -23,6 +24,9 @@ from ome_zarr import scale ...@@ -23,6 +24,9 @@ from ome_zarr import scale
from scipy.ndimage import zoom from scipy.ndimage import zoom
from typing import Any, Callable, Iterator, List, Tuple, Union from typing import Any, Callable, Iterator, List, Tuple, Union
import dask.array as da import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from skimage.transform import ( from skimage.transform import (
resize, resize,
) )
...@@ -54,13 +58,112 @@ class OMEScaler( ...@@ -54,13 +58,112 @@ class OMEScaler(
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f"- Scale 0: {rv[-1].shape}")
for i in range(self.max_layer): for i in range(self.max_layer):
downscale_ratio = (1 / self.downscale) ** (i + 1) rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order))
rv.append(zoom(base, zoom=downscale_ratio, order=self.order))
log.info(f"- Scale {i+1}: {rv[-1].shape}") log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv) return list(rv)
def scaleZYXdask(self, base): def scaleZYXdask(self, base):
"""Downsample using :func:`scipy.ndimage.zoom`.""" """
Downsample a 3D volume using Dask and scipy.ndimage.zoom.
This method performs multi-scale downsampling on a 3D dataset, generating image pyramids. It processes the data in chunks using Dask.
Args:
base (dask.array): The 3D array (volume) to be downsampled. Must be a Dask array for chunked processing.
Returns:
list of dask.array: A list of downsampled volumes, where each element represents a different scale. The first element corresponds to the original resolution, and subsequent elements represent progressively downsampled versions.
The downsampling process occurs scale by scale, using the following steps:
- For each scale, the array is resized based on the downscale factor, computed as a function of the current scale level.
- The `scipy.ndimage.zoom` function is used to perform interpolation, with chunk-wise processing handled by Dask's `map_blocks` function.
- The output is rechunked to match the input volume's original chunk size.
"""
def resize_zoom(vol, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library
better_chunksize = tuple(
np.maximum(
1, np.round(np.array(vol.chunksize) * scale_factors) / scale_factors
).astype(int)
)
log.debug(f"better chunk size: {better_chunksize}")
# Compute the chunk size after the downscaling
new_chunk_size = tuple(
np.ceil(np.multiply(better_chunksize, scale_factors)).astype(int)
)
log.debug(
f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}"
)
def resize_chunk(chunk, scale_factors, order):
#print(f"zoom factors: {scale_factors}")
resized_chunk = zoom(
chunk,
zoom=scale_factors,
order=order,
mode="grid-constant",
grid_mode=True,
)
#print(f"resized chunk shape: {resized_chunk.shape}")
return resized_chunk
output_slices = tuple(slice(0, d) for d in scaled_shape)
# Testing new shape
predicted_shape = np.multiply(vol.shape, scale_factors)
log.debug(f"predicted shape: {predicted_shape}")
scaled_vol = da.map_blocks(
resize_chunk,
vol,
scale_factors,
order,
chunks=new_chunk_size,
)[output_slices]
# Rechunk the output to match the input
# This is needed because chunks were scaled down
scaled_vol = scaled_vol.rechunk(vol.chunksize)
return scaled_vol
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
for i in range(self.max_layer):
log.debug(f"\nScale {i+1}\n{'-'*32}")
# Calculate the downscale factor for this scale
downscale_factor = 1 / (self.downscale ** (i + 1))
scaled_shape = tuple(
np.ceil(np.multiply(base.shape, downscale_factor)).astype(int)
)
log.debug(f"target shape: {scaled_shape}")
downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float))
log.debug(f"downscale rate: {downscale_rate}")
scale_factors = tuple(np.divide(1, downscale_rate))
log.debug(f"scale factors: {scale_factors}")
log.debug("\nResizing volume chunk-wise")
scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape)
rv.append(scaled_vol)
log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv)
def scaleZYXdask_legacy(self, base):
"""Downsample using the original OME-Zarr python library"""
rv = [base] rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f"- Scale 0: {rv[-1].shape}")
...@@ -69,7 +172,9 @@ class OMEScaler( ...@@ -69,7 +172,9 @@ class OMEScaler(
scaled_shape = tuple( scaled_shape = tuple(
base.shape[j] // (self.downscale ** (i + 1)) for j in range(3) base.shape[j] // (self.downscale ** (i + 1)) for j in range(3)
) )
rv.append(dask_resize(base, scaled_shape, order=self.order))
scaled = dask_resize(base, scaled_shape, order=self.order)
rv.append(scaled)
log.info(f"- Scale {i+1}: {rv[-1].shape}") log.info(f"- Scale {i+1}: {rv[-1].shape}")
return list(rv) return list(rv)
...@@ -78,27 +183,30 @@ class OMEScaler( ...@@ -78,27 +183,30 @@ class OMEScaler(
def export_ome_zarr( def export_ome_zarr(
path, path,
data, data,
chunk_size=100, chunk_size=256,
downsample_rate=2, downsample_rate=2,
order=0, order=1,
replace=False, replace=False,
method="scaleZYX", method="scaleZYX",
progress_bar:bool = True, progress_bar: bool = True,
progress_bar_repeat_time = "auto", progress_bar_repeat_time="auto",
): ):
""" """
Export image data to OME-Zarr format with pyramidal downsampling. Export 3D image data to OME-Zarr format with pyramidal downsampling.
Automatically calculates the number of downsampled scales such that the smallest scale fits within the specified `chunk_size`. This function generates a multi-scale OME-Zarr representation of the input data, which is commonly used for large imaging datasets. The downsampled scales are calculated such that the smallest scale fits within the specified `chunk_size`.
Args: Args:
path (str): The directory where the OME-Zarr data will be stored. path (str): The directory where the OME-Zarr data will be stored.
data (np.ndarray): The image data to be exported. data (np.ndarray or dask.array): The 3D image data to be exported. Supports both NumPy and Dask arrays.
chunk_size (int, optional): The size of the chunks for storing data. Defaults to 100. chunk_size (int, optional): The size of the chunks for storing data. This affects both the original data and the downsampled scales. Defaults to 256.
downsample_rate (int, optional): Factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2. downsample_rate (int, optional): The factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2.
order (int, optional): Interpolation order to use when downsampling. Defaults to 0 (nearest-neighbor). order (int, optional): The interpolation order to use when downsampling. Defaults to 1 (linear). Use 0 for a faster nearest-neighbor interpolation.
replace (bool, optional): Whether to replace the existing directory if it already exists. Defaults to False. replace (bool, optional): Whether to replace the existing directory if it already exists. Defaults to False.
progress_bar (bool, optional): Whether to display progress while writing data to disk. Defaults to True. method (str, optional): The method used for downsampling. If set to "dask", Dask arrays are used for chunking and downsampling. Defaults to "scaleZYX".
progress_bar (bool, optional): Whether to display a progress bar during export. Defaults to True.
progress_bar_repeat_time (str or int, optional): The repeat interval (in seconds) for updating the progress bar. Defaults to "auto".
Raises: Raises:
ValueError: If the directory already exists and `replace` is False. ValueError: If the directory already exists and `replace` is False.
ValueError: If `downsample_rate` is less than or equal to 1. ValueError: If `downsample_rate` is less than or equal to 1.
...@@ -111,9 +219,10 @@ def export_ome_zarr( ...@@ -111,9 +219,10 @@ def export_ome_zarr(
data = downloader.Snail.Escargot(load_file=True) data = downloader.Snail.Escargot(load_file=True)
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2) qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2)
``` ```
Returns:
None: This function writes the OME-Zarr data to the specified directory and does not return any value.
""" """
# Check if directory exists # Check if directory exists
...@@ -146,30 +255,47 @@ def export_ome_zarr( ...@@ -146,30 +255,47 @@ def export_ome_zarr(
store = parse_url(path, mode="w").store store = parse_url(path, mode="w").store
root = zarr.group(store=store) root = zarr.group(store=store)
fmt = CurrentFormat() # Check if we want to process using Dask
log.info("Creating a multi-scale pyramid") if "dask" in method and not isinstance(data, da.Array):
mip, axes = _create_mip(image=data, fmt=fmt, scaler=scaler, axes="zyx") log.info("\nConverting input data to Dask array")
data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
elif "dask" in method and isinstance(data, da.Array):
log.info("\nInput data will be rechunked")
data = data.rechunk((chunk_size, chunk_size, chunk_size))
log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n")
log.info("Calculating the multi-scale pyramid")
# Generate multi-scale pyramid
mip = scaler.func(data)
log.info("Writing data to disk") log.info("Writing data to disk")
kwargs = dict( kwargs = dict(
pyramid=mip, pyramid=mip,
group=root, group=root,
fmt=fmt, fmt=CurrentFormat(),
axes=axes, axes="zyx",
name=None, name=None,
compute=True, compute=True,
) storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)),
)
if progress_bar: if progress_bar:
n_chunks = get_n_chunks( n_chunks = get_n_chunks(
shapes = (scaled_data.shape for scaled_data in mip), shapes=(scaled_data.shape for scaled_data in mip),
dtypes = (scaled_data.dtype for scaled_data in mip) dtypes=(scaled_data.dtype for scaled_data in mip),
) )
with OmeZarrExportProgressBar(path = path, n_chunks = n_chunks, reapeat_time=progress_bar_repeat_time): with OmeZarrExportProgressBar(
path=path, n_chunks=n_chunks, reapeat_time=progress_bar_repeat_time
):
write_multiscale(**kwargs) write_multiscale(**kwargs)
else: else:
write_multiscale(**kwargs) write_multiscale(**kwargs)
log.info("All done!") log.info("\nAll done!")
return return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment