Skip to content
Snippets Groups Projects
Commit 5d1ba6f3 authored by s193396's avatar s193396 Committed by fima
Browse files

Synthetic collection generation

parent db5ac596
No related branches found
No related tags found
1 merge request!108Synthetic collection generation
Showing
with 651 additions and 60 deletions
docs/assets/screenshots/synthetic_blob_slices.png

61.2 KiB

docs/assets/screenshots/synthetic_collection_cc.gif

525 KiB

docs/assets/screenshots/synthetic_collection_default.gif

620 KiB

docs/assets/screenshots/synthetic_collection_default_labels.gif

312 KiB

# Generating synthetic data
The `qim3d` library provides a set of methods for generating volumes consisting of a single synthetic blob or a collection of multiple synthetic blobs.
::: qim3d.generate
options:
members:
- blob
- collection
\ No newline at end of file
......@@ -5,7 +5,6 @@ Here, we provide functionalities designed specifically for 3D image analysis and
::: qim3d.processing
options:
members:
- test_blob_detection
- blob_detection
- structure_tensor
- local_thickness
......@@ -26,5 +25,5 @@ Here, we provide functionalities designed specifically for 3D image analysis and
members:
- remove_background
- watershed
- edge_fade
- fade_mask
- overlay_rgb_images
......@@ -2,12 +2,6 @@
A set of tools to ease managment of the system, with the common needs for large data in mind.
::: qim3d.utils.img
options:
members:
- generate_volume
- overlay_rgb_images
::: qim3d.utils.system
options:
members:
......
......@@ -10,6 +10,7 @@ nav:
- qim3d: index.md
- Input & Output: io.md
- Data Generation: generate.md
- Processing: processing.md
- Visualization: viz.md
- Utils: utils.md
......
......@@ -15,6 +15,7 @@ from . import gui
from . import viz
from . import utils
from . import processing
from . import generate
# commented out to avoid torch import
# from . import models
......
from .blob_ import blob
from .collection_ import collection
\ No newline at end of file
......@@ -2,65 +2,17 @@ import numpy as np
import scipy.ndimage
from noise import pnoise3
def overlay_rgb_images(background, foreground, alpha=0.5):
"""Overlay a RGB foreground onto an RGB background using alpha blending.
Args:
background (numpy.ndarray): The background RGB image.
foreground (numpy.ndarray): The foreground RGB image (usually masks).
alpha (float, optional): The alpha value for blending. Defaults to 0.5.
Returns:
numpy.ndarray: The composite RGB image with overlaid foreground.
Raises:
ValueError: If input images have different shapes.
Note:
- The function performs alpha blending to overlay the foreground onto the background.
- It ensures that the background and foreground have the same shape before blending.
- It calculates the maximum projection of the foreground and blends them onto the background.
- Brightness outside the foreground is adjusted to maintain consistency with the background.
"""
# Igonore alpha in case its there
background = background[..., :3]
foreground = foreground[..., :3]
# Ensure both images have the same shape
if background.shape != foreground.shape:
raise ValueError("Input images must have the same shape")
# Perform alpha blending
foreground_max_projection = np.amax(foreground, axis=2)
foreground_max_projection = np.stack((foreground_max_projection,) * 3, axis=-1)
# Normalize if we have something
if np.max(foreground_max_projection) > 0:
foreground_max_projection = foreground_max_projection / np.max(
foreground_max_projection
)
composite = background * (1 - alpha) + foreground * alpha
composite = np.clip(composite, 0, 255).astype("uint8")
# Adjust brightness outside foreground
composite = composite + (background * (1 - alpha)) * (1 - foreground_max_projection)
return composite.astype("uint8")
def generate_volume(
base_shape=(128, 128, 128),
final_shape=(128, 128, 128),
noise_scale=0.05,
order=1,
gamma=1.0,
max_value=255,
threshold=0.5,
dtype="uint8",
):
def blob(
base_shape: tuple = (128, 128, 128),
final_shape: tuple = (128, 128, 128),
noise_scale: float = 0.05,
order: int = 1,
gamma: int = 1.0,
max_value: int = 255,
threshold: float = 0.5,
smooth_borders: bool = False,
dtype: str = "uint8",
) -> np.ndarray:
"""
Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction.
......@@ -72,55 +24,65 @@ def generate_volume(
gamma (float, optional): Gamma correction factor. Defaults to 1.0.
max_value (int, optional): Maximum value for the volume intensity. Defaults to 255.
threshold (float, optional): Threshold value for clipping low intensity values. Defaults to 0.5.
smooth_borders (bool, optional): Flag for automatic computation of the threshold value to ensure a blob with no straight edges. If True, the `threshold` parameter is ignored. Defaults to False.
dtype (str, optional): Desired data type of the output volume. Defaults to "uint8".
Returns:
numpy.ndarray: Generated 3D volume with specified parameters.
synthetic_blob (numpy.ndarray): Generated 3D volume with specified parameters.
Raises:
ValueError: If `final_shape` is not a tuple or does not have three elements.
TypeError: If `final_shape` is not a tuple or does not have three elements.
ValueError: If `dtype` is not a valid numpy number type.
Example:
```python
import qim3d
vol = qim3d.utils.generate_volume(noise_scale=0.05, threshold=0.4)
qim3d.viz.slices(vol, vmin=0, vmax=255, n_slices=15)
# Generate synthetic blob
synthetic_blob = qim3d.generate.blob(noise_scale = 0.05)
# Visualize slices
qim3d.viz.slices(synthetic_blob, vmin = 0, vmax = 255, n_slices = 15)
```
![generate_volume](assets/screenshots/generate_volume.png)
![synthetic_blob](assets/screenshots/synthetic_blob_slices.png)
```python
qim3d.viz.vol(vol)
# Visualize 3D volume
qim3d.viz.vol(synthetic_blob)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_volume.html" width="100%" height="500" frameborder="0"></iframe>
<iframe src="https://platform.qim.dk/k3d/synthetic_blob.html" width="100%" height="500" frameborder="0"></iframe>
"""
if not isinstance(final_shape, tuple) or len(final_shape) != 3:
raise ValueError("Size must be a tuple")
raise TypeError("Size must be a tuple of 3 dimensions")
if not np.issubdtype(dtype, np.number):
raise ValueError("Invalid data type")
# Define the dimensions of the shape for generating Perlin noise
# Initialize the 3D array for the shape
volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32)
# Fill the 3D array with values from the Perlin noise function
for i in range(base_shape[0]):
for j in range(base_shape[1]):
for k in range(base_shape[2]):
# Generate grid of coordinates
z, y, x = np.indices(base_shape)
# Calculate the distance from the center of the shape
dist = np.sqrt(
(i - base_shape[0] / 2) ** 2
+ (j - base_shape[1] / 2) ** 2
+ (k - base_shape[2] / 2) ** 2
) / np.sqrt(3 * ((base_shape[0] / 2) ** 2))
center = np.array(base_shape) / 2
dist = np.sqrt((z - center[0])**2 +
(y - center[1])**2 +
(x - center[2])**2)
dist /= np.sqrt(3 * (center[0]**2))
# Generate Perlin noise and adjust the values based on the distance from the center
# This creates a spherical shape with noise
volume[i][j][k] = (
1 + pnoise3(i * noise_scale, j * noise_scale, k * noise_scale)
) * (1 - dist)
vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input
noise = vectorized_pnoise3(z.flatten() * noise_scale,
y.flatten() * noise_scale,
x.flatten() * noise_scale
).reshape(base_shape)
volume = (1 + noise) * (1 - dist)
# Normalize
volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
......@@ -131,7 +93,18 @@ def generate_volume(
# Scale the volume to the maximum value
volume = volume * max_value
# clip the low values of the volume to create a coherent volume
if smooth_borders:
# Maximum value among the six sides of the 3D volume
max_border_value = np.max([
np.max(volume[0, :, :]), np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]), np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]), np.max(volume[:, :, -1])
])
# Compute threshold such that there will be no straight cuts in the blob
threshold = max_border_value / max_value
# Clip the low values of the volume to create a coherent volume
volume[volume < threshold * max_value] = 0
# Clip high values
......
import numpy as np
import scipy.ndimage
from tqdm.notebook import tqdm
from skimage.filters import threshold_li
from qim3d.generate import blob as generate_blob
from qim3d.processing import get_3d_cc
from qim3d.io.logger import log
def random_placement(
collection: np.ndarray,
blob: np.ndarray,
rng: np.random.Generator,
) -> tuple[np.ndarray, bool]:
"""
Place blob at random available position in collection.
Args:
collection (numpy.ndarray): 3D volume of the collection.
blob (numpy.ndarray): 3D volume of the blob.
rng (numpy.random.Generator): Random number generator.
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
"""
# Find available (zero) elements in collection
available_z, available_y, available_x = np.where(collection == 0)
# Flag for placement success
placed = False
# Attempt counter
j = 1
while (not placed) and (j < 200_000):
# Select a random available position in collection
idx = rng.choice(len(available_z))
z, y, x = available_z[idx], available_y[idx], available_x[idx]
start = np.array([z, y, x]) # Start position of blob placement
end = start + np.array(blob.shape) # End position of blob placement
# Check if the blob fits in the selected region (without overlap)
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
)
placed = True
# Increment attempt counter
j += 1
return collection, placed
def specific_placement(
collection: np.ndarray,
blob: np.ndarray,
positions: list[tuple],
) -> tuple[np.ndarray, bool]:
"""
Place blob at one of the specified positions in the collection.
Args:
collection (numpy.ndarray): 3D volume of the collection.
blob (numpy.ndarray): 3D volume of the blob.
positions (list[tuple]): List of specified positions as (z, y, x) coordinates for the blobs.
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
positions (list[tuple]): List of remaining positions to place blobs.
"""
# Flag for placement success
placed = False
for position in positions:
# Get coordinates of next position
z, y, x = position
# Place blob with center at specified position
start = (
np.array([z, y, x]) - np.array(blob.shape) // 2
) # Start position of blob placement
end = start + np.array(blob.shape) # End position of blob placement
# Check if the blob fits in the selected region (without overlap)
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
)
placed = True
# Remove position from list
positions.remove(position)
break
return collection, placed, positions
def collection(
collection_shape: tuple = (200, 200, 200),
num_objects: int = 15,
positions: list[tuple] = None,
min_shape: tuple = (40, 40, 40),
max_shape: tuple = (60, 60, 60),
object_shape_zoom: tuple = (1.0, 1.0, 1.0),
min_object_noise: float = 0.02,
max_object_noise: float = 0.05,
min_rotation_degrees: int = 0,
max_rotation_degrees: int = 360,
rotation_axes: list[tuple] = [(0, 1), (0, 2), (1, 2)],
min_gamma: float = 0.8,
max_gamma: float = 1.2,
min_high_value: int = 128,
max_high_value: int = 255,
min_threshold: float = 0.5,
max_threshold: float = 0.6,
smooth_borders: bool = False,
seed: int = 0,
verbose: bool = False,
) -> tuple[np.ndarray, object]:
"""
Generate a 3D volume of multiple synthetic objects using Perlin noise.
Args:
collection_shape (tuple, optional): Shape of the final collection volume to generate. Defaults to (200, 200, 200).
num_objects (int, optional): Number of synthetic objects to include in the collection. Defaults to 15.
positions (list[tuple], optional): List of specific positions as (z, y, x) coordinates for the objects. If not provided, they are placed randomly into the collection. Defaults to None.
min_shape (tuple, optional): Minimum shape of the objects. Defaults to (40, 40, 40).
max_shape (tuple, optional): Maximum shape of the objects. Defaults to (60, 60, 60).
object_shape_zoom (tuple, optional): Scaling factors for each dimension of each object. Defaults to (1.0, 1.0, 1.0).
min_object_noise (float, optional): Minimum scale factor for Perlin noise. Defaults to 0.02.
max_object_noise (float, optional): Maximum scale factor for Perlin noise. Defaults to 0.05.
min_rotation_degrees (int, optional): Minimum rotation angle in degrees. Defaults to 0.
max_rotation_degrees (int, optional): Maximum rotation angle in degrees. Defaults to 360.
rotation_axes (list[tuple], optional): List of axis pairs that will be randomly chosen to rotate around. Defaults to [(0, 1), (0, 2), (1, 2)].
min_gamma (float, optional): Minimum gamma correction factor. Defaults to 0.8.
max_gamma (float, optional): Maximum gamma correction factor. Defaults to 1.2.
min_high_value (int, optional): Minimum maximum value for the volume intensity. Defaults to 128.
max_high_value (int, optional): Maximum maximum value for the volume intensity. Defaults to 255.
min_threshold (float, optional): Minimum threshold value for clipping low intensity values. Defaults to 0.5.
max_threshold (float, optional): Maximum threshold value for clipping low intensity values. Defaults to 0.6.
smooth_borders (bool, optional): Flag for smoothing blob borders to avoid straight edges in the objects. If True, the `min_threshold` and `max_threshold` parameters are ignored. Defaults to False.
seed (int, optional): Seed for reproducibility. Defaults to 0.
verbose (bool, optional): Flag to enable verbose logging. Defaults to False.
Returns:
synthetic_collection (numpy.ndarray): 3D volume of the generated collection of synthetic blobs with specified parameters.
CC (object): A ConnectedComponents object containing the connected components and the number of connected components found in the collection.
Raises:
TypeError: If `collection_shape` is not 3D.
ValueError: If blob parameters are invalid.
Note:
- The function places objects without overlap.
- The function can either place objects at random positions in the collection (if `positions = None`) or at specific positions provided in the `positions` argument. If specific positions are provided, the number of blobs must match the number of positions (e.g. `num_objects = 2` with `positions = [(12, 8, 10), (24, 20, 18)]`).
- If not all `num_objects` can be placed, the function returns the `synthetic_collection` volume with as many blobs as possible in it, and logs an error.
- Labels for all objects are returned, even if they are not a sigle connected component.
Example:
```python
import qim3d
# Generate synthetic collection of blobs
num_objects = 15
synthetic_collection, labels = qim3d.generate.collection(num_objects = num_objects)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_default.html" width="100%" height="500" frameborder="0"></iframe>
```python
qim3d.viz.slicer(synthetic_collection)
```
![synthetic_collection](assets/screenshots/synthetic_collection_default.gif)
```python
# Visualize labels
cmap = qim3d.viz.colormaps.objects(nlabels=num_objects)
qim3d.viz.slicer(labels, cmap=cmap, vmax=num_objects)
```
![synthetic_collection](assets/screenshots/synthetic_collection_default_labels.gif)
Example:
```python
import qim3d
# Generate synthetic collection of dense blobs
synthetic_collection, labels = qim3d.generate.collection(
min_high_value = 255,
max_high_value = 255,
min_object_noise = 0.05,
max_object_noise = 0.05,
min_threshold = 0.99,
max_threshold = 0.99,
min_gamma = 0.02,
max_gamma = 0.02)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_dense.html" width="100%" height="500" frameborder="0"></iframe>
Example:
```python
import qim3d
# Generate synthetic collection of tubular structures
synthetic_collection, labels = qim3d.generate.collection(
num_objects=10,
collection_shape=(200,100,100),
min_shape = (190, 50, 50),
max_shape = (200, 60, 60),
object_shape_zoom = (1, 0.2, 0.2),
min_object_noise = 0.01,
max_object_noise = 0.02,
max_rotation_degrees=10,
min_threshold = 0.95,
max_threshold = 0.98,
min_gamma = 0.02,
max_gamma = 0.03
)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_tubular.html" width="100%" height="500" frameborder="0"></iframe>
"""
if verbose:
original_log_level = log.getEffectiveLevel()
log.setLevel("DEBUG")
# Check valid input types
if not isinstance(collection_shape, tuple) or len(collection_shape) != 3:
raise TypeError(
"Shape of collection must be a tuple with three dimensions (z, y, x)"
)
if len(min_shape) != len(max_shape):
raise ValueError("Object shapes must be tuples of the same length")
# if not isinstance(blob_shapes, list) or \
# len(blob_shapes) != 2 or len(blob_shapes[0]) != 3 or len(blob_shapes[1]) != 3:
# raise TypeError("Blob shapes must be a list of two tuples with three dimensions (z, y, x)")
if (positions is not None) and (len(positions) != num_objects):
raise ValueError(
"Number of objects must match number of positions, otherwise set positions = None"
)
# Set seed for random number generator
rng = np.random.default_rng(seed)
# Initialize the 3D array for the shape
collection_array = np.zeros(
(collection_shape[0], collection_shape[1], collection_shape[2]), dtype=np.uint8
)
labels = np.zeros_like(collection_array)
# Fill the 3D array with synthetic blobs
for i in tqdm(range(num_objects), desc="Objects placed"):
log.debug(f"\nObject #{i+1}")
# Sample from blob parameter ranges
if min_shape == max_shape:
blob_shape = min_shape
else:
blob_shape = tuple(
rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3)
)
log.debug(f"- Blob shape: {blob_shape}")
# Sample noise scale
noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise)
log.debug(f"- Object noise scale: {noise_scale:.4f}")
gamma = rng.uniform(low=min_gamma, high=max_gamma)
log.debug(f"- Gamma correction: {gamma:.3f}")
if max_high_value > min_high_value:
max_value = rng.integers(low=min_high_value, high=max_high_value)
else:
max_value = min_high_value
log.debug(f"- Max value: {max_value}")
threshold = rng.uniform(low=min_threshold, high=max_threshold)
log.debug(f"- Threshold: {threshold:.3f}")
# Generate synthetic blob
blob = generate_blob(
base_shape=blob_shape,
final_shape=tuple(l * r for l, r in zip(blob_shape, object_shape_zoom)),
noise_scale=noise_scale,
gamma=gamma,
max_value=max_value,
threshold=threshold,
smooth_borders=smooth_borders,
)
# Rotate object
if max_rotation_degrees > 0:
angle = rng.uniform(
low=min_rotation_degrees, high=max_rotation_degrees
) # Sample rotation angle
axes = rng.choice(rotation_axes) # Sample the two axes to rotate around
log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}")
blob = scipy.ndimage.rotate(blob, angle, axes, order=0)
# Place synthetic object into the collection
# If positions are specified, place blob at one of the specified positions
collection_before = collection_array.copy()
if positions:
collection_array, placed, positions = specific_placement(
collection_array, blob, positions
)
# Otherwise, place blob at a random available position
else:
collection_array, placed = random_placement(collection_array, blob, rng)
# Break if blob could not be placed
if not placed:
break
# Update labels
new_labels = np.where(collection_array != collection_before, i + 1, 0).astype(
labels.dtype
)
labels += new_labels
if not placed:
# Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar)
log.error(
f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects."
)
if verbose:
log.setLevel(original_log_level)
return collection_array, labels
......@@ -29,7 +29,7 @@ import numpy as np
from PIL import Image
from qim3d.io import load, save
from qim3d.utils import overlay_rgb_images
from qim3d.processing.operations import overlay_rgb_images
from qim3d.gui.interface import BaseInterface
# TODO: img in launch should be self.img
......
"Testing docstring"
from .local_thickness_ import local_thickness
from .structure_tensor_ import structure_tensor
from .detection import blob_detection
......
......@@ -112,22 +112,22 @@ def fade_mask(
vol: np.ndarray,
decay_rate: float = 10,
ratio: float = 0.5,
geometry: str = "sphere",
invert=False,
geometry: str = "spherical",
invert: bool = False,
axis: int = 0,
):
) -> np.ndarray:
"""
Apply edge fading to a volume.
Args:
vol (np.ndarray): The volume to apply edge fading to.
decay_rate (float, optional): The decay rate of the fading. Defaults to 10.
ratio (float, optional): The ratio of the volume to fade. Defaults to 0.
ratio (float, optional): The ratio of the volume to fade. Defaults to 0.5.
geometric (str, optional): The geometric shape of the fading. Can be 'spherical' or 'cylindrical'. Defaults to 'spherical'.
axis (int, optional): The axis along which to apply the fading. Defaults to 0.
Returns:
np.ndarray: The volume with edge fading applied.
vol_faded (np.ndarray): The volume with edge fading applied.
Example:
```python
......@@ -157,9 +157,9 @@ def fade_mask(
center = np.array([(s - 1) / 2 for s in shape])
# Calculate the distance of each point from the center
if geometry == "sphere":
if geometry == "spherical":
distance = np.linalg.norm([z - center[0], y - center[1], x - center[2]], axis=0)
elif geometry == "cilinder":
elif geometry == "cylindrical":
distance_list = np.array([z - center[0], y - center[1], x - center[2]])
# remove the axis along which the fading is not applied
distance_list = np.delete(distance_list, axis, axis=0)
......@@ -185,3 +185,55 @@ def fade_mask(
vol_faded = vol * fade_array
return vol_faded
def overlay_rgb_images(
background: np.ndarray,
foreground: np.ndarray,
alpha: float = 0.5
) -> np.ndarray:
"""
Overlay an RGB foreground onto an RGB background using alpha blending.
Args:
background (numpy.ndarray): The background RGB image.
foreground (numpy.ndarray): The foreground RGB image (usually masks).
alpha (float, optional): The alpha value for blending. Defaults to 0.5.
Returns:
composite (numpy.ndarray): The composite RGB image with overlaid foreground.
Raises:
ValueError: If input images have different shapes.
Note:
- The function performs alpha blending to overlay the foreground onto the background.
- It ensures that the background and foreground have the same shape before blending.
- It calculates the maximum projection of the foreground and blends them onto the background.
- Brightness outside the foreground is adjusted to maintain consistency with the background.
"""
# Igonore alpha in case its there
background = background[..., :3]
foreground = foreground[..., :3]
# Ensure both images have the same shape
if background.shape != foreground.shape:
raise ValueError("Input images must have the same shape")
# Perform alpha blending
foreground_max_projection = np.amax(foreground, axis=2)
foreground_max_projection = np.stack((foreground_max_projection,) * 3, axis=-1)
# Normalize if we have something
if np.max(foreground_max_projection) > 0:
foreground_max_projection = foreground_max_projection / np.max(
foreground_max_projection
)
composite = background * (1 - alpha) + foreground * alpha
composite = np.clip(composite, 0, 255).astype("uint8")
# Adjust brightness outside foreground
composite = composite + (background * (1 - alpha)) * (1 - foreground_max_projection)
return composite.astype("uint8")
\ No newline at end of file
"""Wrapper for the structure tensor function from the structure_tensor package"""
from typing import Tuple
import logging
import numpy as np
from qim3d.viz.structure_tensor import vectors
def structure_tensor(
......@@ -75,8 +75,12 @@ def structure_tensor(
```
"""
previous_logging_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.CRITICAL)
import structure_tensor as st
logging.getLogger().setLevel(previous_logging_level)
if vol.ndim != 3:
raise ValueError("The input volume must be 3D")
......@@ -88,7 +92,9 @@ def structure_tensor(
# Add small noise to the volume
# FIXME: This is a temporary solution to avoid uniform regions with constant values
# in the volume, which lead to numerical issues in the structure tensor computation
vol_noisy = vol + np.random.default_rng(seed = 0).uniform(0, 1e-10, size=vol.shape)
vol_noisy = vol + np.random.default_rng(seed=0).uniform(
0, 1e-10, size=vol.shape
)
# Compute the structure tensor (of volume with noise)
s_vol = st.structure_tensor_3d(vol_noisy, sigma, rho)
......@@ -101,6 +107,8 @@ def structure_tensor(
val, vec = st.eig_special_3d(s_vol, full=full)
if visualize:
from qim3d.viz.structure_tensor import vectors
display(vectors(vol, vec, **viz_kwargs))
return val, vec
......@@ -2,7 +2,6 @@
from . import doi, internal_tools
from .augmentations import Augmentation
from .data import Dataset, prepare_dataloaders, prepare_datasets
from .img import generate_volume, overlay_rgb_images
from .models import inference, model_summary, train_model
from .preview import image_preview
from .loading_progress_bar import ProgressBar
......
import numpy as np
from typing import Optional, Union, Tuple
import matplotlib.pyplot as plt
import structure_tensor as st
from matplotlib.gridspec import GridSpec
import ipywidgets as widgets
import logging as log
import logging
previous_logging_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.CRITICAL)
import structure_tensor as st
logging.getLogger().setLevel(previous_logging_level)
def vectors(
volume: np.ndarray,
......@@ -109,8 +115,12 @@ def vectors(
# Create three subplots
fig, ax = plt.subplots(1, 3, figsize=figsize, layout="constrained")
blend_hue_saturation = lambda hue, sat : hue * (1 - sat) + 0.5 * sat # Function for blending hue and saturation
blend_slice_colors = lambda slice, colors : 0.5 * (slice + colors) # Function for blending image slice with orientation colors
blend_hue_saturation = (
lambda hue, sat: hue * (1 - sat) + 0.5 * sat
) # Function for blending hue and saturation
blend_slice_colors = lambda slice, colors: 0.5 * (
slice + colors
) # Function for blending image slice with orientation colors
# ----- Subplot 1: Image slice with orientation vectors ----- #
# Create meshgrid with the correct dimensions
......@@ -120,7 +130,9 @@ def vectors(
g = slice(grid_size // 2, None, grid_size)
# Angles from 0 to pi
angles_quiver = np.mod(np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi)
angles_quiver = np.mod(
np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi
)
# Calculate z-component (saturation)
saturation_quiver = (vectors_slice_z[g, g] ** 2)[:, :, np.newaxis]
......@@ -130,8 +142,12 @@ def vectors(
# Blend hue and saturation
rgba_quiver = blend_hue_saturation(hue_quiver, saturation_quiver)
rgba_quiver = np.clip(rgba_quiver, 0, 1) # Ensure rgba values are values within [0, 1]
rgba_quiver_flat = rgba_quiver.reshape((rgba_quiver.shape[0]*rgba_quiver.shape[1], 4)) # Flatten array for quiver plot
rgba_quiver = np.clip(
rgba_quiver, 0, 1
) # Ensure rgba values are values within [0, 1]
rgba_quiver_flat = rgba_quiver.reshape(
(rgba_quiver.shape[0] * rgba_quiver.shape[1], 4)
) # Flatten array for quiver plot
# Plot vectors
ax[0].quiver(
......@@ -152,7 +168,11 @@ def vectors(
)
ax[0].imshow(data_slice, cmap=plt.cm.gray)
ax[0].set_title(f"Orientation vectors (slice {slice_idx})" if not interactive else "Orientation vectors")
ax[0].set_title(
f"Orientation vectors (slice {slice_idx})"
if not interactive
else "Orientation vectors"
)
ax[0].set_axis_off()
# ----- Subplot 2: Orientation histogram ----- #
......@@ -169,15 +189,25 @@ def vectors(
# Calculate z-component (saturation) for each bin
bins = np.digitize(angles.ravel(), bin_edges)
saturation_bin = np.array([np.mean((vectors_slice_z**2).ravel()[bins == i]) \
if np.sum(bins == i) > 0 else 0 for i in range(1, len(bin_edges))])
saturation_bin = np.array(
[
(
np.mean((vectors_slice_z**2).ravel()[bins == i])
if np.sum(bins == i) > 0
else 0
)
for i in range(1, len(bin_edges))
]
)
# Calculate hue for each bin
hue_bin = plt.cm.hsv(bin_centers / np.pi)
# Blend hue and saturation
rgba_bin = hue_bin.copy()
rgba_bin[:, :3] = blend_hue_saturation(hue_bin[:, :3], saturation_bin[:, np.newaxis])
rgba_bin[:, :3] = blend_hue_saturation(
hue_bin[:, :3], saturation_bin[:, np.newaxis]
)
ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
ax[1].set_xlabel("Angle [radians]")
......@@ -200,10 +230,16 @@ def vectors(
rgba = blend_hue_saturation(hue, saturation)
# Grayscale image slice blended with orientation colors
data_slice_orientation_colored = (blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255).astype('uint8')
data_slice_orientation_colored = (
blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
).astype("uint8")
ax[2].imshow(data_slice_orientation_colored)
ax[2].set_title(f"Colored orientations (slice {slice_idx})" if not interactive else "Colored orientations")
ax[2].set_title(
f"Colored orientations (slice {slice_idx})"
if not interactive
else "Colored orientations"
)
ax[2].set_axis_off()
if show:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment