Skip to content
Snippets Groups Projects
Commit ab793d64 authored by fima's avatar fima :beers:
Browse files

Merge branch 'synthetic_objects' into 'main'

Synthetic collection generation

See merge request !108
parents db5ac596 5d1ba6f3
No related branches found
No related tags found
1 merge request!108Synthetic collection generation
Showing
with 651 additions and 60 deletions
docs/assets/screenshots/synthetic_blob_slices.png

61.2 KiB

docs/assets/screenshots/synthetic_collection_cc.gif

525 KiB

docs/assets/screenshots/synthetic_collection_default.gif

620 KiB

docs/assets/screenshots/synthetic_collection_default_labels.gif

312 KiB

# Generating synthetic data
The `qim3d` library provides a set of methods for generating volumes consisting of a single synthetic blob or a collection of multiple synthetic blobs.
::: qim3d.generate
options:
members:
- blob
- collection
\ No newline at end of file
......@@ -5,7 +5,6 @@ Here, we provide functionalities designed specifically for 3D image analysis and
::: qim3d.processing
options:
members:
- test_blob_detection
- blob_detection
- structure_tensor
- local_thickness
......@@ -26,5 +25,5 @@ Here, we provide functionalities designed specifically for 3D image analysis and
members:
- remove_background
- watershed
- edge_fade
- fade_mask
- overlay_rgb_images
......@@ -2,12 +2,6 @@
A set of tools to ease managment of the system, with the common needs for large data in mind.
::: qim3d.utils.img
options:
members:
- generate_volume
- overlay_rgb_images
::: qim3d.utils.system
options:
members:
......
......@@ -10,6 +10,7 @@ nav:
- qim3d: index.md
- Input & Output: io.md
- Data Generation: generate.md
- Processing: processing.md
- Visualization: viz.md
- Utils: utils.md
......
......@@ -15,6 +15,7 @@ from . import gui
from . import viz
from . import utils
from . import processing
from . import generate
# commented out to avoid torch import
# from . import models
......
from .blob_ import blob
from .collection_ import collection
\ No newline at end of file
......@@ -2,65 +2,17 @@ import numpy as np
import scipy.ndimage
from noise import pnoise3
def overlay_rgb_images(background, foreground, alpha=0.5):
"""Overlay a RGB foreground onto an RGB background using alpha blending.
Args:
background (numpy.ndarray): The background RGB image.
foreground (numpy.ndarray): The foreground RGB image (usually masks).
alpha (float, optional): The alpha value for blending. Defaults to 0.5.
Returns:
numpy.ndarray: The composite RGB image with overlaid foreground.
Raises:
ValueError: If input images have different shapes.
Note:
- The function performs alpha blending to overlay the foreground onto the background.
- It ensures that the background and foreground have the same shape before blending.
- It calculates the maximum projection of the foreground and blends them onto the background.
- Brightness outside the foreground is adjusted to maintain consistency with the background.
"""
# Igonore alpha in case its there
background = background[..., :3]
foreground = foreground[..., :3]
# Ensure both images have the same shape
if background.shape != foreground.shape:
raise ValueError("Input images must have the same shape")
# Perform alpha blending
foreground_max_projection = np.amax(foreground, axis=2)
foreground_max_projection = np.stack((foreground_max_projection,) * 3, axis=-1)
# Normalize if we have something
if np.max(foreground_max_projection) > 0:
foreground_max_projection = foreground_max_projection / np.max(
foreground_max_projection
)
composite = background * (1 - alpha) + foreground * alpha
composite = np.clip(composite, 0, 255).astype("uint8")
# Adjust brightness outside foreground
composite = composite + (background * (1 - alpha)) * (1 - foreground_max_projection)
return composite.astype("uint8")
def generate_volume(
base_shape=(128, 128, 128),
final_shape=(128, 128, 128),
noise_scale=0.05,
order=1,
gamma=1.0,
max_value=255,
threshold=0.5,
dtype="uint8",
):
def blob(
base_shape: tuple = (128, 128, 128),
final_shape: tuple = (128, 128, 128),
noise_scale: float = 0.05,
order: int = 1,
gamma: int = 1.0,
max_value: int = 255,
threshold: float = 0.5,
smooth_borders: bool = False,
dtype: str = "uint8",
) -> np.ndarray:
"""
Generate a 3D volume with Perlin noise, spherical gradient, and optional scaling and gamma correction.
......@@ -72,55 +24,65 @@ def generate_volume(
gamma (float, optional): Gamma correction factor. Defaults to 1.0.
max_value (int, optional): Maximum value for the volume intensity. Defaults to 255.
threshold (float, optional): Threshold value for clipping low intensity values. Defaults to 0.5.
smooth_borders (bool, optional): Flag for automatic computation of the threshold value to ensure a blob with no straight edges. If True, the `threshold` parameter is ignored. Defaults to False.
dtype (str, optional): Desired data type of the output volume. Defaults to "uint8".
Returns:
numpy.ndarray: Generated 3D volume with specified parameters.
synthetic_blob (numpy.ndarray): Generated 3D volume with specified parameters.
Raises:
ValueError: If `final_shape` is not a tuple or does not have three elements.
TypeError: If `final_shape` is not a tuple or does not have three elements.
ValueError: If `dtype` is not a valid numpy number type.
Example:
```python
import qim3d
vol = qim3d.utils.generate_volume(noise_scale=0.05, threshold=0.4)
qim3d.viz.slices(vol, vmin=0, vmax=255, n_slices=15)
# Generate synthetic blob
synthetic_blob = qim3d.generate.blob(noise_scale = 0.05)
# Visualize slices
qim3d.viz.slices(synthetic_blob, vmin = 0, vmax = 255, n_slices = 15)
```
![generate_volume](assets/screenshots/generate_volume.png)
![synthetic_blob](assets/screenshots/synthetic_blob_slices.png)
```python
qim3d.viz.vol(vol)
# Visualize 3D volume
qim3d.viz.vol(synthetic_blob)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_volume.html" width="100%" height="500" frameborder="0"></iframe>
<iframe src="https://platform.qim.dk/k3d/synthetic_blob.html" width="100%" height="500" frameborder="0"></iframe>
"""
if not isinstance(final_shape, tuple) or len(final_shape) != 3:
raise ValueError("Size must be a tuple")
raise TypeError("Size must be a tuple of 3 dimensions")
if not np.issubdtype(dtype, np.number):
raise ValueError("Invalid data type")
# Define the dimensions of the shape for generating Perlin noise
# Initialize the 3D array for the shape
volume = np.empty((base_shape[0], base_shape[1], base_shape[2]), dtype=np.float32)
# Fill the 3D array with values from the Perlin noise function
for i in range(base_shape[0]):
for j in range(base_shape[1]):
for k in range(base_shape[2]):
# Generate grid of coordinates
z, y, x = np.indices(base_shape)
# Calculate the distance from the center of the shape
dist = np.sqrt(
(i - base_shape[0] / 2) ** 2
+ (j - base_shape[1] / 2) ** 2
+ (k - base_shape[2] / 2) ** 2
) / np.sqrt(3 * ((base_shape[0] / 2) ** 2))
center = np.array(base_shape) / 2
dist = np.sqrt((z - center[0])**2 +
(y - center[1])**2 +
(x - center[2])**2)
dist /= np.sqrt(3 * (center[0]**2))
# Generate Perlin noise and adjust the values based on the distance from the center
# This creates a spherical shape with noise
volume[i][j][k] = (
1 + pnoise3(i * noise_scale, j * noise_scale, k * noise_scale)
) * (1 - dist)
vectorized_pnoise3 = np.vectorize(pnoise3) # Vectorize pnoise3, since it only takes scalar input
noise = vectorized_pnoise3(z.flatten() * noise_scale,
y.flatten() * noise_scale,
x.flatten() * noise_scale
).reshape(base_shape)
volume = (1 + noise) * (1 - dist)
# Normalize
volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
......@@ -131,7 +93,18 @@ def generate_volume(
# Scale the volume to the maximum value
volume = volume * max_value
# clip the low values of the volume to create a coherent volume
if smooth_borders:
# Maximum value among the six sides of the 3D volume
max_border_value = np.max([
np.max(volume[0, :, :]), np.max(volume[-1, :, :]),
np.max(volume[:, 0, :]), np.max(volume[:, -1, :]),
np.max(volume[:, :, 0]), np.max(volume[:, :, -1])
])
# Compute threshold such that there will be no straight cuts in the blob
threshold = max_border_value / max_value
# Clip the low values of the volume to create a coherent volume
volume[volume < threshold * max_value] = 0
# Clip high values
......
import numpy as np
import scipy.ndimage
from tqdm.notebook import tqdm
from skimage.filters import threshold_li
from qim3d.generate import blob as generate_blob
from qim3d.processing import get_3d_cc
from qim3d.io.logger import log
def random_placement(
collection: np.ndarray,
blob: np.ndarray,
rng: np.random.Generator,
) -> tuple[np.ndarray, bool]:
"""
Place blob at random available position in collection.
Args:
collection (numpy.ndarray): 3D volume of the collection.
blob (numpy.ndarray): 3D volume of the blob.
rng (numpy.random.Generator): Random number generator.
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
"""
# Find available (zero) elements in collection
available_z, available_y, available_x = np.where(collection == 0)
# Flag for placement success
placed = False
# Attempt counter
j = 1
while (not placed) and (j < 200_000):
# Select a random available position in collection
idx = rng.choice(len(available_z))
z, y, x = available_z[idx], available_y[idx], available_x[idx]
start = np.array([z, y, x]) # Start position of blob placement
end = start + np.array(blob.shape) # End position of blob placement
# Check if the blob fits in the selected region (without overlap)
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
)
placed = True
# Increment attempt counter
j += 1
return collection, placed
def specific_placement(
collection: np.ndarray,
blob: np.ndarray,
positions: list[tuple],
) -> tuple[np.ndarray, bool]:
"""
Place blob at one of the specified positions in the collection.
Args:
collection (numpy.ndarray): 3D volume of the collection.
blob (numpy.ndarray): 3D volume of the blob.
positions (list[tuple]): List of specified positions as (z, y, x) coordinates for the blobs.
Returns:
collection (numpy.ndarray): 3D volume of the collection with the blob placed.
placed (bool): Flag for placement success.
positions (list[tuple]): List of remaining positions to place blobs.
"""
# Flag for placement success
placed = False
for position in positions:
# Get coordinates of next position
z, y, x = position
# Place blob with center at specified position
start = (
np.array([z, y, x]) - np.array(blob.shape) // 2
) # Start position of blob placement
end = start + np.array(blob.shape) # End position of blob placement
# Check if the blob fits in the selected region (without overlap)
if np.all(
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] == 0
):
# Check if placement is within bounds (bool)
within_bounds = np.all(start >= 0) and np.all(
end <= np.array(collection.shape)
)
if within_bounds:
# Place blob
collection[start[0] : end[0], start[1] : end[1], start[2] : end[2]] = (
blob
)
placed = True
# Remove position from list
positions.remove(position)
break
return collection, placed, positions
def collection(
collection_shape: tuple = (200, 200, 200),
num_objects: int = 15,
positions: list[tuple] = None,
min_shape: tuple = (40, 40, 40),
max_shape: tuple = (60, 60, 60),
object_shape_zoom: tuple = (1.0, 1.0, 1.0),
min_object_noise: float = 0.02,
max_object_noise: float = 0.05,
min_rotation_degrees: int = 0,
max_rotation_degrees: int = 360,
rotation_axes: list[tuple] = [(0, 1), (0, 2), (1, 2)],
min_gamma: float = 0.8,
max_gamma: float = 1.2,
min_high_value: int = 128,
max_high_value: int = 255,
min_threshold: float = 0.5,
max_threshold: float = 0.6,
smooth_borders: bool = False,
seed: int = 0,
verbose: bool = False,
) -> tuple[np.ndarray, object]:
"""
Generate a 3D volume of multiple synthetic objects using Perlin noise.
Args:
collection_shape (tuple, optional): Shape of the final collection volume to generate. Defaults to (200, 200, 200).
num_objects (int, optional): Number of synthetic objects to include in the collection. Defaults to 15.
positions (list[tuple], optional): List of specific positions as (z, y, x) coordinates for the objects. If not provided, they are placed randomly into the collection. Defaults to None.
min_shape (tuple, optional): Minimum shape of the objects. Defaults to (40, 40, 40).
max_shape (tuple, optional): Maximum shape of the objects. Defaults to (60, 60, 60).
object_shape_zoom (tuple, optional): Scaling factors for each dimension of each object. Defaults to (1.0, 1.0, 1.0).
min_object_noise (float, optional): Minimum scale factor for Perlin noise. Defaults to 0.02.
max_object_noise (float, optional): Maximum scale factor for Perlin noise. Defaults to 0.05.
min_rotation_degrees (int, optional): Minimum rotation angle in degrees. Defaults to 0.
max_rotation_degrees (int, optional): Maximum rotation angle in degrees. Defaults to 360.
rotation_axes (list[tuple], optional): List of axis pairs that will be randomly chosen to rotate around. Defaults to [(0, 1), (0, 2), (1, 2)].
min_gamma (float, optional): Minimum gamma correction factor. Defaults to 0.8.
max_gamma (float, optional): Maximum gamma correction factor. Defaults to 1.2.
min_high_value (int, optional): Minimum maximum value for the volume intensity. Defaults to 128.
max_high_value (int, optional): Maximum maximum value for the volume intensity. Defaults to 255.
min_threshold (float, optional): Minimum threshold value for clipping low intensity values. Defaults to 0.5.
max_threshold (float, optional): Maximum threshold value for clipping low intensity values. Defaults to 0.6.
smooth_borders (bool, optional): Flag for smoothing blob borders to avoid straight edges in the objects. If True, the `min_threshold` and `max_threshold` parameters are ignored. Defaults to False.
seed (int, optional): Seed for reproducibility. Defaults to 0.
verbose (bool, optional): Flag to enable verbose logging. Defaults to False.
Returns:
synthetic_collection (numpy.ndarray): 3D volume of the generated collection of synthetic blobs with specified parameters.
CC (object): A ConnectedComponents object containing the connected components and the number of connected components found in the collection.
Raises:
TypeError: If `collection_shape` is not 3D.
ValueError: If blob parameters are invalid.
Note:
- The function places objects without overlap.
- The function can either place objects at random positions in the collection (if `positions = None`) or at specific positions provided in the `positions` argument. If specific positions are provided, the number of blobs must match the number of positions (e.g. `num_objects = 2` with `positions = [(12, 8, 10), (24, 20, 18)]`).
- If not all `num_objects` can be placed, the function returns the `synthetic_collection` volume with as many blobs as possible in it, and logs an error.
- Labels for all objects are returned, even if they are not a sigle connected component.
Example:
```python
import qim3d
# Generate synthetic collection of blobs
num_objects = 15
synthetic_collection, labels = qim3d.generate.collection(num_objects = num_objects)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_default.html" width="100%" height="500" frameborder="0"></iframe>
```python
qim3d.viz.slicer(synthetic_collection)
```
![synthetic_collection](assets/screenshots/synthetic_collection_default.gif)
```python
# Visualize labels
cmap = qim3d.viz.colormaps.objects(nlabels=num_objects)
qim3d.viz.slicer(labels, cmap=cmap, vmax=num_objects)
```
![synthetic_collection](assets/screenshots/synthetic_collection_default_labels.gif)
Example:
```python
import qim3d
# Generate synthetic collection of dense blobs
synthetic_collection, labels = qim3d.generate.collection(
min_high_value = 255,
max_high_value = 255,
min_object_noise = 0.05,
max_object_noise = 0.05,
min_threshold = 0.99,
max_threshold = 0.99,
min_gamma = 0.02,
max_gamma = 0.02)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_dense.html" width="100%" height="500" frameborder="0"></iframe>
Example:
```python
import qim3d
# Generate synthetic collection of tubular structures
synthetic_collection, labels = qim3d.generate.collection(
num_objects=10,
collection_shape=(200,100,100),
min_shape = (190, 50, 50),
max_shape = (200, 60, 60),
object_shape_zoom = (1, 0.2, 0.2),
min_object_noise = 0.01,
max_object_noise = 0.02,
max_rotation_degrees=10,
min_threshold = 0.95,
max_threshold = 0.98,
min_gamma = 0.02,
max_gamma = 0.03
)
# Visualize synthetic collection
qim3d.viz.vol(synthetic_collection)
```
<iframe src="https://platform.qim.dk/k3d/synthetic_collection_tubular.html" width="100%" height="500" frameborder="0"></iframe>
"""
if verbose:
original_log_level = log.getEffectiveLevel()
log.setLevel("DEBUG")
# Check valid input types
if not isinstance(collection_shape, tuple) or len(collection_shape) != 3:
raise TypeError(
"Shape of collection must be a tuple with three dimensions (z, y, x)"
)
if len(min_shape) != len(max_shape):
raise ValueError("Object shapes must be tuples of the same length")
# if not isinstance(blob_shapes, list) or \
# len(blob_shapes) != 2 or len(blob_shapes[0]) != 3 or len(blob_shapes[1]) != 3:
# raise TypeError("Blob shapes must be a list of two tuples with three dimensions (z, y, x)")
if (positions is not None) and (len(positions) != num_objects):
raise ValueError(
"Number of objects must match number of positions, otherwise set positions = None"
)
# Set seed for random number generator
rng = np.random.default_rng(seed)
# Initialize the 3D array for the shape
collection_array = np.zeros(
(collection_shape[0], collection_shape[1], collection_shape[2]), dtype=np.uint8
)
labels = np.zeros_like(collection_array)
# Fill the 3D array with synthetic blobs
for i in tqdm(range(num_objects), desc="Objects placed"):
log.debug(f"\nObject #{i+1}")
# Sample from blob parameter ranges
if min_shape == max_shape:
blob_shape = min_shape
else:
blob_shape = tuple(
rng.integers(low=min_shape[i], high=max_shape[i]) for i in range(3)
)
log.debug(f"- Blob shape: {blob_shape}")
# Sample noise scale
noise_scale = rng.uniform(low=min_object_noise, high=max_object_noise)
log.debug(f"- Object noise scale: {noise_scale:.4f}")
gamma = rng.uniform(low=min_gamma, high=max_gamma)
log.debug(f"- Gamma correction: {gamma:.3f}")
if max_high_value > min_high_value:
max_value = rng.integers(low=min_high_value, high=max_high_value)
else:
max_value = min_high_value
log.debug(f"- Max value: {max_value}")
threshold = rng.uniform(low=min_threshold, high=max_threshold)
log.debug(f"- Threshold: {threshold:.3f}")
# Generate synthetic blob
blob = generate_blob(
base_shape=blob_shape,
final_shape=tuple(l * r for l, r in zip(blob_shape, object_shape_zoom)),
noise_scale=noise_scale,
gamma=gamma,
max_value=max_value,
threshold=threshold,
smooth_borders=smooth_borders,
)
# Rotate object
if max_rotation_degrees > 0:
angle = rng.uniform(
low=min_rotation_degrees, high=max_rotation_degrees
) # Sample rotation angle
axes = rng.choice(rotation_axes) # Sample the two axes to rotate around
log.debug(f"- Rotation angle: {angle:.2f} at axes: {axes}")
blob = scipy.ndimage.rotate(blob, angle, axes, order=0)
# Place synthetic object into the collection
# If positions are specified, place blob at one of the specified positions
collection_before = collection_array.copy()
if positions:
collection_array, placed, positions = specific_placement(
collection_array, blob, positions
)
# Otherwise, place blob at a random available position
else:
collection_array, placed = random_placement(collection_array, blob, rng)
# Break if blob could not be placed
if not placed:
break
# Update labels
new_labels = np.where(collection_array != collection_before, i + 1, 0).astype(
labels.dtype
)
labels += new_labels
if not placed:
# Log error if not all num_objects could be placed (this line of code has to be here, otherwise it will interfere with tqdm progress bar)
log.error(
f"Object #{i+1} could not be placed in the collection, no space found. Collection contains {i}/{num_objects} objects."
)
if verbose:
log.setLevel(original_log_level)
return collection_array, labels
......@@ -29,7 +29,7 @@ import numpy as np
from PIL import Image
from qim3d.io import load, save
from qim3d.utils import overlay_rgb_images
from qim3d.processing.operations import overlay_rgb_images
from qim3d.gui.interface import BaseInterface
# TODO: img in launch should be self.img
......
"Testing docstring"
from .local_thickness_ import local_thickness
from .structure_tensor_ import structure_tensor
from .detection import blob_detection
......
......@@ -112,22 +112,22 @@ def fade_mask(
vol: np.ndarray,
decay_rate: float = 10,
ratio: float = 0.5,
geometry: str = "sphere",
invert=False,
geometry: str = "spherical",
invert: bool = False,
axis: int = 0,
):
) -> np.ndarray:
"""
Apply edge fading to a volume.
Args:
vol (np.ndarray): The volume to apply edge fading to.
decay_rate (float, optional): The decay rate of the fading. Defaults to 10.
ratio (float, optional): The ratio of the volume to fade. Defaults to 0.
ratio (float, optional): The ratio of the volume to fade. Defaults to 0.5.
geometric (str, optional): The geometric shape of the fading. Can be 'spherical' or 'cylindrical'. Defaults to 'spherical'.
axis (int, optional): The axis along which to apply the fading. Defaults to 0.
Returns:
np.ndarray: The volume with edge fading applied.
vol_faded (np.ndarray): The volume with edge fading applied.
Example:
```python
......@@ -157,9 +157,9 @@ def fade_mask(
center = np.array([(s - 1) / 2 for s in shape])
# Calculate the distance of each point from the center
if geometry == "sphere":
if geometry == "spherical":
distance = np.linalg.norm([z - center[0], y - center[1], x - center[2]], axis=0)
elif geometry == "cilinder":
elif geometry == "cylindrical":
distance_list = np.array([z - center[0], y - center[1], x - center[2]])
# remove the axis along which the fading is not applied
distance_list = np.delete(distance_list, axis, axis=0)
......@@ -185,3 +185,55 @@ def fade_mask(
vol_faded = vol * fade_array
return vol_faded
def overlay_rgb_images(
background: np.ndarray,
foreground: np.ndarray,
alpha: float = 0.5
) -> np.ndarray:
"""
Overlay an RGB foreground onto an RGB background using alpha blending.
Args:
background (numpy.ndarray): The background RGB image.
foreground (numpy.ndarray): The foreground RGB image (usually masks).
alpha (float, optional): The alpha value for blending. Defaults to 0.5.
Returns:
composite (numpy.ndarray): The composite RGB image with overlaid foreground.
Raises:
ValueError: If input images have different shapes.
Note:
- The function performs alpha blending to overlay the foreground onto the background.
- It ensures that the background and foreground have the same shape before blending.
- It calculates the maximum projection of the foreground and blends them onto the background.
- Brightness outside the foreground is adjusted to maintain consistency with the background.
"""
# Igonore alpha in case its there
background = background[..., :3]
foreground = foreground[..., :3]
# Ensure both images have the same shape
if background.shape != foreground.shape:
raise ValueError("Input images must have the same shape")
# Perform alpha blending
foreground_max_projection = np.amax(foreground, axis=2)
foreground_max_projection = np.stack((foreground_max_projection,) * 3, axis=-1)
# Normalize if we have something
if np.max(foreground_max_projection) > 0:
foreground_max_projection = foreground_max_projection / np.max(
foreground_max_projection
)
composite = background * (1 - alpha) + foreground * alpha
composite = np.clip(composite, 0, 255).astype("uint8")
# Adjust brightness outside foreground
composite = composite + (background * (1 - alpha)) * (1 - foreground_max_projection)
return composite.astype("uint8")
\ No newline at end of file
"""Wrapper for the structure tensor function from the structure_tensor package"""
from typing import Tuple
import logging
import numpy as np
from qim3d.viz.structure_tensor import vectors
def structure_tensor(
......@@ -75,8 +75,12 @@ def structure_tensor(
```
"""
previous_logging_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.CRITICAL)
import structure_tensor as st
logging.getLogger().setLevel(previous_logging_level)
if vol.ndim != 3:
raise ValueError("The input volume must be 3D")
......@@ -88,7 +92,9 @@ def structure_tensor(
# Add small noise to the volume
# FIXME: This is a temporary solution to avoid uniform regions with constant values
# in the volume, which lead to numerical issues in the structure tensor computation
vol_noisy = vol + np.random.default_rng(seed = 0).uniform(0, 1e-10, size=vol.shape)
vol_noisy = vol + np.random.default_rng(seed=0).uniform(
0, 1e-10, size=vol.shape
)
# Compute the structure tensor (of volume with noise)
s_vol = st.structure_tensor_3d(vol_noisy, sigma, rho)
......@@ -101,6 +107,8 @@ def structure_tensor(
val, vec = st.eig_special_3d(s_vol, full=full)
if visualize:
from qim3d.viz.structure_tensor import vectors
display(vectors(vol, vec, **viz_kwargs))
return val, vec
......@@ -2,7 +2,6 @@
from . import doi, internal_tools
from .augmentations import Augmentation
from .data import Dataset, prepare_dataloaders, prepare_datasets
from .img import generate_volume, overlay_rgb_images
from .models import inference, model_summary, train_model
from .preview import image_preview
from .loading_progress_bar import ProgressBar
......
import numpy as np
from typing import Optional, Union, Tuple
import matplotlib.pyplot as plt
import structure_tensor as st
from matplotlib.gridspec import GridSpec
import ipywidgets as widgets
import logging as log
import logging
previous_logging_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.CRITICAL)
import structure_tensor as st
logging.getLogger().setLevel(previous_logging_level)
def vectors(
volume: np.ndarray,
......@@ -109,8 +115,12 @@ def vectors(
# Create three subplots
fig, ax = plt.subplots(1, 3, figsize=figsize, layout="constrained")
blend_hue_saturation = lambda hue, sat : hue * (1 - sat) + 0.5 * sat # Function for blending hue and saturation
blend_slice_colors = lambda slice, colors : 0.5 * (slice + colors) # Function for blending image slice with orientation colors
blend_hue_saturation = (
lambda hue, sat: hue * (1 - sat) + 0.5 * sat
) # Function for blending hue and saturation
blend_slice_colors = lambda slice, colors: 0.5 * (
slice + colors
) # Function for blending image slice with orientation colors
# ----- Subplot 1: Image slice with orientation vectors ----- #
# Create meshgrid with the correct dimensions
......@@ -120,7 +130,9 @@ def vectors(
g = slice(grid_size // 2, None, grid_size)
# Angles from 0 to pi
angles_quiver = np.mod(np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi)
angles_quiver = np.mod(
np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi
)
# Calculate z-component (saturation)
saturation_quiver = (vectors_slice_z[g, g] ** 2)[:, :, np.newaxis]
......@@ -130,8 +142,12 @@ def vectors(
# Blend hue and saturation
rgba_quiver = blend_hue_saturation(hue_quiver, saturation_quiver)
rgba_quiver = np.clip(rgba_quiver, 0, 1) # Ensure rgba values are values within [0, 1]
rgba_quiver_flat = rgba_quiver.reshape((rgba_quiver.shape[0]*rgba_quiver.shape[1], 4)) # Flatten array for quiver plot
rgba_quiver = np.clip(
rgba_quiver, 0, 1
) # Ensure rgba values are values within [0, 1]
rgba_quiver_flat = rgba_quiver.reshape(
(rgba_quiver.shape[0] * rgba_quiver.shape[1], 4)
) # Flatten array for quiver plot
# Plot vectors
ax[0].quiver(
......@@ -152,7 +168,11 @@ def vectors(
)
ax[0].imshow(data_slice, cmap=plt.cm.gray)
ax[0].set_title(f"Orientation vectors (slice {slice_idx})" if not interactive else "Orientation vectors")
ax[0].set_title(
f"Orientation vectors (slice {slice_idx})"
if not interactive
else "Orientation vectors"
)
ax[0].set_axis_off()
# ----- Subplot 2: Orientation histogram ----- #
......@@ -169,15 +189,25 @@ def vectors(
# Calculate z-component (saturation) for each bin
bins = np.digitize(angles.ravel(), bin_edges)
saturation_bin = np.array([np.mean((vectors_slice_z**2).ravel()[bins == i]) \
if np.sum(bins == i) > 0 else 0 for i in range(1, len(bin_edges))])
saturation_bin = np.array(
[
(
np.mean((vectors_slice_z**2).ravel()[bins == i])
if np.sum(bins == i) > 0
else 0
)
for i in range(1, len(bin_edges))
]
)
# Calculate hue for each bin
hue_bin = plt.cm.hsv(bin_centers / np.pi)
# Blend hue and saturation
rgba_bin = hue_bin.copy()
rgba_bin[:, :3] = blend_hue_saturation(hue_bin[:, :3], saturation_bin[:, np.newaxis])
rgba_bin[:, :3] = blend_hue_saturation(
hue_bin[:, :3], saturation_bin[:, np.newaxis]
)
ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
ax[1].set_xlabel("Angle [radians]")
......@@ -200,10 +230,16 @@ def vectors(
rgba = blend_hue_saturation(hue, saturation)
# Grayscale image slice blended with orientation colors
data_slice_orientation_colored = (blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255).astype('uint8')
data_slice_orientation_colored = (
blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
).astype("uint8")
ax[2].imshow(data_slice_orientation_colored)
ax[2].set_title(f"Colored orientations (slice {slice_idx})" if not interactive else "Colored orientations")
ax[2].set_title(
f"Colored orientations (slice {slice_idx})"
if not interactive
else "Colored orientations"
)
ax[2].set_axis_off()
if show:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment