Skip to content
Snippets Groups Projects
Commit c00db7f0 authored by fima's avatar fima :beers:
Browse files

Merge branch 'slice_viz_overhaul' into 'main'

Slice viz overhaul

See merge request !52
parents 28e57057 d6f89ddb
No related branches found
No related tags found
1 merge request!52Slice viz overhaul
......@@ -32,7 +32,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
```
## Visualize data
You can easily check slices from your volume using `slice_viz`
You can easily check slices from your volume using `slices`
```python
import qim3d
......@@ -40,13 +40,13 @@ import qim3d
img = qim3d.examples.fly_150x256x256
# By default shows the middle slice
qim3d.viz.slice_viz(img)
qim3d.viz.slices(img)
# Or we can specifly positions
qim3d.viz.slice_viz(img, position=[0,32,128])
qim3d.viz.slices(img, position=[0,32,128])
# Parameters for size and colormap are also possible
qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno")
qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
```
......
......@@ -31,7 +31,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
```
### Visualize data
You can easily check slices from your volume using `slice_viz`
You can easily check slices from your volume using `slices`
```python
import qim3d
......@@ -39,13 +39,13 @@ import qim3d
img = qim3d.examples.fly_150x256x256
# By default shows the middle slice
qim3d.viz.slice_viz(img)
qim3d.viz.slices(img)
# Or we can specifly positions
qim3d.viz.slice_viz(img, position=[0,32,128])
qim3d.viz.slices(img, position=[0,32,128])
# Parameters for size and colormap are also possible
qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno")
qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
```
......
......@@ -31,7 +31,7 @@ Includes new develoments toward the usability of the library, as well as its int
- For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality.
- Improved data loader
- Now the extensions `tif`, `h5` and `txm` are supported.
- Added `qim3d.viz.slice_viz` for easy slice visualization.
- Added `qim3d.viz.slices` for easy slice visualization.
- U-net model creation
- Model availabe from `qim3d.models.UNet`
- Data augmentation class at `qim3d.utils.Augmentation`
......
......@@ -12,7 +12,7 @@ import qim3d
members:
- grid_overview
- grid_pred
- slice_viz
- slices
import torch
import numpy as np
import qim3d
import matplotlib.pyplot as plt
import pytest
from torch import ones
from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview
def test_grid_overview():
random_tuple = (ones(1,256,256),ones(256,256))
random_tuple = (torch.ones(1,256,256),torch.ones(256,256))
n_images = 10
train_set = [random_tuple for t in range(n_images)]
......@@ -16,7 +15,7 @@ def test_grid_overview():
def test_grid_overview_tuple():
random_tuple = (ones(256,256),ones(256,256))
random_tuple = (torch.ones(256,256),torch.ones(256,256))
with pytest.raises(ValueError,match="Data elements must be tuples"):
qim3d.viz.grid_overview(random_tuple,num_images=1)
......@@ -42,22 +41,88 @@ def test_grid_pred():
# unit tests for slice visualization
def test_slice_viz():
example_volume = ones(10,10,10)
def test_slices_numpy_array_input():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
assert fig.get_figwidth() == img_width
def test_slices_torch_tensor_input():
example_volume = torch.ones((10,10,10))
img_width = 3
fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width)
fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
assert fig.get_figwidth() == img_width
def test_slices_wrong_input_format():
input = 'not_a_volume'
with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'):
qim3d.viz.slices(input)
def test_slices_not_volume():
example_volume = np.ones((10,10))
with pytest.raises(ValueError, match = 'The provided object is not a volume as it has less than 3 dimensions.'):
qim3d.viz.slices(example_volume)
def test_slices_wrong_position_format1():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = 'invalid_slice')
def test_slices_wrong_position_format2():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = 1.5)
def test_slices_wrong_position_format3():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = [1, 2, 3.5])
def test_slices_invalid_axis_value():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
qim3d.viz.slices(example_volume, axis = 3)
def test_slices_show_title_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
# Assert that titles are not shown
assert all(ax.get_title() == '' for ax in fig.get_axes())
def test_slice_viz_not_volume():
example_volume = ones(10,10)
dim = example_volume.ndim
with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"):
qim3d.viz.slice_viz(example_volume)
def test_slices_interpolation_option():
example_volume = torch.ones((10, 10, 10))
img_width = 3
interpolation_method = 'bilinear'
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, interpolation=interpolation_method)
for ax in fig.get_axes():
# Access the interpolation method used for each Axes object
actual_interpolation = ax.images[0].get_interpolation()
# Assert that the actual interpolation method matches the expected method
assert actual_interpolation == interpolation_method
def test_slices_multiple_slices():
example_volume = np.ones((10, 10, 10))
img_width = 3
n_slices = 3
fig = qim3d.viz.slices(example_volume, n_slices=n_slices, img_width=img_width)
# Add assertions for the expected number of subplots in the figure
assert len(fig.get_axes()) == n_slices
def test_slices_axis_argument():
# Non-symmetric input
example_volume = np.arange(1000).reshape((10, 10, 10))
img_width = 3
# Call the function with different values of the axis
fig_axis_0 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=0)
fig_axis_1 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=1)
fig_axis_2 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=2)
def test_slice_viz_wrong_slice():
example_volume = ones(10,10,10)
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'):
qim3d.viz.slice_viz(example_volume, position = 'invalid_slice')
# Ensure that different axes result in different plots
assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
\ No newline at end of file
from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slice_viz
from .img import grid_pred, grid_overview, slices
"""
Provides a collection of visualization functions.
"""
from typing import List, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
import torch
import numpy as np
from qim3d.io.logger import log
import math
import qim3d.io
import os
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False):
def grid_overview(
data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
):
"""Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
......@@ -22,7 +29,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
Raises:
ValueError: If the data elements are not tuples.
......@@ -75,7 +82,9 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(num_images)]
fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True)
fig = plt.figure(
figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True
)
# create 2 (3) x 1 subfigs
subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1)
......@@ -100,7 +109,14 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
return fig
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False):
def grid_pred(
in_targ_preds,
num_images=7,
cmap_im="gray",
cmap_segm="viridis",
alpha=0.5,
show=False,
):
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations.
......@@ -119,7 +135,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
cmap_im (str, optional): Color map for input images. Defaults to "gray".
cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
Returns:
fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models.
......@@ -187,9 +203,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
ax.axis("off")
elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(
targets[col], cmap=custom_cmap, alpha=alpha
)
ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
else:
ax.imshow(inputs[col], cmap=cmap_im)
......@@ -202,105 +216,172 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
return fig
def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False):
""" Displays one or several slices from a 3d array.
By default if `position` is None, slice_viz plots an overview of the entire stack.
If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position.
If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted.
def slices(
vol: Union[np.ndarray, torch.Tensor],
axis: int = 0,
position: Optional[Union[str, int, List[int]]] = None,
n_slices: int = 5,
max_cols: int = 5,
cmap: str = "viridis",
img_height: int = 2,
img_width: int = 2,
show: bool = False,
show_position: bool = True,
interpolation: Optional[str] = None,
) -> plt.Figure:
"""Displays one or several slices from a 3d volume.
By default if `position` is None, slices plots `n_slices` linearly spaced slices.
If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.
Args:
input (str, numpy.ndarray): Path to the file or 3-dimensional array.
position (str, int, list, array, optional): One or several slicing levels.
n_slices (int, optional): Defines how many slices the user wants.
cmap (str, optional): Specifies the color map for the image.
axis (bool, optional): Specifies whether the axes should be included.
vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure.
img_width(int, optional): Width of the figure.
show (bool, optional): If True, displays the plot. Defaults to False.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
Raises:
ValueError: If the file or array is not a 3D volume.
ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
ValueError: If the input is not a numpy.ndarray or torch.Tensor.
ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
ValueError: If the file or array is not a volume with at least 3 dimensions.
ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
Example:
image_path = '/my_image_path/my_image.tif'
slice_viz(image_path)
vol_path = '/my_vol_path/my_vol.tif'
vol = qim3d.io.load(vol_path)
slices(vol, axis = 1, position = 'mid', n_slices = 3, cmap = 'viridis', img_height = 4, img_width = 4, show = True, show_position = True, interpolation = None)
"""
# Filepath input
if isinstance(input,str):
vol = qim3d.io.load(input) # Function has its own ValueErrors
dim = vol.ndim
# Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, torch.Tensor)):
raise ValueError("Input must be a numpy.ndarray or torch.Tensor")
# Numpy array input
elif isinstance(input,(np.ndarray,torch.Tensor)):
vol = input
dim = input.ndim
if vol.ndim < 3:
raise ValueError(
"The provided object is not a volume as it has less than 3 dimensions."
)
# Ensure axis is a valid choice
if not (0 <= axis < vol.ndim):
raise ValueError(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
)
if dim != 3:
raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
# Get total number of slices in the specified dimension
n_total = vol.shape[axis]
# Position is not provided - will use linearly spaced slices
if position is None:
height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int)
slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
# Position is a string
elif isinstance(position,str):
if position.lower() in ['mid','middle']:
expansion_start = int(vol.shape[0]/2)
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
elif position.lower() in ['top','upper', 'start']:
expansion_start = 0
height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int)
elif position.lower() in ['bot','bottom', 'end']:
expansion_start = vol.shape[0]-1
height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int)
else:
raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".')
elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
if position.lower() == "start":
slice_idxs = _get_slice_range(0, n_slices, n_total)
elif position.lower() == "mid":
slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
elif position.lower() == "end":
slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
# Position is an integer
elif isinstance(position, int):
expansion_start = position
n_stacks = vol.shape[0]-1
# if linspace would extend beyond n_stacks
if expansion_start + n_slices > n_stacks:
height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int)
# if linspace would extend below 0
elif expansion_start - n_slices < 0:
height = np.linspace(0,n_slices-1,n_slices).astype(int)
slice_idxs = _get_slice_range(position, n_slices, n_total)
# Position is a list of integers
elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
slice_idxs = position
else:
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
raise ValueError(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
)
# Position is a list or array of integers
elif isinstance(position,(list,np.ndarray)):
height = position
# Make grid
nrows = math.ceil(n_slices / max_cols)
ncols = min(n_slices, max_cols)
num_images = len(height)
# Generate figure
fig, axs = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * img_height, nrows * img_width),
constrained_layout=True,
)
if nrows == 1:
axs = [axs] # Convert to a list for uniformity
# Convert Torch tensor to NumPy array in order to use the numpy.take method
if isinstance(vol, torch.Tensor):
vol = vol.numpy()
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)):
slice_idx = i * max_cols + j
try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
ax.imshow(slice_img, cmap=cmap, interpolation=interpolation)
if show_position:
ax.text(
0.0,
1.0,
f"slice {slice_idxs[slice_idx]} ",
transform=ax.transAxes,
color="white",
fontsize=8,
va="top",
ha="left",
bbox=dict(facecolor="#303030", linewidth=0, pad=0),
)
ax.text(
1.0,
0.0,
f"axis {axis} ",
transform=ax.transAxes,
color="white",
fontsize=8,
va="bottom",
ha="right",
bbox=dict(facecolor="#303030", linewidth=0, pad=0),
)
fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True)
axs = fig.subplots(nrows = 1, ncols = num_images)
except IndexError:
# Not a problem, because we simply do not have a slice to show
pass
for col, ax in enumerate(np.atleast_1d(axs)):
ax.imshow(vol[height[col],:,:],cmap = cmap)
ax.set_title(f'Slice {height[col]}', fontsize=6*img_height)
if not axis:
ax.axis('off')
# Hide the axis, so that we have a nice grid
ax.axis("off")
if show:
plt.show()
plt.close()
return fig
def _get_slice_range(position: int, n_slices: int, n_total):
"""Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
start_idx = position - n_slices // 2
end_idx = (
position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1
)
slice_idxs = np.arange(start_idx, end_idx)
if slice_idxs[0] < 0:
slice_idxs = np.arange(0, n_slices)
elif slice_idxs[-1] > n_total:
slice_idxs = np.arange(n_total - n_slices, n_total)
return slice_idxs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment