Skip to content
Snippets Groups Projects
Commit d82d3a48 authored by fima's avatar fima :beers:
Browse files

Merge branch 'vizualization_update' into 'main'

Vizualization update

See merge request !4
parents d9b70f14 39cca3af
No related branches found
No related tags found
1 merge request!4Vizualization update
This diff is collapsed.
from . import internal_tools
from . import models
from .data import Dataset
\ No newline at end of file
""" Tools performed with trained models."""
import torch
def inference(data,model):
"""Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list,
where each item is a tuple containing the input image tensor and the corresponding target label tensor.
The function checks the format and validity of the input data, ensures the model is in evaluation mode,
and generates predictions using the model. The input images, target labels, and predicted labels are returned
as a tuple.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The function does not assume the model is already in evaluation mode (model.eval()).
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
inference(data,model)
"""
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(len(data))]
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
else:
preds = outputs.cpu().argmax(axis=1)
# if there is only one image
if inputs.dim() == 2:
inputs = inputs.unsqueeze(0)
targets = targets.unsqueeze(0)
preds = preds.unsqueeze(0)
return inputs,targets,preds
\ No newline at end of file
......@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
fig.show()
def grid_pred(
data, model, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5
):
"""Displays a grid of input images, predicted segmentations, and ground truth segmentations.
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5):
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
num_images (int, optional): The maximum number of images to display. Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Displays a grid of subplots representing different aspects of the input images and segmentations.
The grid includes the following rows:
- Row 1: Input images
- Row 2: Predicted segmentations overlaying input images
- Row 3: Ground truth segmentations overlaying input images
- Row 4: Comparison between true and predicted segmentations overlaying input images
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset.
The function utilizes various color maps for visualization and applies transparency to the segmentations.
Notes:
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The function does not assume that the model is already in evaluation mode (model.eval()).
- The function will execute faster on a CUDA-enabled GPU.
- The grid layout consists of three rows: input images, predicted segmentations,
and ground truth segmentations.
Args:
in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations.
num_images (int, optional): Number of images to display. Defaults to 7.
cmap_im (str, optional): Color map for input images. Defaults to "gray".
cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
Returns:
None
Raises:
None
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
grid_pred(dataset, model, cmap_im='viridis', alpha=0.5)
in_targ_preds = qim3d.utils.models.inference(dataset,model)
grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)
"""
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
# Check if dataset have at least specified number of images
if len(data) < num_images:
if len(in_targ_preds[0]) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
num_images,
len(data),
len(in_targ_preds[0]),
)
num_images = len(data)
num_images = len(in_targ_preds[0])
# Take only the number of images from in_targ_preds
inputs,targets,preds = [items[:num_images] for items in in_targ_preds]
# Adapt segmentation cmap so that background is transparent
colors_segm = cm.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(num_images)]
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
else:
preds = outputs.cpu().argmax(axis=1)
N = len(plot_data)
H = plot_data[0][0].shape[-2]
W = plot_data[0][0].shape[-1]
N = num_images
H = inputs[0].shape[-2]
W = inputs[0].shape[-1]
comp_rgb = torch.zeros((N,4,H,W))
comp_rgb[:,1,:,:] = targets.logical_and(preds)
......@@ -223,7 +182,7 @@ def grid_pred(
elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(
plot_data[col][1].cpu().squeeze(), cmap=custom_cmap, alpha=alpha
targets[col], cmap=custom_cmap, alpha=alpha
)
ax.axis("off")
else:
......@@ -232,3 +191,5 @@ def grid_pred(
ax.axis("off")
fig.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment