Skip to content
Snippets Groups Projects
Commit 1ebfb049 authored by fima's avatar fima :beers:
Browse files

Merge branch 'Viz_Unittests' into 'main'

Implementation of unittests for vizualization functions + modification of load tests to run for Windows users.

See merge request !28
parents 784049a6 fbf7cb46
Branches
Tags
1 merge request!28Implementation of unittests for vizualization functions + modification of load tests to run for Windows users.
This diff is collapsed.
...@@ -321,8 +321,8 @@ class DataLoader: ...@@ -321,8 +321,8 @@ class DataLoader:
similar_paths = difflib.get_close_matches(path, valid_paths) similar_paths = difflib.get_close_matches(path, valid_paths)
if similar_paths: if similar_paths:
suggestion = similar_paths[0] # Get the closest match suggestion = similar_paths[0] # Get the closest match
message = f"Invalid path.\nDid you mean '{suggestion}'?" message = f"Invalid path. Did you mean '{suggestion}'?"
raise ValueError(message) raise ValueError(repr(message))
else: else:
raise ValueError("Invalid path") raise ValueError("Invalid path")
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
import os import os
import pytest import pytest
import re
# Load blobs volume into memory # Load blobs volume into memory
vol = qim3d.examples.blobs_256x256 vol = qim3d.examples.blobs_256x256
...@@ -30,5 +31,7 @@ def test_did_you_mean(): ...@@ -30,5 +31,7 @@ def test_did_you_mean():
# Remove last two characters from the path # Remove last two characters from the path
blobs_path_misspelled = str(blobs_path)[:-2] blobs_path_misspelled = str(blobs_path)[:-2]
with pytest.raises(ValueError,match=f"Invalid path.\nDid you mean '{blobs_path}'?"): message = f"Invalid path. Did you mean '{blobs_path}'?"
with pytest.raises(ValueError,match=re.escape(repr(message))):
qim3d.io.load(blobs_path_misspelled) qim3d.io.load(blobs_path_misspelled)
\ No newline at end of file
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import os import os
import hashlib import hashlib
import pytest import pytest
import re
def test_image_exist(): def test_image_exist():
# Create random test image # Create random test image
...@@ -141,7 +142,9 @@ def test_folder_doesnt_exist(): ...@@ -141,7 +142,9 @@ def test_folder_doesnt_exist():
# Create invalid path # Create invalid path
invalid_path = os.path.join('this','path','doesnt','exist.tif') invalid_path = os.path.join('this','path','doesnt','exist.tif')
with pytest.raises(ValueError,match=f'The directory {os.path.dirname(invalid_path)} does not exist. Please provide a valid directory'): message = f'The directory {re.escape(os.path.dirname(invalid_path))} does not exist. Please provide a valid directory'
with pytest.raises(ValueError,match=message):
# Try to save test image to an invalid path # Try to save test image to an invalid path
qim3d.io.save(invalid_path,test_image) qim3d.io.save(invalid_path,test_image)
......
import qim3d
import matplotlib.pyplot as plt
import pytest
from torch import ones
from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview
def test_grid_overview():
random_tuple = (ones(1,256,256),ones(256,256))
n_images = 10
train_set = [random_tuple for t in range(n_images)]
fig = qim3d.viz.grid_overview(train_set,num_images=n_images)
assert fig.get_figwidth() == 2*n_images
def test_grid_overview_tuple():
random_tuple = (ones(256,256),ones(256,256))
with pytest.raises(ValueError,match="Data elements must be tuples"):
qim3d.viz.grid_overview(random_tuple,num_images=1)
# unit tests for grid prediction
def test_grid_pred():
folder = 'folder_data'
n = 4
temp_data(folder,n = n)
model = qim3d.models.UNet()
augmentation = qim3d.utils.Augmentation()
train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation)
in_targ_pred = qim3d.utils.models.inference(train_set,model)
fig = qim3d.viz.grid_pred(in_targ_pred)
assert (fig.get_figwidth(),fig.get_figheight()) == (2*(n),10)
temp_data(folder,remove = True)
# unit tests for slice visualization
def test_slice_viz():
example_volume = ones(10,10,10)
img_width = 3
fig = qim3d.viz.slice_viz(example_volume,img_width = img_width)
assert fig.get_figwidth() == img_width
def test_slice_viz_not_volume():
example_volume = ones(10,10)
dim = example_volume.ndim
with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"):
qim3d.viz.slice_viz(example_volume)
def test_slice_viz_wrong_slice():
example_volume = ones(10,10,10)
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'):
qim3d.viz.slice_viz(example_volume, position = 'invalid_slice')
import qim3d
import pytest
#unit test for plot_metrics()
def test_plot_metrics():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
fig = qim3d.viz.plot_metrics(metrics, figsize=(10,10))
assert (fig.get_figwidth(),fig.get_figheight()) == (10,10)
def test_plot_metrics_labels():
metrics = {'epoch_loss': [0.3,0.2,0.1], 'batch_loss': [0.3,0.2,0.1]}
with pytest.raises(ValueError,match="The number of metrics doesn't match the number of labels."):
qim3d.viz.plot_metrics(metrics,labels = ['a','b'])
\ No newline at end of file
...@@ -119,10 +119,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 ...@@ -119,10 +119,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
) )
if plot: if plot:
fig = plt.figure(figsize=(16, 6), constrained_layout = True) plot_metrics(train_loss, val_loss, labels = ['Train','Valid.'], show = True)
plot_metrics(train_loss, label = 'Train')
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
if return_loss: if return_loss:
return train_loss,val_loss return train_loss,val_loss
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from qim3d.io.logger import log from qim3d.io.logger import log
import qim3d.io import qim3d.io
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5): def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
"""Displays an overview grid of images, labels, and masks (if they exist). """Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations Labels are the annotated target segmentations
...@@ -15,16 +15,12 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -15,16 +15,12 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
sparse labeled data sparse labeled data
Args: Args:
data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image, data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image, label, (and mask data).
label, (and mask data). num_images (int, optional): The maximum number of images to display. Defaults to 7.
num_images (int, optional): The maximum number of images to display. cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'.
Defaults to 7. cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'.
cmap_im (str, optional): The colormap to be used for displaying input images. alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5.
Defaults to 'gray'. show (bool, optional): If True, displays the plot. Defaults to False.
cmap_segm (str, optional): The colormap to be used for displaying labels.
Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays.
Defaults to 0.5.
Raises: Raises:
ValueError: If the data elements are not tuples. ValueError: If the data elements are not tuples.
...@@ -91,10 +87,15 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -91,10 +87,15 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
else: else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im) ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off") ax.axis("off")
fig.show()
if show:
plt.show()
plt.close()
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5): return fig
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False):
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations. Displays a grid of subplots representing different aspects of the input images and segmentations.
...@@ -113,6 +114,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ...@@ -113,6 +114,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
cmap_im (str, optional): Color map for input images. Defaults to "gray". cmap_im (str, optional): Color map for input images. Defaults to "gray".
cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis". cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
alpha (float, optional): Alpha value for transparency. Defaults to 0.5. alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False.
Returns: Returns:
None None
...@@ -189,9 +191,13 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ...@@ -189,9 +191,13 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha) ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha)
ax.axis("off") ax.axis("off")
fig.show() if show:
plt.show()
plt.close()
def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2): return fig
def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2,show = False):
""" Displays one or several slices from a 3d array. """ Displays one or several slices from a 3d array.
Args: Args:
...@@ -199,6 +205,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, ...@@ -199,6 +205,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
position (str, int, list, array, optional): One or several slicing levels. position (str, int, list, array, optional): One or several slicing levels.
cmap (str, optional): Specifies the color map for the image. cmap (str, optional): Specifies the color map for the image.
axis (bool, optional): Specifies whether the axes should be included. axis (bool, optional): Specifies whether the axes should be included.
show (bool, optional): If True, displays the plot. Defaults to False.
Raises: Raises:
ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom). ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
...@@ -220,7 +227,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, ...@@ -220,7 +227,7 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
# Numpy array input # Numpy array input
elif isinstance(input,np.ndarray): elif isinstance(input,(np.ndarray,torch.Tensor)):
dim = input.ndim dim = input.ndim
if dim == 3: if dim == 3:
...@@ -258,4 +265,8 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, ...@@ -258,4 +265,8 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
if not axis: if not axis:
ax.axis('off') ax.axis('off')
fig.show() if show:
\ No newline at end of file plt.show()
plt.close()
return fig
\ No newline at end of file
...@@ -4,41 +4,68 @@ import numpy as np ...@@ -4,41 +4,68 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as snb import seaborn as snb
def plot_metrics(metric, color = 'blue', linestyle = '-', batch_linestyle = 'dotted', label = None): def plot_metrics(*metrics,
linestyle = '-',
batch_linestyle = 'dotted',
labels:list = None,
figsize:tuple = (16,6),
show = False):
""" """
Plots the metrics over epochs and batches. Plots the metrics over epochs and batches.
Args: Args:
metric (dict): A dictionary containing the metrics per epochs and per batches. *metrics: Variable-length argument list of dictionary containing the metrics per epochs and per batches.
color (str, optional): The color of the plotted lines. Defaults to 'blue'.
linestyle (str, optional): The style of the epoch metric line. Defaults to '-'. linestyle (str, optional): The style of the epoch metric line. Defaults to '-'.
batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'. batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'.
label (str, optional): The label for the epoch metric line. Defaults to None. labels (list[str], optional): Labels for the plotted lines. Defaults to None.
figsize (Tuple[int, int], optional): Figure size (width, height) in inches. Defaults to (16, 6).
show (bool, optional): If True, displays the plot. Defaults to False.
Returns: Returns:
None if return_fig:
fig (matplotlib.figure.Figure): plot with metrics.
Example: Example:
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]} train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss, color = 'red', label='Train') val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
""" """
if labels == None:
labels = [None]*len(metrics)
elif len(metrics) != len(labels):
raise ValueError("The number of metrics doesn't match the number of labels.")
# plotting parameters # plotting parameters
snb.set_style('darkgrid') snb.set_style('darkgrid')
snb.set(font_scale=1.5) snb.set(font_scale=1.5)
plt.rcParams['lines.linewidth'] = 2 plt.rcParams['lines.linewidth'] = 2
fig = plt.figure(figsize = figsize)
palette = snb.color_palette(None,len(metrics))
for i,metric in enumerate(metrics):
metric_name = list(metric.keys())[0] metric_name = list(metric.keys())[0]
epoch_metric = metric[list(metric.keys())[0]] epoch_metric = metric[list(metric.keys())[0]]
batch_metric = metric[list(metric.keys())[1]] batch_metric = metric[list(metric.keys())[1]]
x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric)) x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric))
plt.plot(epoch_metric,linestyle = linestyle, color = color,label = label) plt.plot(epoch_metric,linestyle = linestyle, color = palette[i], label = labels[i])
plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = color, alpha = 0.4) plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = palette[i], alpha = 0.4)
if labels[0] != None:
plt.legend()
plt.ylabel(metric_name) plt.ylabel(metric_name)
plt.xlabel('epoch') plt.xlabel('epoch')
plt.legend()
# reset plotting parameters # reset plotting parameters
snb.set_style('white') snb.set_style('white')
if show:
plt.show()
plt.close()
return fig
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment