From a96611b177bbd2bba52300e8a5b5a5d3f8d62309 Mon Sep 17 00:00:00 2001
From: Felipe <fima@dtu.dk>
Date: Wed, 25 Sep 2024 12:11:18 +0200
Subject: [PATCH] Saving zarr now does not use dask

---
 qim3d/io/saving.py | 31 +++++++++++++++++++++++--------
 1 file changed, 23 insertions(+), 8 deletions(-)

diff --git a/qim3d/io/saving.py b/qim3d/io/saving.py
index 42f8e29f..dda53eb5 100644
--- a/qim3d/io/saving.py
+++ b/qim3d/io/saving.py
@@ -77,6 +77,7 @@ class DataSaver:
         self.compression = kwargs.get("compression", False)
         self.basename = kwargs.get("basename", None)
         self.sliced_dim = kwargs.get("sliced_dim", 0)
+        self.chunk_shape = kwargs.get("chunk_shape", "auto")
 
     def save_tiff(self, path, data):
         """Save data to a TIFF file to the given path.
@@ -258,7 +259,7 @@ class DataSaver:
         ds.save_as(path)
 
     def save_to_zarr(self, path, data):
-        """ Saves a Dask array to a Zarr array on disk.
+        """Saves a Dask array to a Zarr array on disk.
 
         Args:
             path (str): The path to the Zarr array on disk.
@@ -267,12 +268,18 @@ class DataSaver:
         Returns:
             zarr.core.Array: The Zarr array saved on disk.
         """
-        assert isinstance(data, da.Array), 'data must be a dask array'
 
-        # forces compute when saving to zarr
-        da.to_zarr(data, path, compute=True, overwrite=self.replace, compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=2))
-        
-    
+        # Old version that is using dask
+
+        # assert isinstance(data, da.Array), 'data must be a dask array'
+
+        # # forces compute when saving to zarr
+        # da.to_zarr(data, path, compute=True, overwrite=self.replace, compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=2))
+        zarr_array = zarr.open(
+            path, mode="w", shape=data.shape, chunks=self.chunk_shape, dtype=data.dtype
+        )
+        zarr_array[:] = data
+
     def save_PIL(self, path, data):
         """Save data to a PIL file to the given path.
 
@@ -378,7 +385,7 @@ class DataSaver:
                         return self.save_dicom(path, data)
                     elif path.endswith((".zarr")):
                         return self.save_to_zarr(path, data)
-                    elif path.endswith((".jpeg",".jpg", ".png")):
+                    elif path.endswith((".jpeg", ".jpg", ".png")):
                         return self.save_PIL(path, data)
                     else:
                         raise ValueError("Unsupported file format")
@@ -396,7 +403,14 @@ class DataSaver:
 
 
 def save(
-    path, data, replace=False, compression=False, basename=None, sliced_dim=0, **kwargs
+    path,
+    data,
+    replace=False,
+    compression=False,
+    basename=None,
+    sliced_dim=0,
+    chunk_shape="auto",
+    **kwargs,
 ):
     """Save data to a specified file path.
 
@@ -427,5 +441,6 @@ def save(
         compression=compression,
         basename=basename,
         sliced_dim=sliced_dim,
+        chunk_shape=chunk_shape,
         **kwargs,
     ).save(path, data)
-- 
GitLab