Skip to content
Snippets Groups Projects
Commit 638c6809 authored by Christian Kento Rasmussen's avatar Christian Kento Rasmussen
Browse files

Implemented save for h5, nifti, vol, png, jpeg, jpg

parent 21fd55d7
No related tags found
1 merge request!45Save files function
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import os import os
import h5py
import nibabel as nib import nibabel as nib
import numpy as np import numpy as np
import PIL
import tifffile import tifffile
from qim3d.io.logger import log from qim3d.io.logger import log
...@@ -126,21 +128,55 @@ class DataSaver: ...@@ -126,21 +128,55 @@ class DataSaver:
path (str): The path to save file to path (str): The path to save file to
data (numpy.ndarray): The data to be saved data (numpy.ndarray): The data to be saved
""" """
# No support for compression yet
if self.compression:
raise NotImplementedError("Saving compressed .vol files is not yet supported")
# Create custom .vgi metadata file # Create custom .vgi metadata file
metadata = "" metadata = ""
metadata += "{volume1}\n" metadata += "{volume1}\n" # .vgi organization
metadata += "[file1]" metadata += "[file1]\n" # .vgi organization
metadata += "Size = 1000 1000 766" metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format
metadata += "datatype = float\n" metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
metadata += "Name = WFW_200kV_6W_1mmSn_6micro_1s.vol" metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension
# Save metadata # Save metadata
with open(path[:-3] + ".vgi", "w") as f: with open(path[:-4] + ".vgi", "w") as f:
f.write(metadata) f.write(metadata)
# Save data using numpy in binary format # Save data using numpy in binary format
np.save(path[:-3] + ".vol", data) data.tofile(path[:-4] + ".vol")
def save_h5(self, path, data):
""" Save data to a HDF5 file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
with h5py.File(path, "w") as f:
f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None)
def save_PIL(self, path, data):
""" Save data to a PIL file to the given path.
Args:
path (str): The path to save file to
data (numpy.ndarray): The data to be saved
"""
# No support for compression yet
if self.compression and path.endswith(".png"):
raise NotImplementedError("png does not support compression")
elif not self.compression and path.endswith((".jpeg",".jpg")):
raise NotImplementedError("jpeg does not support no compression")
# Convert to PIL image
img = PIL.Image.fromarray(data)
# Save image
img.save(path)
def save(self, path, data): def save(self, path, data):
"""Save data to the given path. """Save data to the given path.
...@@ -210,10 +246,12 @@ class DataSaver: ...@@ -210,10 +246,12 @@ class DataSaver:
return self.save_nifti(path, data) return self.save_nifti(path, data)
elif path.endswith(("TXRM","XRM","TXM")): elif path.endswith(("TXRM","XRM","TXM")):
raise NotImplementedError("Saving TXRM files is not yet supported") raise NotImplementedError("Saving TXRM files is not yet supported")
elif path.endswith(("h5")): elif path.endswith((".h5")):
raise NotImplementedError("Saving HDF5 files is not yet supported") return self.save_h5(path, data)
elif path.endswith((".vol",".vgi")): elif path.endswith((".vol",".vgi")):
raise NotImplementedError("Saving VOL files is not yet supported") return self.save_vol(path, data)
elif path.endswith((".jpeg",".jpg", ".png")):
return self.save_PIL(path, data)
else: else:
raise ValueError("Unsupported file format") raise ValueError("Unsupported file format")
# If there is no file extension in the path # If there is no file extension in the path
......
import qim3d import hashlib
import os
import re
import tempfile import tempfile
import numpy as np import numpy as np
import os
import hashlib
import pytest import pytest
import re
import qim3d
def test_image_exist(): def test_image_exist():
# Create random test image # Create random test image
...@@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim(): ...@@ -221,6 +224,143 @@ def test_tiff_stack_slicing_dim():
qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim) qim3d.io.save(path2save,test_image,basename='test',sliced_dim=dim)
assert len(os.listdir(path2save))==test_image.shape[dim] assert len(os.listdir(path2save))==test_image.shape[dim]
def test_tiff_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.tif")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
def test_vol_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.vol")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
def test_pil_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path_png = os.path.join(temp_dir,"test_image.png")
image_path_jpg = os.path.join(temp_dir,"test_image.jpg")
# Save to temporary directory
qim3d.io.save(image_path_png,original_image)
qim3d.io.save(image_path_jpg,original_image, compression=True)
# Load from temporary directory
saved_image_png = qim3d.io.load(image_path_png)
saved_image_jpg = qim3d.io.load(image_path_jpg)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_png_hash = calculate_image_hash(saved_image_png)
# Assert that original image is identical to saved_image
assert original_hash == saved_png_hash
# jpg is lossy so the hashes will not match, checks that the image is the same size and similar values
assert original_image.shape == saved_image_jpg.shape
def test_nifti_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.nii")
image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
qim3d.io.save(image_path_compressed,original_image, compression=True)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
saved_image_compressed = qim3d.io.load(image_path_compressed)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
saved_compressed_hash = calculate_image_hash(saved_image_compressed)
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
assert original_hash == saved_compressed_hash
# Compute file sizes
file_size = os.path.getsize(image_path)
compressed_file_size = os.path.getsize(image_path_compressed)
# Assert that compressed file size is smaller than non-compressed file size
assert compressed_file_size < file_size
def test_h5_save_load():
# Create random test image
original_image = qim3d.examples.blobs_256x256x256
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir,"test_image.h5")
image_path_compressed = os.path.join(temp_dir,"test_image_compressed.nii.gz")
# Save to temporary directory
qim3d.io.save(image_path,original_image)
qim3d.io.save(image_path_compressed,original_image, compression=True)
# Load from temporary directory
saved_image = qim3d.io.load(image_path)
saved_image_compressed = qim3d.io.load(image_path_compressed)
# Get hashes
original_hash = calculate_image_hash(original_image)
saved_hash = calculate_image_hash(saved_image)
saved_compressed_hash = calculate_image_hash(saved_image_compressed)
# Assert that original image is identical to saved_image
# Assert that original image is identical to saved_image
assert original_hash == saved_hash
assert original_hash == saved_compressed_hash
# Compute file sizes
file_size = os.path.getsize(image_path)
compressed_file_size = os.path.getsize(image_path_compressed)
# Assert that compressed file size is smaller than non-compressed file size
assert compressed_file_size < file_size
def calculate_image_hash(image): def calculate_image_hash(image):
image_bytes = image.tobytes() image_bytes = image.tobytes()
hash_object = hashlib.md5(image_bytes) hash_object = hashlib.md5(image_bytes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment