Skip to content
Snippets Groups Projects
Commit 1ebfb049 authored by fima's avatar fima :beers:
Browse files

Merge branch 'Viz_Unittests' into 'main'

Implementation of unittests for vizualization functions + modification of load tests to run for Windows users.

See merge request !28
parents 784049a6 fbf7cb46
No related branches found
No related tags found
1 merge request!28Implementation of unittests for vizualization functions + modification of load tests to run for Windows users.
This diff is collapsed.
......@@ -321,8 +321,8 @@ class DataLoader:
similar_paths = difflib.get_close_matches(path, valid_paths)
if similar_paths:
suggestion = similar_paths[0] # Get the closest match
message = f"Invalid path.\nDid you mean '{suggestion}'?"
raise ValueError(message)
message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(repr(message))
else:
raise ValueError("Invalid path")
......
......@@ -3,6 +3,7 @@ import numpy as np
from pathlib import Path
import os
import pytest
import re
# Load blobs volume into memory
vol = qim3d.examples.blobs_256x256
......@@ -30,5 +31,7 @@ def test_did_you_mean():
# Remove last two characters from the path
blobs_path_misspelled = str(blobs_path)[:-2]
with pytest.raises(ValueError,match=f"Invalid path.\nDid you mean '{blobs_path}'?"):
message = f"Invalid path. Did you mean '{blobs_path}'?"
with pytest.raises(ValueError,match=re.escape(repr(message))):
qim3d.io.load(blobs_path_misspelled)
\ No newline at end of file
......@@ -4,6 +4,7 @@ import numpy as np
import os
import hashlib
import pytest
import re
def test_image_exist():
# Create random test image
......@@ -141,7 +142,9 @@ def test_folder_doesnt_exist():
# Create invalid path
invalid_path = os.path.join('this','path','doesnt','exist.tif')
with pytest.raises(ValueError,match=f'The directory {os.path.dirname(invalid_path)} does not exist. Please provide a valid directory'):
message = f'The directory {re.escape(os.path.dirname(invalid_path))} does not exist. Please provide a valid directory'
with pytest.raises(ValueError,match=message):
# Try to save test image to an invalid path
qim3d.io.save(invalid_path,test_image)
......
import qim3d
import matplotlib.pyplot as plt
import pytest
from torch import ones
from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview
def test_grid_overview():
random_tuple = (ones(1,256,256),ones(256,256))
n_images = 10
train_set = [random_tuple for t in range(n_images)]
fig = qim3d.viz.grid_overview(train_set,num_images=n_images)
assert fig.get_figwidth() == 2*n_images
def test_grid_overview_tuple():
random_tuple = (ones(256,256),ones(256,256))
with pytest.raises(ValueError,match="Data elements must be tuples"):
qim3d.viz.grid_overview(random_tuple,num_images=1)
# unit tests for grid prediction
def test_grid_pred():
folder = 'folder_data'
n = 4
temp_data(folder,n = n)
model = qim3d.models.UNet()
augmentation = qim3d.utils.Augmentation()
train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation)
in_targ_pred = qim3d.utils.models.inference(train_set,model)
fig = qim3d.viz.grid_pred(in_targ_pred)
assert (fig.get_figwidth(),fig.get_figheight()) == (2*(n),10)
temp_data(folder,remove = True)
# unit tests for slice visualization
def test_slice_viz():
example_volume = ones(10,10,10)
img_width = 3
fig = qim3d.viz.slice_viz(example_volume,img_width = img_width)
assert fig.get_figwidth() == img_width
def test_slice_viz_not_volume():
example_volume = ones(10,10)
dim = example_volume.ndim
with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"):
qim3d.viz.slice_viz(example_volume)
def test_slice_viz_wrong_slice():
example_volume = ones(10,10,10)
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'):
qim3d.viz.slice_viz(example_volume, position = 'invalid_slice')
import qim3d
import pytest
#unit test for plot_metrics()
def test_plot_metrics():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
fig = qim3d.viz.plot_metrics(metrics, figsize=(10,10))
assert (fig.get_figwidth(),fig.get_figheight()) == (10,10)
def test_plot_metrics_labels():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
with pytest.raises(ValueError,match="The number of metrics doesn't match the number of labels."):
qim3d.viz.plot_metrics(metrics,labels = ['a','b'])
\ No newline at end of file
......@@ -119,10 +119,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
)
if plot:
fig = plt.figure(figsize=(16, 6), constrained_layout = True)
plot_metrics(train_loss, label = 'Train')
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
plot_metrics(train_loss, val_loss, labels = ['Train','Valid.'], show = True)
if return_loss:
return train_loss,val_loss
......
......@@ -7,7 +7,7 @@ import numpy as np
from qim3d.io.logger import log
import qim3d.io
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5):
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
"""Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
......@@ -15,16 +15,12 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
sparse labeled data
Args:
data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image,
label, (and mask data).
num_images (int, optional): The maximum number of images to display.
Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying labels.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays.
Defaults to 0.5.
data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image, label, (and mask data).
num_images (int, optional): The maximum number of images to display. Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False.
Raises:
ValueError: If the data elements are not tuples.
......@@ -91,10 +87,15 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off")
fig.show()
if show:
plt.show()
plt.close()
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5):
return fig
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False):
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations.
......@@ -113,6 +114,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
cmap_im (str, optional): Color map for input images. Defaults to "gray".
cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False.
Returns:
None
......@@ -189,9 +191,13 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha)
ax.axis("off")
fig.show()
if show:
plt.show()
plt.close()
def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2):
return fig
def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2,show = False):
""" Displays one or several slices from a 3d array.
Args:
......@@ -199,6 +205,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
position (str, int, list, array, optional): One or several slicing levels.
cmap (str, optional): Specifies the color map for the image.
axis (bool, optional): Specifies whether the axes should be included.
show (bool, optional): If True, displays the plot. Defaults to False.
Raises:
ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
......@@ -220,7 +227,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
# Numpy array input
elif isinstance(input,np.ndarray):
elif isinstance(input,(np.ndarray,torch.Tensor)):
dim = input.ndim
if dim == 3:
......@@ -258,4 +265,8 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
if not axis:
ax.axis('off')
fig.show()
\ No newline at end of file
if show:
plt.show()
plt.close()
return fig
\ No newline at end of file
......@@ -4,41 +4,68 @@ import numpy as np
import matplotlib.pyplot as plt
import seaborn as snb
def plot_metrics(metric, color = 'blue', linestyle = '-', batch_linestyle = 'dotted', label = None):
def plot_metrics(*metrics,
linestyle = '-',
batch_linestyle = 'dotted',
labels:list = None,
figsize:tuple = (16,6),
show = False):
"""
Plots the metrics over epochs and batches.
Args:
metric (dict): A dictionary containing the metrics per epochs and per batches.
color (str, optional): The color of the plotted lines. Defaults to 'blue'.
*metrics: Variable-length argument list of dictionary containing the metrics per epochs and per batches.
linestyle (str, optional): The style of the epoch metric line. Defaults to '-'.
batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'.
label (str, optional): The label for the epoch metric line. Defaults to None.
labels (list[str], optional): Labels for the plotted lines. Defaults to None.
figsize (Tuple[int, int], optional): Figure size (width, height) in inches. Defaults to (16, 6).
show (bool, optional): If True, displays the plot. Defaults to False.
Returns:
None
if return_fig:
fig (matplotlib.figure.Figure): plot with metrics.
Example:
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss, color = 'red', label='Train')
val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
"""
if labels == None:
labels = [None]*len(metrics)
elif len(metrics) != len(labels):
raise ValueError("The number of metrics doesn't match the number of labels.")
# plotting parameters
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
plt.rcParams['lines.linewidth'] = 2
fig = plt.figure(figsize = figsize)
palette = snb.color_palette(None,len(metrics))
for i,metric in enumerate(metrics):
metric_name = list(metric.keys())[0]
epoch_metric = metric[list(metric.keys())[0]]
batch_metric = metric[list(metric.keys())[1]]
x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric))
plt.plot(epoch_metric,linestyle = linestyle, color = color,label = label)
plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = color, alpha = 0.4)
plt.plot(epoch_metric,linestyle = linestyle, color = palette[i], label = labels[i])
plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = palette[i], alpha = 0.4)
if labels[0] != None:
plt.legend()
plt.ylabel(metric_name)
plt.xlabel('epoch')
plt.legend()
# reset plotting parameters
snb.set_style('white')
if show:
plt.show()
plt.close()
return fig
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment