Skip to content
Snippets Groups Projects
Commit c8085470 authored by s184058's avatar s184058 Committed by fima
Browse files

Grid viz

parent 7af67039
Branches
No related tags found
1 merge request!1Grid viz
import qim3d.io
import qim3d.gui
import qim3d.tools
import qim3d.viz
import logging
\ No newline at end of file
from .img import grid_pred, grid_overview
\ No newline at end of file
""" Provides a collection of visualization functions."""
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
import torch
import numpy as np
from qim3d.io.logger import log
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5):
"""Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
Masks are applied to the output and target prior to the loss calculation in case of
sparse labeled data
Args:
data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image,
label, (and mask data).
num_images (int, optional): The maximum number of images to display.
Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying labels.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays.
Defaults to 0.5.
Raises:
ValueError: If the data elements are not tuples.
Notes:
- If the image data is RGB, the color map is ignored and the user is informed.
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The grid layout and dimensions vary based on the presence of a mask.
Returns:
None
Example:
data = [(image1, label1, mask1), (image2, label2, mask2)]
grid_overview(data, num_images=5, cmap_im='viridis', cmap_segm='hot', alpha=0.8)
"""
# Check if data has a mask
has_mask = len(data[0]) > 2 and data[0][-1] is not None
# Check if image data is RGB and inform the user if it's the case
if len(data[0][0].squeeze().shape) > 2:
log.info("Input images are RGB: color map is ignored")
# Check if dataset have at least specified number of images
if len(data) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
num_images,
len(data),
)
num_images = len(data)
# Adapt segmentation cmap so that background is transparent
colors_segm = cm.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data elements must be tuples")
# Define row titles
row_titles = ["Input images", "Ground truth segmentation", "Mask"]
# Make new list such that possible augmentations remain identical for all three rows
plot_data = list(data[:num_images])
fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True)
# create 2 (3) x 1 subfigs
subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1)
for row, subfig in enumerate(subfigs):
subfig.suptitle(row_titles[row], fontsize=22)
# create 1 x num_images subplots per subfig
axs = subfig.subplots(nrows=1, ncols=num_images)
for col, ax in enumerate(np.atleast_1d(axs)):
if row in [1, 2]: # Ground truth segmentation and mask
ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im)
ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha)
ax.axis("off")
else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off")
fig.show()
def grid_pred(
data, model, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5
):
"""Displays a grid of input images, predicted segmentations, and ground truth segmentations.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
num_images (int, optional): The maximum number of images to display. Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The function does not assume that the model is already in evaluation mode (model.eval()).
- The function will execute faster on a CUDA-enabled GPU.
- The grid layout consists of three rows: input images, predicted segmentations,
and ground truth segmentations.
Returns:
None
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
grid_pred(dataset, model, cmap_im='viridis', alpha=0.5)
"""
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
# Check if dataset have at least specified number of images
if len(data) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
num_images,
len(data),
)
num_images = len(data)
# Adapt segmentation cmap so that background is transparent
colors_segm = cm.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
plot_data = list(data[:num_images])
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
else:
preds = outputs.cpu().argmax(axis=1)
N = len(plot_data)
H = plot_data[0][0].shape[-2]
W = plot_data[0][0].shape[-1]
comp_rgb = torch.zeros((N,4,H,W))
comp_rgb[:,1,:,:] = targets.logical_and(preds)
comp_rgb[:,0,:,:] = targets.logical_xor(preds)
comp_rgb[:,3,:,:] = targets.logical_or(preds)
row_titles = [
"Input images",
"Predicted segmentation",
"Ground truth segmentation",
"True vs. predicted segmentation",
]
fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True)
# create 3 x 1 subfigs
subfigs = fig.subfigures(nrows=4, ncols=1)
for row, subfig in enumerate(subfigs):
subfig.suptitle(row_titles[row], fontsize=22)
# create 1 x num_images subplots per subfig
axs = subfig.subplots(nrows=1, ncols=num_images)
for col, ax in enumerate(np.atleast_1d(axs)):
if row == 0:
ax.imshow(inputs[col], cmap=cmap_im)
ax.axis("off")
elif row == 1: # Predicted segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(
plot_data[col][1].cpu().squeeze(), cmap=custom_cmap, alpha=alpha
)
ax.axis("off")
else:
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha)
ax.axis("off")
fig.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment