diff --git a/qim3d/io/convert.py b/qim3d/io/convert.py index 5efee712d39158248f36578ac7125222db90e0fd..2e18a23ae2a86d3627b980e1f5d1c89790801a0d 100644 --- a/qim3d/io/convert.py +++ b/qim3d/io/convert.py @@ -9,6 +9,7 @@ import zarr from tqdm import tqdm from qim3d.utils.internal_tools import stringify_path +from qim3d.io.saving import save class Convert: @@ -21,17 +22,23 @@ class Convert: self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64)) def convert(self, input_path, output_path): + def get_file_extension(file_path): + root, ext = os.path.splitext(file_path) + if ext in ['.gz', '.bz2', '.xz']: # handle common compressed extensions + root, ext2 = os.path.splitext(root) + ext = ext2 + ext + return ext # Stringify path in case it is not already a string input_path = stringify_path(input_path) - input_ext = os.path.splitext(input_path)[1] - output_ext = os.path.splitext(output_path)[1] + input_ext = get_file_extension(input_path) + output_ext = get_file_extension(output_path) output_path = stringify_path(output_path) if os.path.isfile(input_path): match input_ext, output_ext: case (".tif", ".zarr") | (".tiff", ".zarr"): return self.convert_tif_to_zarr(input_path, output_path) - case (".nii", ".zarr"): + case (".nii", ".zarr") | (".nii.gz", ".zarr"): return self.convert_nifti_to_zarr(input_path, output_path) case _: raise ValueError("Unsupported file format") @@ -41,7 +48,9 @@ class Convert: case (".zarr", ".tif") | (".zarr", ".tiff"): return self.convert_zarr_to_tif(input_path, output_path) case (".zarr", ".nii"): - return self.convert_zarr_to_tif(input_path, output_path) + return self.convert_zarr_to_nifti(input_path, output_path) + case (".zarr", ".nii.gz"): + return self.convert_zarr_to_nifti(input_path, output_path, compress=True) case _: raise ValueError("Unsupported file format") # Fail @@ -71,7 +80,7 @@ class Convert: """ vol = tiff.memmap(tif_path) z = zarr.open( - zarr_path, mode="w", shape=vol.shape, chunks=self.chunks, dtype=vol.dtype + zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype ) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) # ! Fastest way is z[:] = vol[:], but does not have a progress bar @@ -99,7 +108,7 @@ class Convert: None """ z = zarr.open(zarr_path) - tiff.imwrite(tif_path, z) + save(tif_path, z) def convert_nifti_to_zarr(self, nifti_path, zarr_path): """Convert a nifti file to a zarr file @@ -113,7 +122,7 @@ class Convert: """ vol = nib.load(nifti_path).dataobj z = zarr.open( - zarr_path, mode="w", shape=vol.shape, chunks=self.chunks, dtype=vol.dtype + zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype ) chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks)) # ! Fastest way is z[:] = vol[:], but does not have a progress bar @@ -130,7 +139,7 @@ class Convert: return z - def convert_zarr_to_nifti(self, zarr_path, nifti_path): + def convert_zarr_to_nifti(self, zarr_path, nifti_path, compress=False): """Convert a zarr file to a nifti file Args: @@ -141,7 +150,8 @@ class Convert: None """ z = zarr.open(zarr_path) - nib.save(nib.Nifti1Image(z, np.eye(4)), nifti_path) + save(nifti_path, z, compress=compress) + def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)): """Convert a file to another format without loading the entire file into memory @@ -151,6 +161,5 @@ def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64) output_path (str): path to the output file chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64). """ - converter = Convert(chunk_shape=chunk_shape) converter.convert(input_path, output_path)