diff --git a/qim3d/io/ome_zarr.py b/qim3d/io/ome_zarr.py index 29137382b45f8f099314891f1504f8b6670505c5..feca06225890176d5233e15c5ee7602bead750b2 100644 --- a/qim3d/io/ome_zarr.py +++ b/qim3d/io/ome_zarr.py @@ -5,6 +5,7 @@ Exporting data to different formats. import os import math import shutil +import logging import numpy as np import zarr @@ -23,6 +24,9 @@ from ome_zarr import scale from scipy.ndimage import zoom from typing import Any, Callable, Iterator, List, Tuple, Union import dask.array as da +import dask +from dask.distributed import Client, LocalCluster + from skimage.transform import ( resize, ) @@ -54,13 +58,112 @@ class OMEScaler( log.info(f"- Scale 0: {rv[-1].shape}") for i in range(self.max_layer): - downscale_ratio = (1 / self.downscale) ** (i + 1) - rv.append(zoom(base, zoom=downscale_ratio, order=self.order)) + rv.append(zoom(rv[-1], zoom=1 / self.downscale, order=self.order)) log.info(f"- Scale {i+1}: {rv[-1].shape}") + return list(rv) def scaleZYXdask(self, base): - """Downsample using :func:`scipy.ndimage.zoom`.""" + """ + Downsample a 3D volume using Dask and scipy.ndimage.zoom. + + This method performs multi-scale downsampling on a 3D dataset, generating image pyramids. It processes the data in chunks using Dask. + + Args: + base (dask.array): The 3D array (volume) to be downsampled. Must be a Dask array for chunked processing. + + Returns: + list of dask.array: A list of downsampled volumes, where each element represents a different scale. The first element corresponds to the original resolution, and subsequent elements represent progressively downsampled versions. + + The downsampling process occurs scale by scale, using the following steps: + - For each scale, the array is resized based on the downscale factor, computed as a function of the current scale level. + - The `scipy.ndimage.zoom` function is used to perform interpolation, with chunk-wise processing handled by Dask's `map_blocks` function. + - The output is rechunked to match the input volume's original chunk size. + + + """ + def resize_zoom(vol, scale_factors, order, scaled_shape): + + # Get the chunksize needed so that all the blocks match the new shape + # This snippet comes from the original OME-Zarr-python library + better_chunksize = tuple( + np.maximum( + 1, np.round(np.array(vol.chunksize) * scale_factors) / scale_factors + ).astype(int) + ) + + log.debug(f"better chunk size: {better_chunksize}") + + # Compute the chunk size after the downscaling + new_chunk_size = tuple( + np.ceil(np.multiply(better_chunksize, scale_factors)).astype(int) + ) + + log.debug( + f"orginal chunk size: {vol.chunksize}, chunk size after downscale: {new_chunk_size}" + ) + + def resize_chunk(chunk, scale_factors, order): + + #print(f"zoom factors: {scale_factors}") + resized_chunk = zoom( + chunk, + zoom=scale_factors, + order=order, + mode="grid-constant", + grid_mode=True, + ) + #print(f"resized chunk shape: {resized_chunk.shape}") + + return resized_chunk + + output_slices = tuple(slice(0, d) for d in scaled_shape) + + # Testing new shape + predicted_shape = np.multiply(vol.shape, scale_factors) + log.debug(f"predicted shape: {predicted_shape}") + scaled_vol = da.map_blocks( + resize_chunk, + vol, + scale_factors, + order, + chunks=new_chunk_size, + )[output_slices] + + # Rechunk the output to match the input + # This is needed because chunks were scaled down + scaled_vol = scaled_vol.rechunk(vol.chunksize) + return scaled_vol + + rv = [base] + log.info(f"- Scale 0: {rv[-1].shape}") + + for i in range(self.max_layer): + log.debug(f"\nScale {i+1}\n{'-'*32}") + # Calculate the downscale factor for this scale + downscale_factor = 1 / (self.downscale ** (i + 1)) + + scaled_shape = tuple( + np.ceil(np.multiply(base.shape, downscale_factor)).astype(int) + ) + + log.debug(f"target shape: {scaled_shape}") + downscale_rate = tuple(np.divide(rv[-1].shape, scaled_shape).astype(float)) + log.debug(f"downscale rate: {downscale_rate}") + scale_factors = tuple(np.divide(1, downscale_rate)) + log.debug(f"scale factors: {scale_factors}") + + log.debug("\nResizing volume chunk-wise") + scaled_vol = resize_zoom(rv[-1], scale_factors, self.order, scaled_shape) + rv.append(scaled_vol) + + log.info(f"- Scale {i+1}: {rv[-1].shape}") + + return list(rv) + + def scaleZYXdask_legacy(self, base): + """Downsample using the original OME-Zarr python library""" + rv = [base] log.info(f"- Scale 0: {rv[-1].shape}") @@ -69,7 +172,9 @@ class OMEScaler( scaled_shape = tuple( base.shape[j] // (self.downscale ** (i + 1)) for j in range(3) ) - rv.append(dask_resize(base, scaled_shape, order=self.order)) + + scaled = dask_resize(base, scaled_shape, order=self.order) + rv.append(scaled) log.info(f"- Scale {i+1}: {rv[-1].shape}") return list(rv) @@ -78,27 +183,30 @@ class OMEScaler( def export_ome_zarr( path, data, - chunk_size=100, + chunk_size=256, downsample_rate=2, - order=0, + order=1, replace=False, method="scaleZYX", - progress_bar:bool = True, - progress_bar_repeat_time = "auto", + progress_bar: bool = True, + progress_bar_repeat_time="auto", ): """ - Export image data to OME-Zarr format with pyramidal downsampling. + Export 3D image data to OME-Zarr format with pyramidal downsampling. - Automatically calculates the number of downsampled scales such that the smallest scale fits within the specified `chunk_size`. + This function generates a multi-scale OME-Zarr representation of the input data, which is commonly used for large imaging datasets. The downsampled scales are calculated such that the smallest scale fits within the specified `chunk_size`. Args: path (str): The directory where the OME-Zarr data will be stored. - data (np.ndarray): The image data to be exported. - chunk_size (int, optional): The size of the chunks for storing data. Defaults to 100. - downsample_rate (int, optional): Factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2. - order (int, optional): Interpolation order to use when downsampling. Defaults to 0 (nearest-neighbor). + data (np.ndarray or dask.array): The 3D image data to be exported. Supports both NumPy and Dask arrays. + chunk_size (int, optional): The size of the chunks for storing data. This affects both the original data and the downsampled scales. Defaults to 256. + downsample_rate (int, optional): The factor by which to downsample the data for each scale. Must be greater than 1. Defaults to 2. + order (int, optional): The interpolation order to use when downsampling. Defaults to 1 (linear). Use 0 for a faster nearest-neighbor interpolation. replace (bool, optional): Whether to replace the existing directory if it already exists. Defaults to False. - progress_bar (bool, optional): Whether to display progress while writing data to disk. Defaults to True. + method (str, optional): The method used for downsampling. If set to "dask", Dask arrays are used for chunking and downsampling. Defaults to "scaleZYX". + progress_bar (bool, optional): Whether to display a progress bar during export. Defaults to True. + progress_bar_repeat_time (str or int, optional): The repeat interval (in seconds) for updating the progress bar. Defaults to "auto". + Raises: ValueError: If the directory already exists and `replace` is False. ValueError: If `downsample_rate` is less than or equal to 1. @@ -111,9 +219,10 @@ def export_ome_zarr( data = downloader.Snail.Escargot(load_file=True) qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2) - ``` + Returns: + None: This function writes the OME-Zarr data to the specified directory and does not return any value. """ # Check if directory exists @@ -146,30 +255,47 @@ def export_ome_zarr( store = parse_url(path, mode="w").store root = zarr.group(store=store) - fmt = CurrentFormat() - log.info("Creating a multi-scale pyramid") - mip, axes = _create_mip(image=data, fmt=fmt, scaler=scaler, axes="zyx") + # Check if we want to process using Dask + if "dask" in method and not isinstance(data, da.Array): + log.info("\nConverting input data to Dask array") + data = da.from_array(data, chunks=(chunk_size, chunk_size, chunk_size)) + log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") + + elif "dask" in method and isinstance(data, da.Array): + log.info("\nInput data will be rechunked") + data = data.rechunk((chunk_size, chunk_size, chunk_size)) + log.info(f" - shape...: {data.shape}\n - chunks..: {data.chunksize}\n") + + + log.info("Calculating the multi-scale pyramid") + + # Generate multi-scale pyramid + mip = scaler.func(data) log.info("Writing data to disk") - kwargs = dict( - pyramid=mip, - group=root, - fmt=fmt, - axes=axes, - name=None, - compute=True, - ) + kwargs = dict( + pyramid=mip, + group=root, + fmt=CurrentFormat(), + axes="zyx", + name=None, + compute=True, + storage_options=dict(chunks=(chunk_size, chunk_size, chunk_size)), + ) if progress_bar: n_chunks = get_n_chunks( - shapes = (scaled_data.shape for scaled_data in mip), - dtypes = (scaled_data.dtype for scaled_data in mip) - ) - with OmeZarrExportProgressBar(path = path, n_chunks = n_chunks, reapeat_time=progress_bar_repeat_time): + shapes=(scaled_data.shape for scaled_data in mip), + dtypes=(scaled_data.dtype for scaled_data in mip), + ) + with OmeZarrExportProgressBar( + path=path, n_chunks=n_chunks, reapeat_time=progress_bar_repeat_time + ): write_multiscale(**kwargs) else: write_multiscale(**kwargs) - log.info("All done!") + log.info("\nAll done!") + return