Skip to content
Snippets Groups Projects

Pre commits

Merged s193396 requested to merge pre_commits into main
+ 45
37
"""Visualization tools"""
import matplotlib
import matplotlib.figure
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import torch
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
from qim3d.utils._logger import log
import matplotlib
def plot_metrics(
*metrics: tuple[dict[str, float]],
linestyle: str = "-",
batch_linestyle: str = "dotted",
linestyle: str = '-',
batch_linestyle: str = 'dotted',
labels: list | None = None,
figsize: tuple = (16, 6),
show: bool = False
show: bool = False,
):
"""
Plots the metrics over epochs and batches.
@@ -35,6 +38,7 @@ def plot_metrics(
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
"""
import seaborn as snb
@@ -44,9 +48,9 @@ def plot_metrics(
raise ValueError("The number of metrics doesn't match the number of labels.")
# plotting parameters
snb.set_style("darkgrid")
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
plt.rcParams["lines.linewidth"] = 2
plt.rcParams['lines.linewidth'] = 2
fig = plt.figure(figsize=figsize)
@@ -68,10 +72,10 @@ def plot_metrics(
plt.legend()
plt.ylabel(metric_name)
plt.xlabel("epoch")
plt.xlabel('epoch')
# reset plotting parameters
snb.set_style("white")
snb.set_style('white')
if show:
plt.show()
@@ -81,14 +85,15 @@ def plot_metrics(
def grid_overview(
data: list,
data: list | torch.utils.data.Dataset,
num_images: int = 7,
cmap_im: str = "gray",
cmap_segm: str = "viridis",
cmap_im: str = 'gray',
cmap_segm: str = 'viridis',
alpha: float = 0.5,
show: bool = False,
) -> matplotlib.figure.Figure:
"""Displays an overview grid of images, labels, and masks (if they exist).
"""
Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
Masks are applied to the output and target prior to the loss calculation in case of
@@ -120,6 +125,7 @@ def grid_overview(
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The grid layout and dimensions vary based on the presence of a mask.
"""
import torch
@@ -128,12 +134,12 @@ def grid_overview(
# Check if image data is RGB and inform the user if it's the case
if len(data[0][0].squeeze().shape) > 2:
log.info("Input images are RGB: color map is ignored")
log.info('Input images are RGB: color map is ignored')
# Check if dataset have at least specified number of images
if len(data) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
'Not enough images in the dataset. Changing num_images=%d to num_images=%d',
num_images,
len(data),
)
@@ -142,14 +148,14 @@ def grid_overview(
# Adapt segmentation cmap so that background is transparent
colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm)
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data elements must be tuples")
raise ValueError('Data elements must be tuples')
# Define row titles
row_titles = ["Input images", "Ground truth segmentation", "Mask"]
row_titles = ['Input images', 'Ground truth segmentation', 'Mask']
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(num_images)]
@@ -169,10 +175,10 @@ def grid_overview(
if row in [1, 2]: # Ground truth segmentation and mask
ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im)
ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off")
ax.axis('off')
if show:
plt.show()
@@ -184,12 +190,13 @@ def grid_overview(
def grid_pred(
in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray],
num_images: int = 7,
cmap_im: str = "gray",
cmap_segm: str = "viridis",
cmap_im: str = 'gray',
cmap_segm: str = 'viridis',
alpha: float = 0.5,
show: bool = False,
) -> matplotlib.figure.Figure:
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
"""
Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations.
The grid includes the following rows:
@@ -221,25 +228,26 @@ def grid_pred(
model = MySegmentationModel()
in_targ_preds = qim3d.ml.inference(dataset,model)
qim3d.viz.grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)
"""
import torch
# Check if dataset have at least specified number of images
if len(in_targ_preds[0]) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
'Not enough images in the dataset. Changing num_images=%d to num_images=%d',
num_images,
len(in_targ_preds[0]),
)
num_images = len(in_targ_preds[0])
# Take only the number of images from in_targ_preds
inputs, targets, preds = [items[:num_images] for items in in_targ_preds]
inputs, targets, preds = (items[:num_images] for items in in_targ_preds)
# Adapt segmentation cmap so that background is transparent
colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm)
N = num_images
H = inputs[0].shape[-2]
@@ -251,10 +259,10 @@ def grid_pred(
comp_rgb[:, 3, :, :] = targets.logical_or(preds)
row_titles = [
"Input images",
"Predicted segmentation",
"Ground truth segmentation",
"True vs. predicted segmentation",
'Input images',
'Predicted segmentation',
'Ground truth segmentation',
'True vs. predicted segmentation',
]
fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True)
@@ -269,20 +277,20 @@ def grid_pred(
for col, ax in enumerate(np.atleast_1d(axs)):
if row == 0:
ax.imshow(inputs[col], cmap=cmap_im)
ax.axis("off")
ax.axis('off')
elif row == 1: # Predicted segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
else:
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha)
ax.axis("off")
ax.axis('off')
if show:
plt.show()
@@ -315,8 +323,8 @@ def vol_masked(
"""
background = (vol.astype("float") + viz_delta) * (1 - vol_mask) * -1
foreground = (vol.astype("float") + viz_delta) * vol_mask
background = (vol.astype('float') + viz_delta) * (1 - vol_mask) * -1
foreground = (vol.astype('float') + viz_delta) * vol_mask
vol_masked_result = background + foreground
return vol_masked_result
Loading