Skip to content
Snippets Groups Projects
Commit 40da2311 authored by fima's avatar fima :beers:
Browse files

Merge branch 'save_files_function' into 'main'

Save files function

See merge request !45
parents 1c8ccdf8 f6aeb4d8
No related branches found
No related tags found
1 merge request!45Save files function
......@@ -378,10 +378,14 @@ class DataLoader:
match dt:
case 'float':
dt = np.float32
case 'float32':
dt = np.float32
case 'uint8':
dt = np.uint8
case 'unsigned integer':
dt = np.uint16
case 'uint16':
dt = np.uint16
case _:
raise ValueError(f"Unsupported data type: {dt}")
......
"""Provides functionality for saving data to various file formats."""
import os
import tifffile
import h5py
import nibabel as nib
import numpy as np
import PIL
import tifffile
from qim3d.io.logger import log
from qim3d.utils.internal_tools import sizeof, stringify_path
......@@ -90,6 +95,89 @@ class DataSaver:
log.info(f"Total of {no_slices} files saved following the pattern '{pattern_string}'")
def save_nifti(self, path, data):
""" Save data to a NIfTI file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# Create header
header = nib.Nifti1Header()
header.set_data_dtype(data.dtype)
# Create NIfTI image object
img = nib.Nifti1Image(data, np.eye(4), header)
# nib does automatically compress if filetype ends with .gz
if self.compression and not path.endswith(".gz"):
path += ".gz"
log.warning("File extension '.gz' is added since compression is enabled.")
if not self.compression and path.endswith(".gz"):
path = path[:-3]
log.warning("File extension '.gz' is ignored since compression is disabled.")
# Save image
nib.save(img, path)
def save_vol(self, path, data):
""" Save data to a VOL file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# No support for compression yet
if self.compression:
raise NotImplementedError("Saving compressed .vol files is not yet supported")
# Create custom .vgi metadata file
metadata = ""
metadata += "{volume1}\n" # .vgi organization
metadata += "[file1]\n" # .vgi organization
metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format
metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension
# Save metadata
with open(path[:-4] + ".vgi", "w") as f:
f.write(metadata)
# Save data using numpy in binary format
data.tofile(path[:-4] + ".vol")
def save_h5(self, path, data):
""" Save data to a HDF5 file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
with h5py.File(path, "w") as f:
f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None)
def save_PIL(self, path, data):
""" Save data to a PIL file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# No support for compression yet
if self.compression and path.endswith(".png"):
raise NotImplementedError("png does not support compression")
elif not self.compression and path.endswith((".jpeg",".jpg")):
raise NotImplementedError("jpeg does not support no compression")
# Convert to PIL image
img = PIL.Image.fromarray(data)
# Save image
img.save(path)
def save(self, path, data):
"""Save data to the given path.
......@@ -154,6 +242,16 @@ class DataSaver:
if path.endswith((".tif", ".tiff")):
return self.save_tiff(path, data)
elif path.endswith((".nii","nii.gz")):
return self.save_nifti(path, data)
elif path.endswith(("TXRM","XRM","TXM")):
raise NotImplementedError("Saving TXRM files is not yet supported")
elif path.endswith((".h5")):
return self.save_h5(path, data)
elif path.endswith((".vol",".vgi")):
return self.save_vol(path, data)
elif path.endswith((".jpeg",".jpg", ".png")):
return self.save_PIL(path, data)
else:
raise ValueError("Unsupported file format")
# If there is no file extension in the path
......
import qim3d
import hashlib
import os
import re
import tempfile
import numpy as np
import os
import hashlib
import pytest
import re
import qim3d
def test_image_exist():
# Create random test image
......@@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim():
qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim)
assert len(os.listdir(path2save))==test_image.shape[dim]
def test_tiff_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.tif")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
def test_vol_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.vol")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
def test_pil_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path_png = os.path.join(temp_dir,"test_image.png")
image_path_jpg = os.path.join(temp_dir,"test_image.jpg")
# Save to temporary directory
qim3d.io.save(image_path_png,original_image)
qim3d.io.save(image_path_jpg,original_image, compression=True)
# Load from temporary directory
saved_image_png = qim3d.io.load(image_path_png)
saved_image_jpg = qim3d.io.load(image_path_jpg)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_png_hash = calculate_image_hash(saved_image_png)
# Assert that original image is identical to saved_image
assert original_hash == saved_png_hash
# jpg is lossy so the hashes will not match, checks that the image is the same size and similar values
assert original_image.shape == saved_image_jpg.shape
def test_nifti_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.nii")
image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
qim3d.io.save(image_path_compressed,original_image, compression=True)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
saved_image_compressed = qim3d.io.load(image_path_compressed)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
saved_compressed_hash = calculate_image_hash(saved_image_compressed)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
assert original_hash == saved_compressed_hash
# Compute file sizes
file_size = os.path.getsize(image_path)
compressed_file_size = os.path.getsize(image_path_compressed)
# Assert that compressed file size is smaller than non-compressed file size
assert compressed_file_size < file_size
def test_h5_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.h5")
image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
qim3d.io.save(image_path_compressed,original_image, compression=True)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
saved_image_compressed = qim3d.io.load(image_path_compressed)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
saved_compressed_hash = calculate_image_hash(saved_image_compressed)
# Assert that original image is identical to saved_image
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
assert original_hash == saved_compressed_hash
# Compute file sizes
file_size = os.path.getsize(image_path)
compressed_file_size = os.path.getsize(image_path_compressed)
# Assert that compressed file size is smaller than non-compressed file size
assert compressed_file_size < file_size
def calculate_image_hash(image):
image_bytes = image.tobytes()
hash_object = hashlib.md5(image_bytes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment