Skip to content
Snippets Groups Projects
Commit 2c9d0b67 authored by fima's avatar fima :beers:
Browse files

Merge branch 'load_PIL_stack' into 'main'

Load pil stack

See merge request !91
parents f0552880 5c938b3c
No related branches found
No related tags found
1 merge request!91Load pil stack
# import qim3d.io as io
# import qim3d.gui as gui
# import qim3d.viz as viz
# import qim3d.utils as utils
# import qim3d.models as models
# import qim3d.processing as processing
from . import io, gui, viz, utils, models, processing
import logging
__version__ = '0.3.2'
logging.basicConfig(level=logging.ERROR)
from qim3d import io
from qim3d import gui
from qim3d import viz
from qim3d import utils
from qim3d import models
from qim3d import processing
__version__ = "0.3.2"
examples = io.ImgExamples()
io.logger.set_level_info()
......@@ -17,12 +17,22 @@ import struct
from pathlib import Path
import dask.array as da
import dask_image.imread
import h5py
import nibabel as nib
import numpy as np
import olefile
import pydicom
import tifffile
# Dask
import dask_image
import dask
from dask import delayed
import dask.array as da
dask.config.set(scheduler="processes") # Dask parallel goes brrrrr
from PIL import Image, UnidentifiedImageError
import qim3d
......@@ -77,6 +87,7 @@ class DataLoader:
self.contains = kwargs.get("contains", None)
self.force_load = kwargs.get("force_load", False)
self.dim_order = kwargs.get("dim_order", (2, 1, 0))
self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
def load_tiff(self, path):
"""Load a TIFF file from the specified path.
......@@ -348,6 +359,90 @@ class DataLoader:
"""
return np.array(Image.open(path))
def load_PIL_stack(self, path):
"""Load a stack of PIL files from the specified path.
Args:
path (str): The path to the stack of PIL files.
Returns:
numpy.ndarray or numpy.memmap: The loaded volume.
If 'self.virtual_stack' is True, returns a numpy.memmap object.
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple PIL stacks in the directory
"""
if not self.contains:
raise ValueError(
"Please specify a part of the name that is common for the file stack with the argument 'contains'"
)
# List comprehension to filter files
PIL_stack = [
file
for file in os.listdir(path)
if file.endswith(self.PIL_extensions) and self.contains in file
]
PIL_stack.sort() # Ensure proper ordering
# Check that only one stack in the directory contains the provided string in its name
PIL_stack_only_letters = []
for filename in PIL_stack:
name = os.path.splitext(filename)[0] # Remove file extension
PIL_stack_only_letters.append(
"".join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements
unique_names = list(set(PIL_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the stack matches multiple stacks: {unique_names}.\nPlease provide a string that is unique for the image stack that is intended to be loaded"
)
if self.virtual_stack:
full_paths = [os.path.join(path, file) for file in PIL_stack]
def lazy_loader(path):
with Image.open(path) as img:
return np.array(img)
# Use delayed to load each image with PIL
lazy_images = [delayed(lazy_loader)(path) for path in full_paths]
# Compute the shape of the first image to define the array dimensions
sample_image = np.array(Image.open(full_paths[0]))
image_shape = sample_image.shape
dtype = sample_image.dtype
# Stack the images into a single Dask array
dask_images = [
da.from_delayed(img, shape=image_shape, dtype=dtype) for img in lazy_images
]
stacked = da.stack(dask_images, axis=0)
return stacked
else:
# Generate placeholder volume
first_image = self.load_pil(os.path.join(path, PIL_stack[0]))
vol = np.zeros((len(PIL_stack), *first_image.shape), dtype=first_image.dtype)
# Load file sequence
for idx, file_name in enumerate(PIL_stack):
vol[idx] = self.load_pil(os.path.join(path, file_name))
return vol
# log.info("Found %s file(s)", len(PIL_stack))
# log.info("Loaded shape: %s", vol.shape)
def _load_vgi_metadata(self, path):
"""Helper functions that loads metadata from a VGI file
......@@ -599,6 +694,9 @@ class DataLoader:
[f.endswith(".tif") or f.endswith(".tiff") for f in os.listdir(path)]
):
return self.load_tiff_stack(path)
elif any([f.endswith(self.PIL_extensions) for f in os.listdir(path)]):
return self.load_PIL_stack(path)
else:
return self.load_dicom_dir(path)
......@@ -758,6 +856,5 @@ class ImgExamples:
img_examples_path = Path(qim3d.__file__).parents[0] / "img_examples"
img_paths = list(img_examples_path.glob("*.tif"))
update_dict = {path.stem: load(path) for path in img_paths}
self.__dict__.update(update_dict)
......@@ -36,7 +36,7 @@ def set_simple_output():
formatter = logging.Formatter("%(message)s")
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger("qim3d")
logger = logging.getLogger()
logger.handlers = []
logger.addHandler(handler)
......@@ -132,5 +132,5 @@ def level(log_level):
# create the logger
log = logging.getLogger("qim3d")
# set_simple_output() #TODO: This used to work, but now it gives duplicated messages. Need to be investigated.
set_level_warning()
set_simple_output() #TODO: This used to work, but now it gives duplicated messages. Need to be investigated.
#set_level_warning()
......@@ -11,6 +11,7 @@ import numpy as np
import torch
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
import dask.array as da
from qim3d.io.logger import log
......@@ -271,14 +272,17 @@ def slices(
"""
# Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, torch.Tensor)):
raise ValueError("Input must be a numpy.ndarray or torch.Tensor")
if not isinstance(vol, (np.ndarray, torch.Tensor, da.core.Array)):
raise ValueError("Data type not supported")
if vol.ndim < 3:
raise ValueError(
"The provided object is not a volume as it has less than 3 dimensions."
)
if isinstance(vol, da.core.Array):
vol = vol.compute()
# Ensure axis is a valid choice
if not (0 <= axis < vol.ndim):
raise ValueError(
......@@ -324,9 +328,11 @@ def slices(
if nrows == 1:
axs = [axs] # Convert to a list for uniformity
# Convert Torch tensor to NumPy array in order to use the numpy.take method
# Convert to NumPy array in order to use the numpy.take method
if isinstance(vol, torch.Tensor):
vol = vol.numpy()
elif isinstance(vol, da.core.Array):
vol = vol.compute()
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
......
......@@ -16,6 +16,8 @@ from qim3d.utils.internal_tools import downscale_img, scale_to_float16
def vol(
img,
vmin=None,
vmax=None,
aspectmode="data",
show=True,
save=False,
......@@ -80,7 +82,7 @@ def vol(
a = (y1 - y2) / (x1 - x2)
b = y1 - a * x1
samples = int(min(max(a * pixel_count + b, 32), 512))
samples = int(min(max(a * pixel_count + b, 64), 512))
else:
samples = int(samples) # make sure it's an integer
......@@ -97,15 +99,23 @@ def vol(
f"Downsampled image for visualization. From {original_shape} to {new_shape}"
)
# Set color ranges
color_range = [np.min(img), np.max(img)]
if vmin:
color_range[0] = vmin
if vmax:
color_range[1] = vmax
plt_volume = k3d.volume(
scale_to_float16(img),
bounds=(
[0, img.shape[0], 0, img.shape[1], 0, img.shape[2]]
[0, img.shape[2], 0, img.shape[1], 0, img.shape[0]]
if aspectmode.lower() == "data"
else None
),
color_map=cmap,
samples=samples,
color_range=color_range,
)
plot = k3d.plot(grid_visible=grid_visible, **kwargs)
plot += plt_volume
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment