Skip to content
Snippets Groups Projects
Commit d82d3a48 authored by fima's avatar fima :beers:
Browse files

Merge branch 'vizualization_update' into 'main'

Vizualization update

See merge request !4
parents d9b70f14 39cca3af
Branches
Tags
1 merge request!4Vizualization update
Source diff could not be displayed: it is too large. Options to address this: view the blob.
from . import internal_tools from . import internal_tools
from . import models
from .data import Dataset from .data import Dataset
\ No newline at end of file
""" Tools performed with trained models."""
import torch
def inference(data,model):
"""Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list,
where each item is a tuple containing the input image tensor and the corresponding target label tensor.
The function checks the format and validity of the input data, ensures the model is in evaluation mode,
and generates predictions using the model. The input images, target labels, and predicted labels are returned
as a tuple.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The function does not assume the model is already in evaluation mode (model.eval()).
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
inference(data,model)
"""
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(len(data))]
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
else:
preds = outputs.cpu().argmax(axis=1)
# if there is only one image
if inputs.dim() == 2:
inputs = inputs.unsqueeze(0)
targets = targets.unsqueeze(0)
preds = preds.unsqueeze(0)
return inputs,targets,preds
\ No newline at end of file
...@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
fig.show() fig.show()
def grid_pred( def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5):
data, model, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5 """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
):
"""Displays a grid of input images, predicted segmentations, and ground truth segmentations.
Args: Displays a grid of subplots representing different aspects of the input images and segmentations.
data (torch.utils.data.Dataset): A Torch dataset containing input image and The grid includes the following rows:
ground truth label data. - Row 1: Input images
model (torch.nn.Module): The trained network model used for predicting segmentations. - Row 2: Predicted segmentations overlaying input images
num_images (int, optional): The maximum number of images to display. Defaults to 7. - Row 3: Ground truth segmentations overlaying input images
cmap_im (str, optional): The colormap to be used for displaying input images. - Row 4: Comparison between true and predicted segmentations overlaying input images
Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Raises: Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset.
ValueError: If the data items are not tuples or data items do not consist of tensors. The function utilizes various color maps for visualization and applies transparency to the segmentations.
ValueError: If the input image is not in (C, H, W) format.
Notes: Args:
- The number of displayed images is limited to the minimum between `num_images` in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations.
and the length of the data. num_images (int, optional): Number of images to display. Defaults to 7.
- The function does not assume that the model is already in evaluation mode (model.eval()). cmap_im (str, optional): Color map for input images. Defaults to "gray".
- The function will execute faster on a CUDA-enabled GPU. cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
- The grid layout consists of three rows: input images, predicted segmentations, alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
and ground truth segmentations.
Returns: Returns:
None None
Raises:
None
Example: Example:
dataset = MySegmentationDataset() dataset = MySegmentationDataset()
model = MySegmentationModel() model = MySegmentationModel()
grid_pred(dataset, model, cmap_im='viridis', alpha=0.5) in_targ_preds = qim3d.utils.models.inference(dataset,model)
grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)
""" """
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
# Check if dataset have at least specified number of images # Check if dataset have at least specified number of images
if len(data) < num_images: if len(in_targ_preds[0]) < num_images:
log.warning( log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d", "Not enough images in the dataset. Changing num_images=%d to num_images=%d",
num_images, num_images,
len(data), len(in_targ_preds[0]),
) )
num_images = len(data) num_images = len(in_targ_preds[0])
# Take only the number of images from in_targ_preds
inputs,targets,preds = [items[:num_images] for items in in_targ_preds]
# Adapt segmentation cmap so that background is transparent # Adapt segmentation cmap so that background is transparent
colors_segm = cm.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) colors_segm = cm.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0 colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
model.eval() N = num_images
H = inputs[0].shape[-2]
# Make new list such that possible augmentations remain identical for all three rows W = inputs[0].shape[-1]
plot_data = [data[idx] for idx in range(num_images)]
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
else:
preds = outputs.cpu().argmax(axis=1)
N = len(plot_data)
H = plot_data[0][0].shape[-2]
W = plot_data[0][0].shape[-1]
comp_rgb = torch.zeros((N,4,H,W)) comp_rgb = torch.zeros((N,4,H,W))
comp_rgb[:,1,:,:] = targets.logical_and(preds) comp_rgb[:,1,:,:] = targets.logical_and(preds)
...@@ -223,7 +182,7 @@ def grid_pred( ...@@ -223,7 +182,7 @@ def grid_pred(
elif row == 2: # Ground truth segmentation elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow( ax.imshow(
plot_data[col][1].cpu().squeeze(), cmap=custom_cmap, alpha=alpha targets[col], cmap=custom_cmap, alpha=alpha
) )
ax.axis("off") ax.axis("off")
else: else:
...@@ -232,3 +191,5 @@ def grid_pred( ...@@ -232,3 +191,5 @@ def grid_pred(
ax.axis("off") ax.axis("off")
fig.show() fig.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment