Skip to content
Snippets Groups Projects
Commit 2c9d0b67 authored by fima's avatar fima :beers:
Browse files

Merge branch 'load_PIL_stack' into 'main'

Load pil stack

See merge request !91
parents f0552880 5c938b3c
No related branches found
No related tags found
1 merge request!91Load pil stack
# import qim3d.io as io
# import qim3d.gui as gui
# import qim3d.viz as viz
# import qim3d.utils as utils
# import qim3d.models as models
# import qim3d.processing as processing
from . import io, gui, viz, utils, models, processing
import logging import logging
__version__ = '0.3.2' logging.basicConfig(level=logging.ERROR)
from qim3d import io
from qim3d import gui
from qim3d import viz
from qim3d import utils
from qim3d import models
from qim3d import processing
__version__ = "0.3.2"
examples = io.ImgExamples() examples = io.ImgExamples()
io.logger.set_level_info()
...@@ -17,12 +17,22 @@ import struct ...@@ -17,12 +17,22 @@ import struct
from pathlib import Path from pathlib import Path
import dask.array as da import dask.array as da
import dask_image.imread
import h5py import h5py
import nibabel as nib import nibabel as nib
import numpy as np import numpy as np
import olefile import olefile
import pydicom import pydicom
import tifffile import tifffile
# Dask
import dask_image
import dask
from dask import delayed
import dask.array as da
dask.config.set(scheduler="processes") # Dask parallel goes brrrrr
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
import qim3d import qim3d
...@@ -77,6 +87,7 @@ class DataLoader: ...@@ -77,6 +87,7 @@ class DataLoader:
self.contains = kwargs.get("contains", None) self.contains = kwargs.get("contains", None)
self.force_load = kwargs.get("force_load", False) self.force_load = kwargs.get("force_load", False)
self.dim_order = kwargs.get("dim_order", (2, 1, 0)) self.dim_order = kwargs.get("dim_order", (2, 1, 0))
self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
def load_tiff(self, path): def load_tiff(self, path):
"""Load a TIFF file from the specified path. """Load a TIFF file from the specified path.
...@@ -348,6 +359,90 @@ class DataLoader: ...@@ -348,6 +359,90 @@ class DataLoader:
""" """
return np.array(Image.open(path)) return np.array(Image.open(path))
def load_PIL_stack(self, path):
"""Load a stack of PIL files from the specified path.
Args:
path (str): The path to the stack of PIL files.
Returns:
numpy.ndarray or numpy.memmap: The loaded volume.
If 'self.virtual_stack' is True, returns a numpy.memmap object.
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple PIL stacks in the directory
"""
if not self.contains:
raise ValueError(
"Please specify a part of the name that is common for the file stack with the argument 'contains'"
)
# List comprehension to filter files
PIL_stack = [
file
for file in os.listdir(path)
if file.endswith(self.PIL_extensions) and self.contains in file
]
PIL_stack.sort() # Ensure proper ordering
# Check that only one stack in the directory contains the provided string in its name
PIL_stack_only_letters = []
for filename in PIL_stack:
name = os.path.splitext(filename)[0] # Remove file extension
PIL_stack_only_letters.append(
"".join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements
unique_names = list(set(PIL_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded"
)
if self.virtual_stack:
full_paths = [os.path.join(path, file) for file in PIL_stack]
def lazy_loader(path):
with Image.open(path) as img:
return np.array(img)
# Use delayed to load each image with PIL
lazy_images = [delayed(lazy_loader)(path) for path in full_paths]
# Compute the shape of the first image to define the array dimensions
sample_image = np.array(Image.open(full_paths[0]))
image_shape = sample_image.shape
dtype = sample_image.dtype
# Stack the images into a single Dask array
dask_images = [
da.from_delayed(img, shape=image_shape, dtype=dtype) for img in lazy_images
]
stacked = da.stack(dask_images, axis=0)
return stacked
else:
# Generate placeholder volume
first_image = self.load_pil(os.path.join(path, PIL_stack[0]))
vol = np.zeros((len(PIL_stack), *first_image.shape), dtype=first_image.dtype)
# Load file sequence
for idx, file_name in enumerate(PIL_stack):
vol[idx] = self.load_pil(os.path.join(path, file_name))
return vol
# log.info("Found %s file(s)", len(PIL_stack))
# log.info("Loaded shape: %s", vol.shape)
def _load_vgi_metadata(self, path): def _load_vgi_metadata(self, path):
"""Helper functions that loads metadata from a VGI file """Helper functions that loads metadata from a VGI file
...@@ -599,6 +694,9 @@ class DataLoader: ...@@ -599,6 +694,9 @@ class DataLoader:
[f.endswith(".tif") or f.endswith(".tiff") for f in os.listdir(path)] [f.endswith(".tif") or f.endswith(".tiff") for f in os.listdir(path)]
): ):
return self.load_tiff_stack(path) return self.load_tiff_stack(path)
elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]):
return self.load_PIL_stack(path)
else: else:
return self.load_dicom_dir(path) return self.load_dicom_dir(path)
...@@ -758,6 +856,5 @@ class ImgExamples: ...@@ -758,6 +856,5 @@ class ImgExamples:
img_examples_path = Path(qim3d.__file__).parents[0] / "img_examples" img_examples_path = Path(qim3d.__file__).parents[0] / "img_examples"
img_paths = list(img_examples_path.glob("*.tif")) img_paths = list(img_examples_path.glob("*.tif"))
update_dict = {path.stem: load(path) for path in img_paths} update_dict = {path.stem: load(path) for path in img_paths}
self.__dict__.update(update_dict) self.__dict__.update(update_dict)
...@@ -36,7 +36,7 @@ def set_simple_output(): ...@@ -36,7 +36,7 @@ def set_simple_output():
formatter = logging.Formatter("%(message)s") formatter = logging.Formatter("%(message)s")
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger = logging.getLogger("qim3d") logger = logging.getLogger()
logger.handlers = [] logger.handlers = []
logger.addHandler(handler) logger.addHandler(handler)
...@@ -132,5 +132,5 @@ def level(log_level): ...@@ -132,5 +132,5 @@ def level(log_level):
# create the logger # create the logger
log = logging.getLogger("qim3d") log = logging.getLogger("qim3d")
# set_simple_output() #TODO: This used to work, but now it gives duplicated messages. Need to be investigated. set_simple_output() #TODO: This used to work, but now it gives duplicated messages. Need to be investigated.
set_level_warning() #set_level_warning()
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
import torch import torch
from matplotlib import colormaps from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
import dask.array as da
from qim3d.io.logger import log from qim3d.io.logger import log
...@@ -271,14 +272,17 @@ def slices( ...@@ -271,14 +272,17 @@ def slices(
""" """
# Numpy array or Torch tensor input # Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, torch.Tensor)): if not isinstance(vol, (np.ndarray, torch.Tensor, da.core.Array)):
raise ValueError("Input must be a numpy.ndarray or torch.Tensor") raise ValueError("Data type not supported")
if vol.ndim < 3: if vol.ndim < 3:
raise ValueError( raise ValueError(
"The provided object is not a volume as it has less than 3 dimensions." "The provided object is not a volume as it has less than 3 dimensions."
) )
if isinstance(vol, da.core.Array):
vol = vol.compute()
# Ensure axis is a valid choice # Ensure axis is a valid choice
if not (0 <= axis < vol.ndim): if not (0 <= axis < vol.ndim):
raise ValueError( raise ValueError(
...@@ -324,9 +328,11 @@ def slices( ...@@ -324,9 +328,11 @@ def slices(
if nrows == 1: if nrows == 1:
axs = [axs] # Convert to a list for uniformity axs = [axs] # Convert to a list for uniformity
# Convert Torch tensor to NumPy array in order to use the numpy.take method # Convert to NumPy array in order to use the numpy.take method
if isinstance(vol, torch.Tensor): if isinstance(vol, torch.Tensor):
vol = vol.numpy() vol = vol.numpy()
elif isinstance(vol, da.core.Array):
vol = vol.compute()
# Run through each ax of the grid # Run through each ax of the grid
for i, ax_row in enumerate(axs): for i, ax_row in enumerate(axs):
......
...@@ -16,6 +16,8 @@ from qim3d.utils.internal_tools import downscale_img, scale_to_float16 ...@@ -16,6 +16,8 @@ from qim3d.utils.internal_tools import downscale_img, scale_to_float16
def vol( def vol(
img, img,
vmin=None,
vmax=None,
aspectmode="data", aspectmode="data",
show=True, show=True,
save=False, save=False,
...@@ -80,7 +82,7 @@ def vol( ...@@ -80,7 +82,7 @@ def vol(
a = (y1 - y2) / (x1 - x2) a = (y1 - y2) / (x1 - x2)
b = y1 - a * x1 b = y1 - a * x1
samples = int(min(max(a * pixel_count + b, 32), 512)) samples = int(min(max(a * pixel_count + b, 64), 512))
else: else:
samples = int(samples) # make sure it's an integer samples = int(samples) # make sure it's an integer
...@@ -97,15 +99,23 @@ def vol( ...@@ -97,15 +99,23 @@ def vol(
f"Downsampled image for visualization. From {original_shape} to {new_shape}" f"Downsampled image for visualization. From {original_shape} to {new_shape}"
) )
# Set color ranges
color_range = [np.min(img), np.max(img)]
if vmin:
color_range[0] = vmin
if vmax:
color_range[1] = vmax
plt_volume = k3d.volume( plt_volume = k3d.volume(
scale_to_float16(img), scale_to_float16(img),
bounds=( bounds=(
[0, img.shape[0], 0, img.shape[1], 0, img.shape[2]] [0, img.shape[2], 0, img.shape[1], 0, img.shape[0]]
if aspectmode.lower() == "data" if aspectmode.lower() == "data"
else None else None
), ),
color_map=cmap, color_map=cmap,
samples=samples, samples=samples,
color_range=color_range,
) )
plot = k3d.plot(grid_visible=grid_visible, **kwargs) plot = k3d.plot(grid_visible=grid_visible, **kwargs)
plot += plt_volume plot += plt_volume
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment