From 638c6809779b0f75a034c12490586dc44d18a8cc Mon Sep 17 00:00:00 2001 From: Christian Kento Rasmussen <christian.kento@gmail.com> Date: Thu, 25 Jan 2024 11:08:40 +0100 Subject: [PATCH] Implemented save for h5, nifti, vol, png, jpeg, jpg --- qim3d/io/save.py | 58 +++++++++++--- qim3d/tests/io/test_save.py | 148 +++++++++++++++++++++++++++++++++++- 2 files changed, 192 insertions(+), 14 deletions(-) diff --git a/qim3d/io/save.py b/qim3d/io/save.py index aeb85e06..8f4d7012 100644 --- a/qim3d/io/save.py +++ b/qim3d/io/save.py @@ -2,8 +2,10 @@ import os +import h5py import nibabel as nib import numpy as np +import PIL import tifffile from qim3d.io.logger import log @@ -126,21 +128,55 @@ class DataSaver: path (str): The path to save file to data (numpy.ndarray): The data to be saved """ + # No support for compression yet + if self.compression: + raise NotImplementedError("Saving compressed .vol files is not yet supported") # Create custom .vgi metadata file metadata = "" - metadata += "{volume1}\n" - metadata += "[file1]" - metadata += "Size = 1000 1000 766" - metadata += "datatype = float\n" - metadata += "Name = WFW_200kV_6W_1mmSn_6micro_1s.vol" + metadata += "{volume1}\n" # .vgi organization + metadata += "[file1]\n" # .vgi organization + metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format + metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string + metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension # Save metadata - with open(path[:-3] + ".vgi", "w") as f: + with open(path[:-4] + ".vgi", "w") as f: f.write(metadata) # Save data using numpy in binary format - np.save(path[:-3] + ".vol", data) + data.tofile(path[:-4] + ".vol") + + def save_h5(self, path, data): + """ Save data to a HDF5 file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + + with h5py.File(path, "w") as f: + f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None) + + def save_PIL(self, path, data): + """ Save data to a PIL file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + # No support for compression yet + if self.compression and path.endswith(".png"): + raise NotImplementedError("png does not support compression") + elif not self.compression and path.endswith((".jpeg",".jpg")): + raise NotImplementedError("jpeg does not support no compression") + + # Convert to PIL image + img = PIL.Image.fromarray(data) + + # Save image + img.save(path) + def save(self, path, data): """Save data to the given path. @@ -210,10 +246,12 @@ class DataSaver: return self.save_nifti(path, data) elif path.endswith(("TXRM","XRM","TXM")): raise NotImplementedError("Saving TXRM files is not yet supported") - elif path.endswith(("h5")): - raise NotImplementedError("Saving HDF5 files is not yet supported") + elif path.endswith((".h5")): + return self.save_h5(path, data) elif path.endswith((".vol",".vgi")): - raise NotImplementedError("Saving VOL files is not yet supported") + return self.save_vol(path, data) + elif path.endswith((".jpeg",".jpg", ".png")): + return self.save_PIL(path, data) else: raise ValueError("Unsupported file format") # If there is no file extension in the path diff --git a/qim3d/tests/io/test_save.py b/qim3d/tests/io/test_save.py index fa5c2d22..ad615872 100644 --- a/qim3d/tests/io/test_save.py +++ b/qim3d/tests/io/test_save.py @@ -1,10 +1,13 @@ -import qim3d +import hashlib +import os +import re import tempfile + import numpy as np -import os -import hashlib import pytest -import re + +import qim3d + def test_image_exist(): # Create random test image @@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim(): qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim) assert len(os.listdir(path2save))==test_image.shape[dim] +def test_tiff_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.tif") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + +def test_vol_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.vol") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + +def test_pil_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path_png = os.path.join(temp_dir,"test_image.png") + image_path_jpg = os.path.join(temp_dir,"test_image.jpg") + + # Save to temporary directory + qim3d.io.save(image_path_png,original_image) + qim3d.io.save(image_path_jpg,original_image, compression=True) + + # Load from temporary directory + saved_image_png = qim3d.io.load(image_path_png) + saved_image_jpg = qim3d.io.load(image_path_jpg) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_png_hash = calculate_image_hash(saved_image_png) + + # Assert that original image is identical to saved_image + assert original_hash == saved_png_hash + + # jpg is lossy so the hashes will not match, checks that the image is the same size and similar values + assert original_image.shape == saved_image_jpg.shape + +def test_nifti_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.nii") + image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path_compressed,original_image, compression=True) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + saved_image_compressed = qim3d.io.load(image_path_compressed) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + saved_compressed_hash = calculate_image_hash(saved_image_compressed) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + assert original_hash == saved_compressed_hash + + # Compute file sizes + file_size = os.path.getsize(image_path) + compressed_file_size = os.path.getsize(image_path_compressed) + + # Assert that compressed file size is smaller than non-compressed file size + assert compressed_file_size < file_size + + +def test_h5_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.h5") + image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path_compressed,original_image, compression=True) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + saved_image_compressed = qim3d.io.load(image_path_compressed) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + saved_compressed_hash = calculate_image_hash(saved_image_compressed) + + # Assert that original image is identical to saved_image + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + assert original_hash == saved_compressed_hash + + # Compute file sizes + file_size = os.path.getsize(image_path) + compressed_file_size = os.path.getsize(image_path_compressed) + + # Assert that compressed file size is smaller than non-compressed file size + assert compressed_file_size < file_size + def calculate_image_hash(image): image_bytes = image.tobytes() hash_object = hashlib.md5(image_bytes) -- GitLab