Skip to content
Snippets Groups Projects
Commit 1923c384 authored by fima's avatar fima :beers:
Browse files

Merge branch 'zarr_loading' into 'main'

Zarr loading and converting

See merge request !96
parents b48441fb 4dc09364
Branches
Tags
1 merge request!96Zarr loading and converting
...@@ -2,4 +2,5 @@ from .loading import DataLoader, load, ImgExamples ...@@ -2,4 +2,5 @@ from .loading import DataLoader, load, ImgExamples
from .downloader import Downloader from .downloader import Downloader
from .saving import DataSaver, save from .saving import DataSaver, save
from .sync import Sync from .sync import Sync
from .convert import convert
from . import logger from . import logger
\ No newline at end of file
import difflib
import os
from itertools import product
import numpy as np
import tifffile as tiff
import zarr
from tqdm import tqdm
from qim3d.utils.internal_tools import stringify_path
class Convert:
def __init__(self,**kwargs):
""" Utility class to convert files to other formats without loading the entire file into memory
Args:
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
def convert(self, input_path, output_path):
# Stringify path in case it is not already a string
input_path = stringify_path(input_path)
input_ext = os.path.splitext(input_path)[1]
output_ext = os.path.splitext(output_path)[1]
output_path = stringify_path(output_path)
if os.path.isfile(input_path):
match input_ext, output_ext:
case (".tif", ".zarr") | (".tiff", ".zarr"):
return self.convert_tif_to_zarr(input_path, output_path)
case _:
raise ValueError("Unsupported file format")
# Load a directory
elif os.path.isdir(input_path):
match input_ext, output_ext:
case (".zarr", ".tif") | (".zarr", ".tiff"):
return self.convert_zarr_to_tif(input_path, output_path)
case _:
raise ValueError("Unsupported file format")
# Fail
else:
# Find the closest matching path to warn the user
parent_dir = os.path.dirname(input_path) or "."
parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
similar_paths = difflib.get_close_matches(input_path, valid_paths)
if similar_paths:
suggestion = similar_paths[0] # Get the closest match
message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message))
else:
raise ValueError("Invalid path")
def convert_tif_to_zarr(self, tif_path, zarr_path, chunks=(64, 64, 64)):
""" Convert a tiff file to a zarr file
Args:
tif_path (str): path to the tiff file
zarr_path (str): path to the zarr file
chunks (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
Returns:
zarr.core.Array: zarr array containing the data from the tiff file
"""
vol = tiff.memmap(tif_path)
z = zarr.open(zarr_path, mode='w', shape=vol.shape, chunks=chunks, dtype=vol.dtype)
chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
# ! Fastest way is z[:] = vol[:], but does not have a progress bar
for chunk_indices in tqdm(product(*[range(n) for n in chunk_shape]), total=np.prod(chunk_shape)):
slices = tuple(slice(c * i, min(c * (i + 1), s))
for s, c, i in zip(z.shape, z.chunks, chunk_indices))
temp_data = vol[slices]
# The assignment takes 99% of the cpu-time
z.blocks[chunk_indices] = temp_data
return z
def convert_zarr_to_tif(self, zarr_path, tif_path):
""" Convert a zarr file to a tiff file
Args:
zarr_path (str): path to the zarr file
tif_path (str): path to the tiff file
returns:
None
"""
z = zarr.open(zarr_path)
tiff.imwrite(tif_path, z)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
""" Convert a file to another format without loading the entire file into memory
Args:
input_path (str): path to the input file
output_path (str): path to the output file
chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
"""
converter = Convert(chunk_shape=chunk_shape)
converter.convert(input_path, output_path)
...@@ -606,6 +606,27 @@ class DataLoader: ...@@ -606,6 +606,27 @@ class DataLoader:
else: else:
return vol return vol
def load_zarr(self, path: str):
""" Loads a Zarr array from disk.
Args:
path (str): The path to the Zarr array on disk.
Returns:
dask.array | numpy.ndarray: The dask array loaded from disk.
if 'self.virtual_stack' is True, returns a dask array object, else returns a numpy.ndarray object.
"""
# Opens the Zarr array
vol = da.from_zarr(path)
# If virtual stack is disabled, return the computed array (np.ndarray)
if not self.virtual_stack:
vol = vol.compute()
return vol
def check_file_size(self, filename: str): def check_file_size(self, filename: str):
""" """
Checks if there is enough memory where the file can be loaded. Checks if there is enough memory where the file can be loaded.
...@@ -693,6 +714,8 @@ class DataLoader: ...@@ -693,6 +714,8 @@ class DataLoader:
elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]): elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]):
return self.load_PIL_stack(path) return self.load_PIL_stack(path)
elif path.endswith(".zarr"):
return self.load_zarr(path)
else: else:
return self.load_dicom_dir(path) return self.load_dicom_dir(path)
......
...@@ -24,9 +24,17 @@ Example: ...@@ -24,9 +24,17 @@ Example:
import datetime import datetime
import os import os
import dask.array as da
import h5py
import nibabel as nib
import numpy as np import numpy as np
import PIL import PIL
import tifffile import tifffile
import zarr
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
from qim3d.io.logger import log from qim3d.io.logger import log
from qim3d.utils.internal_tools import sizeof, stringify_path from qim3d.utils.internal_tools import sizeof, stringify_path
...@@ -249,6 +257,22 @@ class DataSaver: ...@@ -249,6 +257,22 @@ class DataSaver:
ds.save_as(path) ds.save_as(path)
def save_to_zarr(self, path, data):
""" Saves a Dask array to a Zarr array on disk.
Args:
path (str): The path to the Zarr array on disk.
data (dask.array): The Dask array to be saved to disk.
Returns:
zarr.core.Array: The Zarr array saved on disk.
"""
assert isinstance(data, da.Array), 'data must be a dask array'
# forces compute when saving to zarr
da.to_zarr(data, path, compute=True, overwrite=self.replace, compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=2))
def save_PIL(self, path, data): def save_PIL(self, path, data):
"""Save data to a PIL file to the given path. """Save data to a PIL file to the given path.
...@@ -344,6 +368,8 @@ class DataSaver: ...@@ -344,6 +368,8 @@ class DataSaver:
return self.save_vol(path, data) return self.save_vol(path, data)
elif path.endswith((".dcm", ".DCM")): elif path.endswith((".dcm", ".DCM")):
return self.save_dicom(path, data) return self.save_dicom(path, data)
elif path.endswith((".zarr")):
return self.save_to_zarr(path, data)
elif path.endswith((".jpeg",".jpg", ".png")): elif path.endswith((".jpeg",".jpg", ".png")):
return self.save_PIL(path, data) return self.save_PIL(path, data)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment