From 2fd4fae77a9e9592dc9329313fde37f4288ab1a4 Mon Sep 17 00:00:00 2001
From: s204159 <s204159@student.dtu.dk>
Date: Fri, 28 Jun 2024 11:27:18 +0200
Subject: [PATCH] Conv zarr nifti

---
 qim3d/io/convert.py | 97 ++++++++++++++++++++++++++++++++++++---------
 qim3d/utils/cli.py  |  3 ++
 2 files changed, 81 insertions(+), 19 deletions(-)

diff --git a/qim3d/io/convert.py b/qim3d/io/convert.py
index b2ab1e49..f2af9e6b 100644
--- a/qim3d/io/convert.py
+++ b/qim3d/io/convert.py
@@ -2,36 +2,44 @@ import difflib
 import os
 from itertools import product
 
+import nibabel as nib
 import numpy as np
 import tifffile as tiff
 import zarr
 from tqdm import tqdm
 
 from qim3d.utils.internal_tools import stringify_path
+from qim3d.io.saving import save
 
 
 class Convert:
-
-    def __init__(self,**kwargs):
-        """ Utility class to convert files to other formats without loading the entire file into memory
+    def __init__(self, **kwargs):
+        """Utility class to convert files to other formats without loading the entire file into memory
 
         Args:
             chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
         """
         self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
 
-
     def convert(self, input_path, output_path):
+        def get_file_extension(file_path):
+            root, ext = os.path.splitext(file_path)
+            if ext in ['.gz', '.bz2', '.xz']:  # handle common compressed extensions
+                root, ext2 = os.path.splitext(root)
+                ext = ext2 + ext
+            return ext
         # Stringify path in case it is not already a string
         input_path = stringify_path(input_path)
-        input_ext = os.path.splitext(input_path)[1]
-        output_ext = os.path.splitext(output_path)[1]
+        input_ext = get_file_extension(input_path)
+        output_ext = get_file_extension(output_path)
         output_path = stringify_path(output_path)
 
-        if os.path.isfile(input_path):  
+        if os.path.isfile(input_path):
             match input_ext, output_ext:
-                case (".tif", ".zarr")  | (".tiff", ".zarr"):
-                    return self.convert_tif_to_zarr(input_path, output_path, chunks=self.chunk_shape)
+                case (".tif", ".zarr") | (".tiff", ".zarr"):
+                    return self.convert_tif_to_zarr(input_path, output_path)
+                case (".nii", ".zarr") | (".nii.gz", ".zarr"):
+                    return self.convert_nifti_to_zarr(input_path, output_path)
                 case _:
                     raise ValueError("Unsupported file format")
         # Load a directory
@@ -39,6 +47,10 @@ class Convert:
             match input_ext, output_ext:
                 case (".zarr", ".tif") | (".zarr", ".tiff"):
                     return self.convert_zarr_to_tif(input_path, output_path)
+                case (".zarr", ".nii"):
+                    return self.convert_zarr_to_nifti(input_path, output_path)
+                case (".zarr", ".nii.gz"):
+                    return self.convert_zarr_to_nifti(input_path, output_path, compression=True)
                 case _:
                     raise ValueError("Unsupported file format")
         # Fail
@@ -55,8 +67,8 @@ class Convert:
             else:
                 raise ValueError("Invalid path")
 
-    def convert_tif_to_zarr(self, tif_path, zarr_path, chunks=(64, 64, 64)):
-        """ Convert a tiff file to a zarr file
+    def convert_tif_to_zarr(self, tif_path, zarr_path):
+        """Convert a tiff file to a zarr file
 
         Args:
             tif_path (str): path to the tiff file
@@ -67,12 +79,18 @@ class Convert:
             zarr.core.Array: zarr array containing the data from the tiff file
         """
         vol = tiff.memmap(tif_path)
-        z = zarr.open(zarr_path, mode='w', shape=vol.shape, chunks=chunks, dtype=vol.dtype)
+        z = zarr.open(
+            zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
+        )
         chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
         # ! Fastest way is z[:] = vol[:], but does not have a progress bar
-        for chunk_indices in tqdm(product(*[range(n) for n in chunk_shape]), total=np.prod(chunk_shape)):
-            slices = tuple(slice(c * i, min(c * (i + 1), s))
-                        for s, c, i in zip(z.shape, z.chunks, chunk_indices))
+        for chunk_indices in tqdm(
+            product(*[range(n) for n in chunk_shape]), total=np.prod(chunk_shape)
+        ):
+            slices = tuple(
+                slice(c * i, min(c * (i + 1), s))
+                for s, c, i in zip(z.shape, z.chunks, chunk_indices)
+            )
             temp_data = vol[slices]
             # The assignment takes 99% of the cpu-time
             z.blocks[chunk_indices] = temp_data
@@ -80,7 +98,7 @@ class Convert:
         return z
 
     def convert_zarr_to_tif(self, zarr_path, tif_path):
-        """ Convert a zarr file to a tiff file
+        """Convert a zarr file to a tiff file
 
         Args:
             zarr_path (str): path to the zarr file
@@ -90,17 +108,58 @@ class Convert:
             None
         """
         z = zarr.open(zarr_path)
-        tiff.imwrite(tif_path, z)
+        save(tif_path, z)
 
+    def convert_nifti_to_zarr(self, nifti_path, zarr_path):
+        """Convert a nifti file to a zarr file
+
+        Args:
+            nifti_path (str): path to the nifti file
+            zarr_path (str): path to the zarr file
+
+        Returns:
+            zarr.core.Array: zarr array containing the data from the nifti file
+        """
+        vol = nib.load(nifti_path).dataobj
+        z = zarr.open(
+            zarr_path, mode="w", shape=vol.shape, chunks=self.chunk_shape, dtype=vol.dtype
+        )
+        chunk_shape = tuple((s + c - 1) // c for s, c in zip(z.shape, z.chunks))
+        # ! Fastest way is z[:] = vol[:], but does not have a progress bar
+        for chunk_indices in tqdm(
+            product(*[range(n) for n in chunk_shape]), total=np.prod(chunk_shape)
+        ):
+            slices = tuple(
+                slice(c * i, min(c * (i + 1), s))
+                for s, c, i in zip(z.shape, z.chunks, chunk_indices)
+            )
+            temp_data = vol[slices]
+            # The assignment takes 99% of the cpu-time
+            z.blocks[chunk_indices] = temp_data
+
+        return z
+
+    def convert_zarr_to_nifti(self, zarr_path, nifti_path, compression=False):
+        """Convert a zarr file to a nifti file
+
+        Args:
+            zarr_path (str): path to the zarr file
+            nifti_path (str): path to the nifti file
+
+        Returns:
+            None
+        """
+        z = zarr.open(zarr_path)
+        save(nifti_path, z, compression=compression)
+        
 
 def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
-    """ Convert a file to another format without loading the entire file into memory
+    """Convert a file to another format without loading the entire file into memory
 
     Args:
         input_path (str): path to the input file
         output_path (str): path to the output file
         chunk_shape (tuple, optional): chunk size for the zarr file. Defaults to (64, 64, 64).
     """
-
     converter = Convert(chunk_shape=chunk_shape)
     converter.convert(input_path, output_path)
diff --git a/qim3d/utils/cli.py b/qim3d/utils/cli.py
index ac8708c4..750997de 100644
--- a/qim3d/utils/cli.py
+++ b/qim3d/utils/cli.py
@@ -189,6 +189,9 @@ def main():
         parser.print_help()
         print("\n")
 
+    elif args.subcommand == 'convert':
+        qim3d.io.convert(args.input_path, args.output_path)
+
 
 if __name__ == "__main__":
     main()
-- 
GitLab