diff --git a/qim3d/io/load.py b/qim3d/io/load.py index f2dd9fd398bed79a5da949eb0ce0409b3947968a..620a2eb81d10a0913677f3505a256c1370f1035b 100644 --- a/qim3d/io/load.py +++ b/qim3d/io/load.py @@ -4,11 +4,12 @@ import os import difflib import tifffile import h5py +import nibabel as nib import numpy as np from pathlib import Path import qim3d from qim3d.io.logger import log -from qim3d.utils.internal_tools import sizeof +from qim3d.utils.internal_tools import sizeof, stringify_path from qim3d.utils.system import Memory @@ -20,7 +21,7 @@ class DataLoader: dataset_name (str): Specifies the name of the dataset to be loaded (only relevant for HDF5 files) return_metadata (bool): Specifies if metadata is returned or not - (only relevant for HDF5 and TXRM/TXM/XRM files) + (only relevant for HDF5, TXRM/TXM/XRM and NIfTI files) contains (str): Specifies a part of the name that is common for the TIFF file stack to be loaded (only relevant for TIFF stacks) @@ -80,7 +81,7 @@ class DataLoader: Returns: numpy.ndarray or tuple: The loaded volume as a NumPy array. - If 'return_metadata' is True, returns a tuple (volume, metadata). + If 'self.return_metadata' is True, returns a tuple (volume, metadata). Raises: ValueError: If the specified dataset_name is not found or is invalid. @@ -214,7 +215,7 @@ class DataLoader: Returns: numpy.ndarray or tuple: The loaded volume as a NumPy array. - If 'return_metadata' is True, returns a tuple (volume, metadata). + If 'self.return_metadata' is True, returns a tuple (volume, metadata). Raises: ValueError: If the dxchange library is not installed @@ -243,13 +244,41 @@ class DataLoader: return vol, metadata else: return vol + + def load_nifti(self,path): + """Load a NIfTI file from the specified path. + + Args: + path (str): The path to the NIfTI file. + + Returns: + numpy.ndarray or tuple: The loaded volume as a NumPy array. + If 'self.return_metadata' is True, returns a tuple (volume, metadata). + """ + + data = nib.load(path) + + # Get image array proxy + vol = data.dataobj + + if not self.virtual_stack: + vol = np.asarray(vol,dtype=data.get_data_dtype()) + + if self.return_metadata: + metadata = {} + for key in data.header: + metadata[key]=data.header[key] + + return vol, metadata + else: + return vol def load(self, path): """ Load a file or directory based on the given path. Args: - path (str): The path to the file or directory. + path (str or os.PathLike): The path to the file or directory. Returns: numpy.ndarray: The loaded volume as a NumPy array. @@ -263,6 +292,8 @@ class DataLoader: data = loader.load("image.tif") """ + path = stringify_path(path) + # Load a file if os.path.isfile(path): # Choose the loader based on the file extension @@ -272,6 +303,8 @@ class DataLoader: return self.load_h5(path) elif path.endswith((".txrm", ".txm", ".xrm")): return self.load_txrm(path) + elif path.endswith((".nii",".nii.gz")): + return self.load_nifti(path) else: raise ValueError("Unsupported file format") @@ -312,7 +345,7 @@ def load( Load data from the specified file or directory. Args: - path (str): The path to the file or directory. + path (str or os.PathLike): The path to the file or directory. virtual_stack (bool, optional): Specifies whether to use virtual stack when loading files. Default is False. dataset_name (str, optional): Specifies the name of the dataset to be loaded @@ -368,4 +401,4 @@ class ImgExamples: # Generate loader for each image found for idx, name in enumerate(img_names): - exec(f"self.{name} = qim3d.io.load(path = str(img_paths[idx]))") + exec(f"self.{name} = qim3d.io.load(path = img_paths[idx])") diff --git a/qim3d/tests/io/test_load.py b/qim3d/tests/io/test_load.py index c957bd1f3d9b1dea367448de9f71c3705e42b85d..992400996bbac166e1c1781f6241d1d805a4282e 100644 --- a/qim3d/tests/io/test_load.py +++ b/qim3d/tests/io/test_load.py @@ -8,7 +8,7 @@ import pytest vol = qim3d.examples.blobs_256x256 # Ceate memory map to blobs -blobs_path = str(Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif") +blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif" vol_memmap = qim3d.io.load(blobs_path,virtual_stack=True) def test_load_shape(): @@ -28,9 +28,7 @@ def test_invalid_path(): def test_did_you_mean(): # Remove last two characters from the path - blobs_path_misspelled = blobs_path[:-2] + blobs_path_misspelled = str(blobs_path)[:-2] with pytest.raises(ValueError,match=f"Invalid path.\nDid you mean '{blobs_path}'?"): - qim3d.io.load(blobs_path_misspelled) - - + qim3d.io.load(blobs_path_misspelled) \ No newline at end of file diff --git a/qim3d/tests/utils/test_internal_tools.py b/qim3d/tests/utils/test_internal_tools.py index c41f2aab0ae0f92bfe2f47fafc31e355f93d7312..bb239ca703669273933a888638dd7088fd0a7078 100644 --- a/qim3d/tests/utils/test_internal_tools.py +++ b/qim3d/tests/utils/test_internal_tools.py @@ -1,6 +1,7 @@ import qim3d import os import re +from pathlib import Path def test_mock_plot(): @@ -35,3 +36,18 @@ def test_get_local_ip(): local_ip = qim3d.utils.internal_tools.get_local_ip() assert validate_ip(local_ip) == True + +def test_stringify_path1(): + """Test that the function converts os.PathLike objects to strings + """ + blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif" + + assert str(blobs_path) == qim3d.utils.internal_tools.stringify_path(blobs_path) + +def test_stringify_path2(): + """Test that the function returns input unchanged if input is a string + """ + # Create test_path + test_path = os.path.join('this','path','doesnt','exist.tif') + + assert test_path == qim3d.utils.internal_tools.stringify_path(test_path) diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index 838ed93e1dc502dad883ead1a5103fed37cbf009..e31001266288c405f3fb1d9092341d4e51cf1e98 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import matplotlib import numpy as np import socket +import os @@ -175,4 +176,11 @@ def is_server_running(ip, port): s.shutdown(2) return True except: - return False \ No newline at end of file + return False + +def stringify_path(path): + """Converts an os.PathLike object to a string + """ + if isinstance(path,os.PathLike): + path = path.__fspath__() + return path \ No newline at end of file