From 638c6809779b0f75a034c12490586dc44d18a8cc Mon Sep 17 00:00:00 2001
From: Christian Kento Rasmussen <christian.kento@gmail.com>
Date: Thu, 25 Jan 2024 11:08:40 +0100
Subject: [PATCH] Implemented save for h5, nifti, vol, png, jpeg, jpg

---
 qim3d/io/save.py            |  58 +++++++++++---
 qim3d/tests/io/test_save.py | 148 +++++++++++++++++++++++++++++++++++-
 2 files changed, 192 insertions(+), 14 deletions(-)

diff --git a/qim3d/io/save.py b/qim3d/io/save.py
index aeb85e06..8f4d7012 100644
--- a/qim3d/io/save.py
+++ b/qim3d/io/save.py
@@ -2,8 +2,10 @@
 
 import os
 
+import h5py
 import nibabel as nib
 import numpy as np
+import PIL
 import tifffile
 
 from qim3d.io.logger import log
@@ -126,21 +128,55 @@ class DataSaver:
             path (str): The path to save file to
             data (numpy.ndarray): The data to be saved
         """
+        # No support for compression yet
+        if self.compression:
+            raise NotImplementedError("Saving compressed .vol files is not yet supported")
 
         # Create custom .vgi metadata file
         metadata = ""
-        metadata += "{volume1}\n"
-        metadata += "[file1]"
-        metadata += "Size = 1000 1000 766"
-        metadata += "datatype = float\n"
-        metadata += "Name = WFW_200kV_6W_1mmSn_6micro_1s.vol"
+        metadata += "{volume1}\n" # .vgi organization 
+        metadata += "[file1]\n" # .vgi organization 
+        metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format
+        metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
+        metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension
 
         # Save metadata
-        with open(path[:-3] + ".vgi", "w") as f:
+        with open(path[:-4] + ".vgi", "w") as f:
             f.write(metadata)
 
         # Save data using numpy in binary format
-        np.save(path[:-3] + ".vol", data)
+        data.tofile(path[:-4] + ".vol")
+
+    def save_h5(self, path, data):
+        """ Save data to a HDF5 file to the given path.
+
+        Args:
+            path (str): The path to save file to
+            data (numpy.ndarray): The data to be saved
+        """
+
+        with h5py.File(path, "w") as f:
+            f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None)
+        
+    def save_PIL(self, path, data):
+        """ Save data to a PIL file to the given path.
+
+        Args:
+            path (str): The path to save file to
+            data (numpy.ndarray): The data to be saved
+        """
+        # No support for compression yet
+        if self.compression and path.endswith(".png"):
+            raise NotImplementedError("png does not support compression")
+        elif not self.compression and path.endswith((".jpeg",".jpg")):
+            raise NotImplementedError("jpeg does not support no compression")
+
+        # Convert to PIL image
+        img = PIL.Image.fromarray(data)
+
+        # Save image
+        img.save(path)
+
         
     def save(self, path, data):
         """Save data to the given path.
@@ -210,10 +246,12 @@ class DataSaver:
                         return self.save_nifti(path, data)
                     elif path.endswith(("TXRM","XRM","TXM")):
                         raise NotImplementedError("Saving TXRM files is not yet supported")
-                    elif path.endswith(("h5")):
-                        raise NotImplementedError("Saving HDF5 files is not yet supported")
+                    elif path.endswith((".h5")):
+                        return self.save_h5(path, data)
                     elif path.endswith((".vol",".vgi")):
-                        raise NotImplementedError("Saving VOL files is not yet supported")
+                        return self.save_vol(path, data)
+                    elif path.endswith((".jpeg",".jpg", ".png")):
+                        return self.save_PIL(path, data)
                     else:
                         raise ValueError("Unsupported file format")
                 # If there is no file extension in the path
diff --git a/qim3d/tests/io/test_save.py b/qim3d/tests/io/test_save.py
index fa5c2d22..ad615872 100644
--- a/qim3d/tests/io/test_save.py
+++ b/qim3d/tests/io/test_save.py
@@ -1,10 +1,13 @@
-import qim3d
+import hashlib
+import os
+import re
 import tempfile
+
 import numpy as np
-import os
-import hashlib
 import pytest
-import re
+
+import qim3d
+
 
 def test_image_exist():
     # Create random test image
@@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim():
             qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim)
             assert len(os.listdir(path2save))==test_image.shape[dim]
 
+def test_tiff_save_load():
+    # Create random test image
+    original_image = qim3d.examples.blobs_256x256
+
+    # Create temporary directory
+    with tempfile.TemporaryDirectory() as temp_dir:
+        image_path = os.path.join(temp_dir,"test_image.tif")
+
+        # Save to temporary directory
+        qim3d.io.save(image_path,original_image)
+
+        # Load from temporary directory
+        saved_image = qim3d.io.load(image_path)
+
+        # Get hashes
+        original_hash = calculate_image_hash(original_image)
+        saved_hash = calculate_image_hash(saved_image)
+
+        # Assert that original image is identical to saved_image
+        assert original_hash == saved_hash
+
+def test_vol_save_load():
+    # Create random test image
+    original_image = qim3d.examples.blobs_256x256x256
+
+    # Create temporary directory
+    with tempfile.TemporaryDirectory() as temp_dir:
+        image_path = os.path.join(temp_dir,"test_image.vol")
+
+        # Save to temporary directory
+        qim3d.io.save(image_path,original_image)
+
+        # Load from temporary directory
+        saved_image = qim3d.io.load(image_path)
+
+        # Get hashes
+        original_hash = calculate_image_hash(original_image)
+        saved_hash = calculate_image_hash(saved_image)
+
+        # Assert that original image is identical to saved_image
+        assert original_hash == saved_hash
+
+def test_pil_save_load():
+    # Create random test image
+    original_image = qim3d.examples.blobs_256x256
+
+    # Create temporary directory
+    with tempfile.TemporaryDirectory() as temp_dir:
+        image_path_png = os.path.join(temp_dir,"test_image.png")
+        image_path_jpg = os.path.join(temp_dir,"test_image.jpg")
+
+        # Save to temporary directory
+        qim3d.io.save(image_path_png,original_image)
+        qim3d.io.save(image_path_jpg,original_image, compression=True)
+
+        # Load from temporary directory
+        saved_image_png = qim3d.io.load(image_path_png)
+        saved_image_jpg = qim3d.io.load(image_path_jpg)
+
+        # Get hashes
+        original_hash = calculate_image_hash(original_image)
+        saved_png_hash = calculate_image_hash(saved_image_png)
+
+        # Assert that original image is identical to saved_image
+        assert original_hash == saved_png_hash
+        
+        # jpg is lossy so the hashes will not match, checks that the image is the same size and similar values
+        assert original_image.shape == saved_image_jpg.shape
+
+def test_nifti_save_load():
+    # Create random test image
+    original_image = qim3d.examples.blobs_256x256
+
+    # Create temporary directory
+    with tempfile.TemporaryDirectory() as temp_dir:
+        image_path = os.path.join(temp_dir,"test_image.nii")
+        image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
+
+        # Save to temporary directory
+        qim3d.io.save(image_path,original_image)
+        qim3d.io.save(image_path_compressed,original_image, compression=True)
+
+        # Load from temporary directory
+        saved_image = qim3d.io.load(image_path)
+        saved_image_compressed = qim3d.io.load(image_path_compressed)
+
+        # Get hashes
+        original_hash = calculate_image_hash(original_image)
+        saved_hash = calculate_image_hash(saved_image)
+        saved_compressed_hash = calculate_image_hash(saved_image_compressed)
+
+        # Assert that original image is identical to saved_image
+        assert original_hash == saved_hash
+        assert original_hash == saved_compressed_hash
+
+        # Compute file sizes
+        file_size = os.path.getsize(image_path)
+        compressed_file_size = os.path.getsize(image_path_compressed)
+
+        # Assert that compressed file size is smaller than non-compressed file size
+        assert compressed_file_size < file_size
+
+
+def test_h5_save_load():
+    # Create random test image
+    original_image = qim3d.examples.blobs_256x256x256
+
+    # Create temporary directory
+    with tempfile.TemporaryDirectory() as temp_dir:
+        image_path = os.path.join(temp_dir,"test_image.h5")
+        image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
+
+        # Save to temporary directory
+        qim3d.io.save(image_path,original_image)
+        qim3d.io.save(image_path_compressed,original_image, compression=True)
+
+        # Load from temporary directory
+        saved_image = qim3d.io.load(image_path)
+        saved_image_compressed = qim3d.io.load(image_path_compressed)
+
+        # Get hashes
+        original_hash = calculate_image_hash(original_image)
+        saved_hash = calculate_image_hash(saved_image)
+        saved_compressed_hash = calculate_image_hash(saved_image_compressed)
+
+        # Assert that original image is identical to saved_image
+        # Assert that original image is identical to saved_image
+        assert original_hash == saved_hash
+        assert original_hash == saved_compressed_hash
+
+        # Compute file sizes
+        file_size = os.path.getsize(image_path)
+        compressed_file_size = os.path.getsize(image_path_compressed)
+
+        # Assert that compressed file size is smaller than non-compressed file size
+        assert compressed_file_size < file_size
+
 def calculate_image_hash(image): 
     image_bytes = image.tobytes()
     hash_object = hashlib.md5(image_bytes)
-- 
GitLab