diff --git a/qim3d/io/load.py b/qim3d/io/load.py index 33a7c6b13512a1c8b08418d6f9514c874d9e9350..7a5ebe881b6d2267f638309ba6f780f870ecc796 100644 --- a/qim3d/io/load.py +++ b/qim3d/io/load.py @@ -378,10 +378,14 @@ class DataLoader: match dt: case 'float': dt = np.float32 + case 'float32': + dt = np.float32 case 'uint8': dt = np.uint8 case 'unsigned integer': dt = np.uint16 + case 'uint16': + dt = np.uint16 case _: raise ValueError(f"Unsupported data type: {dt}") diff --git a/qim3d/io/save.py b/qim3d/io/save.py index 73cea5c5115f6893a537cd4e98ed66828bde29df..8f4d70129f64b76e15f731cb95bc404a204eb8e8 100644 --- a/qim3d/io/save.py +++ b/qim3d/io/save.py @@ -1,8 +1,13 @@ """Provides functionality for saving data to various file formats.""" import os -import tifffile + +import h5py +import nibabel as nib import numpy as np +import PIL +import tifffile + from qim3d.io.logger import log from qim3d.utils.internal_tools import sizeof, stringify_path @@ -90,6 +95,89 @@ class DataSaver: log.info(f"Total of {no_slices} files saved following the pattern '{pattern_string}'") + def save_nifti(self, path, data): + """ Save data to a NIfTI file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + # Create header + header = nib.Nifti1Header() + header.set_data_dtype(data.dtype) + + # Create NIfTI image object + img = nib.Nifti1Image(data, np.eye(4), header) + + # nib does automatically compress if filetype ends with .gz + if self.compression and not path.endswith(".gz"): + path += ".gz" + log.warning("File extension '.gz' is added since compression is enabled.") + + if not self.compression and path.endswith(".gz"): + path = path[:-3] + log.warning("File extension '.gz' is ignored since compression is disabled.") + + # Save image + nib.save(img, path) + + def save_vol(self, path, data): + """ Save data to a VOL file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + # No support for compression yet + if self.compression: + raise NotImplementedError("Saving compressed .vol files is not yet supported") + + # Create custom .vgi metadata file + metadata = "" + metadata += "{volume1}\n" # .vgi organization + metadata += "[file1]\n" # .vgi organization + metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format + metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string + metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension + + # Save metadata + with open(path[:-4] + ".vgi", "w") as f: + f.write(metadata) + + # Save data using numpy in binary format + data.tofile(path[:-4] + ".vol") + + def save_h5(self, path, data): + """ Save data to a HDF5 file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + + with h5py.File(path, "w") as f: + f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None) + + def save_PIL(self, path, data): + """ Save data to a PIL file to the given path. + + Args: + path (str): The path to save file to + data (numpy.ndarray): The data to be saved + """ + # No support for compression yet + if self.compression and path.endswith(".png"): + raise NotImplementedError("png does not support compression") + elif not self.compression and path.endswith((".jpeg",".jpg")): + raise NotImplementedError("jpeg does not support no compression") + + # Convert to PIL image + img = PIL.Image.fromarray(data) + + # Save image + img.save(path) + + def save(self, path, data): """Save data to the given path. @@ -154,6 +242,16 @@ class DataSaver: if path.endswith((".tif", ".tiff")): return self.save_tiff(path, data) + elif path.endswith((".nii","nii.gz")): + return self.save_nifti(path, data) + elif path.endswith(("TXRM","XRM","TXM")): + raise NotImplementedError("Saving TXRM files is not yet supported") + elif path.endswith((".h5")): + return self.save_h5(path, data) + elif path.endswith((".vol",".vgi")): + return self.save_vol(path, data) + elif path.endswith((".jpeg",".jpg", ".png")): + return self.save_PIL(path, data) else: raise ValueError("Unsupported file format") # If there is no file extension in the path diff --git a/qim3d/tests/io/test_save.py b/qim3d/tests/io/test_save.py index fa5c2d22c47d19aa101bf5cee33fa884c4b1b132..ad61587213c36a52e569f88726a7fddef3257ebb 100644 --- a/qim3d/tests/io/test_save.py +++ b/qim3d/tests/io/test_save.py @@ -1,10 +1,13 @@ -import qim3d +import hashlib +import os +import re import tempfile + import numpy as np -import os -import hashlib import pytest -import re + +import qim3d + def test_image_exist(): # Create random test image @@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim(): qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim) assert len(os.listdir(path2save))==test_image.shape[dim] +def test_tiff_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.tif") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + +def test_vol_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.vol") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + +def test_pil_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path_png = os.path.join(temp_dir,"test_image.png") + image_path_jpg = os.path.join(temp_dir,"test_image.jpg") + + # Save to temporary directory + qim3d.io.save(image_path_png,original_image) + qim3d.io.save(image_path_jpg,original_image, compression=True) + + # Load from temporary directory + saved_image_png = qim3d.io.load(image_path_png) + saved_image_jpg = qim3d.io.load(image_path_jpg) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_png_hash = calculate_image_hash(saved_image_png) + + # Assert that original image is identical to saved_image + assert original_hash == saved_png_hash + + # jpg is lossy so the hashes will not match, checks that the image is the same size and similar values + assert original_image.shape == saved_image_jpg.shape + +def test_nifti_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.nii") + image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path_compressed,original_image, compression=True) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + saved_image_compressed = qim3d.io.load(image_path_compressed) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + saved_compressed_hash = calculate_image_hash(saved_image_compressed) + + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + assert original_hash == saved_compressed_hash + + # Compute file sizes + file_size = os.path.getsize(image_path) + compressed_file_size = os.path.getsize(image_path_compressed) + + # Assert that compressed file size is smaller than non-compressed file size + assert compressed_file_size < file_size + + +def test_h5_save_load(): + # Create random test image + original_image = qim3d.examples.blobs_256x256x256 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir,"test_image.h5") + image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz") + + # Save to temporary directory + qim3d.io.save(image_path,original_image) + qim3d.io.save(image_path_compressed,original_image, compression=True) + + # Load from temporary directory + saved_image = qim3d.io.load(image_path) + saved_image_compressed = qim3d.io.load(image_path_compressed) + + # Get hashes + original_hash = calculate_image_hash(original_image) + saved_hash = calculate_image_hash(saved_image) + saved_compressed_hash = calculate_image_hash(saved_image_compressed) + + # Assert that original image is identical to saved_image + # Assert that original image is identical to saved_image + assert original_hash == saved_hash + assert original_hash == saved_compressed_hash + + # Compute file sizes + file_size = os.path.getsize(image_path) + compressed_file_size = os.path.getsize(image_path_compressed) + + # Assert that compressed file size is smaller than non-compressed file size + assert compressed_file_size < file_size + def calculate_image_hash(image): image_bytes = image.tobytes() hash_object = hashlib.md5(image_bytes)