From d6f89ddbd3272c6e23f4d9327a1b0016cb442d34 Mon Sep 17 00:00:00 2001 From: s184058 <s184058@student.dtu.dk> Date: Mon, 19 Feb 2024 15:28:28 +0100 Subject: [PATCH] Slice viz overhaul --- README.md | 8 +- docs/index.md | 8 +- docs/releases.md | 2 +- docs/viz.md | 2 +- qim3d/tests/viz/test_img.py | 99 +++++++++--- qim3d/viz/__init__.py | 2 +- qim3d/viz/img.py | 289 +++++++++++++++++++++++------------- 7 files changed, 278 insertions(+), 132 deletions(-) diff --git a/README.md b/README.md index b2d8dd6b..899596a1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True) ``` ## Visualize data -You can easily check slices from your volume using `slice_viz` +You can easily check slices from your volume using `slices` ```python import qim3d @@ -40,13 +40,13 @@ import qim3d img = qim3d.examples.fly_150x256x256 # By default shows the middle slice -qim3d.viz.slice_viz(img) +qim3d.viz.slices(img) # Or we can specifly positions -qim3d.viz.slice_viz(img, position=[0,32,128]) +qim3d.viz.slices(img, position=[0,32,128]) # Parameters for size and colormap are also possible -qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno") +qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno") ``` diff --git a/docs/index.md b/docs/index.md index b2d702e3..ced94f6c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True) ``` ### Visualize data -You can easily check slices from your volume using `slice_viz` +You can easily check slices from your volume using `slices` ```python import qim3d @@ -39,13 +39,13 @@ import qim3d img = qim3d.examples.fly_150x256x256 # By default shows the middle slice -qim3d.viz.slice_viz(img) +qim3d.viz.slices(img) # Or we can specifly positions -qim3d.viz.slice_viz(img, position=[0,32,128]) +qim3d.viz.slices(img, position=[0,32,128]) # Parameters for size and colormap are also possible -qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno") +qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno") ``` diff --git a/docs/releases.md b/docs/releases.md index 5a72d14c..f50c0240 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -31,7 +31,7 @@ Includes new develoments toward the usability of the library, as well as its int - For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality. - Improved data loader - Now the extensions `tif`, `h5` and `txm` are supported. -- Added `qim3d.viz.slice_viz` for easy slice visualization. +- Added `qim3d.viz.slices` for easy slice visualization. - U-net model creation - Model availabe from `qim3d.models.UNet` - Data augmentation class at `qim3d.utils.Augmentation` diff --git a/docs/viz.md b/docs/viz.md index 5a92ef56..095e9e52 100644 --- a/docs/viz.md +++ b/docs/viz.md @@ -12,7 +12,7 @@ import qim3d members: - grid_overview - grid_pred - - slice_viz + - slices diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 06fec0d0..17cc3072 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -1,13 +1,12 @@ +import torch +import numpy as np import qim3d -import matplotlib.pyplot as plt import pytest - -from torch import ones from qim3d.utils.internal_tools import temp_data # unit tests for grid overview def test_grid_overview(): - random_tuple = (ones(1,256,256),ones(256,256)) + random_tuple = (torch.ones(1,256,256),torch.ones(256,256)) n_images = 10 train_set = [random_tuple for t in range(n_images)] @@ -16,7 +15,7 @@ def test_grid_overview(): def test_grid_overview_tuple(): - random_tuple = (ones(256,256),ones(256,256)) + random_tuple = (torch.ones(256,256),torch.ones(256,256)) with pytest.raises(ValueError,match="Data elements must be tuples"): qim3d.viz.grid_overview(random_tuple,num_images=1) @@ -42,22 +41,88 @@ def test_grid_pred(): # unit tests for slice visualization -def test_slice_viz(): - example_volume = ones(10,10,10) +def test_slices_numpy_array_input(): + example_volume = np.ones((10, 10, 10)) + img_width = 3 + fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width) + assert fig.get_figwidth() == img_width + +def test_slices_torch_tensor_input(): + example_volume = torch.ones((10,10,10)) img_width = 3 - fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width) + fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width) assert fig.get_figwidth() == img_width +def test_slices_wrong_input_format(): + input = 'not_a_volume' + with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'): + qim3d.viz.slices(input) + +def test_slices_not_volume(): + example_volume = np.ones((10,10)) + with pytest.raises(ValueError, match = 'The provided object is not a volume as it has less than 3 dimensions.'): + qim3d.viz.slices(example_volume) + +def test_slices_wrong_position_format1(): + example_volume = np.ones((10,10,10)) + with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): + qim3d.viz.slices(example_volume, position = 'invalid_slice') + +def test_slices_wrong_position_format2(): + example_volume = np.ones((10,10,10)) + with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): + qim3d.viz.slices(example_volume, position = 1.5) + +def test_slices_wrong_position_format3(): + example_volume = np.ones((10,10,10)) + with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'): + qim3d.viz.slices(example_volume, position = [1, 2, 3.5]) + +def test_slices_invalid_axis_value(): + example_volume = np.ones((10,10,10)) + with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"): + qim3d.viz.slices(example_volume, axis = 3) + +def test_slices_show_title_option(): + example_volume = np.ones((10, 10, 10)) + img_width = 3 + fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False) + # Assert that titles are not shown + assert all(ax.get_title() == '' for ax in fig.get_axes()) -def test_slice_viz_not_volume(): - example_volume = ones(10,10) - dim = example_volume.ndim - with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"): - qim3d.viz.slice_viz(example_volume) +def test_slices_interpolation_option(): + example_volume = torch.ones((10, 10, 10)) + img_width = 3 + interpolation_method = 'bilinear' + fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, interpolation=interpolation_method) + + for ax in fig.get_axes(): + # Access the interpolation method used for each Axes object + actual_interpolation = ax.images[0].get_interpolation() + + # Assert that the actual interpolation method matches the expected method + assert actual_interpolation == interpolation_method + +def test_slices_multiple_slices(): + example_volume = np.ones((10, 10, 10)) + img_width = 3 + n_slices = 3 + fig = qim3d.viz.slices(example_volume, n_slices=n_slices, img_width=img_width) + # Add assertions for the expected number of subplots in the figure + assert len(fig.get_axes()) == n_slices + +def test_slices_axis_argument(): + # Non-symmetric input + example_volume = np.arange(1000).reshape((10, 10, 10)) + img_width = 3 + # Call the function with different values of the axis + fig_axis_0 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=0) + fig_axis_1 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=1) + fig_axis_2 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=2) -def test_slice_viz_wrong_slice(): - example_volume = ones(10,10,10) - with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'): - qim3d.viz.slice_viz(example_volume, position = 'invalid_slice') + # Ensure that different axes result in different plots + assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array()) + assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array()) + assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array()) \ No newline at end of file diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 1c36c157..98132cca 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,2 +1,2 @@ from .visualizations import plot_metrics -from .img import grid_pred, grid_overview, slice_viz +from .img import grid_pred, grid_overview, slices diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 804cdd28..f228800c 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -1,15 +1,22 @@ """ Provides a collection of visualization functions. """ + +from typing import List, Optional, Union import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib import colormaps import torch import numpy as np from qim3d.io.logger import log +import math import qim3d.io +import os + -def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False): +def grid_overview( + data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False +): """Displays an overview grid of images, labels, and masks (if they exist). Labels are the annotated target segmentations @@ -22,14 +29,14 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'. cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'. alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5. - show (bool, optional): If True, displays the plot. Defaults to False. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. Raises: ValueError: If the data elements are not tuples. Returns: - fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels. + fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels. Example: ```python @@ -75,7 +82,9 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha # Make new list such that possible augmentations remain identical for all three rows plot_data = [data[idx] for idx in range(num_images)] - fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True) + fig = plt.figure( + figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True + ) # create 2 (3) x 1 subfigs subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1) @@ -92,15 +101,22 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha else: ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im) ax.axis("off") - + if show: plt.show() plt.close() - + return fig -def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False): +def grid_pred( + in_targ_preds, + num_images=7, + cmap_im="gray", + cmap_segm="viridis", + alpha=0.5, + show=False, +): """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. Displays a grid of subplots representing different aspects of the input images and segmentations. @@ -119,7 +135,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", cmap_im (str, optional): Color map for input images. Defaults to "gray". cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis". alpha (float, optional): Alpha value for transparency. Defaults to 0.5. - show (bool, optional): If True, displays the plot. Defaults to False. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. Returns: fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models. @@ -131,9 +147,9 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", dataset = MySegmentationDataset() model = MySegmentationModel() in_targ_preds = qim3d.utils.models.inference(dataset,model) - grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5) + grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5) """ - + # Check if dataset have at least specified number of images if len(in_targ_preds[0]) < num_images: log.warning( @@ -144,21 +160,21 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", num_images = len(in_targ_preds[0]) # Take only the number of images from in_targ_preds - inputs,targets,preds = [items[:num_images] for items in in_targ_preds] - + inputs, targets, preds = [items[:num_images] for items in in_targ_preds] + # Adapt segmentation cmap so that background is transparent colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256)) colors_segm[:128, 3] = 0 custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm) - + N = num_images H = inputs[0].shape[-2] W = inputs[0].shape[-1] - comp_rgb = torch.zeros((N,4,H,W)) - comp_rgb[:,1,:,:] = targets.logical_and(preds) - comp_rgb[:,0,:,:] = targets.logical_xor(preds) - comp_rgb[:,3,:,:] = targets.logical_or(preds) + comp_rgb = torch.zeros((N, 4, H, W)) + comp_rgb[:, 1, :, :] = targets.logical_and(preds) + comp_rgb[:, 0, :, :] = targets.logical_xor(preds) + comp_rgb[:, 3, :, :] = targets.logical_or(preds) row_titles = [ "Input images", @@ -187,9 +203,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ax.axis("off") elif row == 2: # Ground truth segmentation ax.imshow(inputs[col], cmap=cmap_im) - ax.imshow( - targets[col], cmap=custom_cmap, alpha=alpha - ) + ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha) ax.axis("off") else: ax.imshow(inputs[col], cmap=cmap_im) @@ -202,105 +216,172 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", return fig -def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False): - """ Displays one or several slices from a 3d array. - By default if `position` is None, slice_viz plots an overview of the entire stack. - If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position. - If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted. - +def slices( + vol: Union[np.ndarray, torch.Tensor], + axis: int = 0, + position: Optional[Union[str, int, List[int]]] = None, + n_slices: int = 5, + max_cols: int = 5, + cmap: str = "viridis", + img_height: int = 2, + img_width: int = 2, + show: bool = False, + show_position: bool = True, + interpolation: Optional[str] = None, +) -> plt.Figure: + """Displays one or several slices from a 3d volume. + + By default if `position` is None, slices plots `n_slices` linearly spaced slices. + If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position. + If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted. + Args: - input (str, numpy.ndarray): Path to the file or 3-dimensional array. - position (str, int, list, array, optional): One or several slicing levels. - n_slices (int, optional): Defines how many slices the user wants. - cmap (str, optional): Specifies the color map for the image. - axis (bool, optional): Specifies whether the axes should be included. + vol (np.ndarray or torch.Tensor): The 3D volume to be sliced. + axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. + n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. + max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5. + cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". img_height(int, optional): Height of the figure. img_width(int, optional): Width of the figure. - show (bool, optional): If True, displays the plot. Defaults to False. + show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + show_position (bool, optional): If True, displays the position of the slices. Defaults to True. + interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. Returns: fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. Raises: - ValueError: If the file or array is not a 3D volume. - ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom). - + ValueError: If the input is not a numpy.ndarray or torch.Tensor. + ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. + ValueError: If the file or array is not a volume with at least 3 dimensions. + ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end". + Example: - image_path = '/my_image_path/my_image.tif' - slice_viz(image_path) + vol_path = '/my_vol_path/my_vol.tif' + vol = qim3d.io.load(vol_path) + slices(vol, axis = 1, position = 'mid', n_slices = 3, cmap = 'viridis', img_height = 4, img_width = 4, show = True, show_position = True, interpolation = None) """ - - # Filepath input - if isinstance(input,str): - vol = qim3d.io.load(input) # Function has its own ValueErrors - dim = vol.ndim - - # Numpy array input - elif isinstance(input,(np.ndarray,torch.Tensor)): - vol = input - dim = input.ndim - - if dim != 3: - raise ValueError(f"Given array is not a volume! Current dimension: {dim}") + # Numpy array or Torch tensor input + if not isinstance(vol, (np.ndarray, torch.Tensor)): + raise ValueError("Input must be a numpy.ndarray or torch.Tensor") + + if vol.ndim < 3: + raise ValueError( + "The provided object is not a volume as it has less than 3 dimensions." + ) + + # Ensure axis is a valid choice + if not (0 <= axis < vol.ndim): + raise ValueError( + f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}." + ) + + # Get total number of slices in the specified dimension + n_total = vol.shape[axis] + + # Position is not provided - will use linearly spaced slices if position is None: - height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int) - + slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int) # Position is a string - elif isinstance(position,str): - - if position.lower() in ['mid','middle']: - expansion_start = int(vol.shape[0]/2) - height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) - - elif position.lower() in ['top','upper', 'start']: - expansion_start = 0 - height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int) - - elif position.lower() in ['bot','bottom', 'end']: - expansion_start = vol.shape[0]-1 - height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int) - - else: - raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".') - - - # Position is an integer - elif isinstance(position,int): - expansion_start = position - n_stacks = vol.shape[0]-1 - - # if linspace would extend beyond n_stacks - if expansion_start + n_slices > n_stacks: - height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int) - - # if linspace would extend below 0 - elif expansion_start - n_slices < 0: - height = np.linspace(0,n_slices-1,n_slices).astype(int) - - else: - height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) - - - # Position is a list or array of integers - elif isinstance(position,(list,np.ndarray)): - height = position - - num_images = len(height) - - - fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True) - axs = fig.subplots(nrows = 1, ncols = num_images) - - for col, ax in enumerate(np.atleast_1d(axs)): - ax.imshow(vol[height[col],:,:],cmap = cmap) - ax.set_title(f'Slice {height[col]}', fontsize=6*img_height) - if not axis: - ax.axis('off') - + elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]: + if position.lower() == "start": + slice_idxs = _get_slice_range(0, n_slices, n_total) + elif position.lower() == "mid": + slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total) + elif position.lower() == "end": + slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total) + # Position is an integer + elif isinstance(position, int): + slice_idxs = _get_slice_range(position, n_slices, n_total) + # Position is a list of integers + elif isinstance(position, list) and all(isinstance(idx, int) for idx in position): + slice_idxs = position + else: + raise ValueError( + 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' + ) + + + # Make grid + nrows = math.ceil(n_slices / max_cols) + ncols = min(n_slices, max_cols) + + # Generate figure + fig, axs = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=(ncols * img_height, nrows * img_width), + constrained_layout=True, + ) + if nrows == 1: + axs = [axs] # Convert to a list for uniformity + + # Convert Torch tensor to NumPy array in order to use the numpy.take method + if isinstance(vol, torch.Tensor): + vol = vol.numpy() + + # Run through each ax of the grid + for i, ax_row in enumerate(axs): + for j, ax in enumerate(np.atleast_1d(ax_row)): + slice_idx = i * max_cols + j + try: + slice_img = vol.take(slice_idxs[slice_idx], axis=axis) + ax.imshow(slice_img, cmap=cmap, interpolation=interpolation) + + if show_position: + ax.text( + 0.0, + 1.0, + f"slice {slice_idxs[slice_idx]} ", + transform=ax.transAxes, + color="white", + fontsize=8, + va="top", + ha="left", + bbox=dict(facecolor="#303030", linewidth=0, pad=0), + ) + + ax.text( + 1.0, + 0.0, + f"axis {axis} ", + transform=ax.transAxes, + color="white", + fontsize=8, + va="bottom", + ha="right", + bbox=dict(facecolor="#303030", linewidth=0, pad=0), + ) + + except IndexError: + # Not a problem, because we simply do not have a slice to show + pass + + # Hide the axis, so that we have a nice grid + ax.axis("off") + if show: plt.show() + plt.close() - return fig \ No newline at end of file + return fig + + +def _get_slice_range(position: int, n_slices: int, n_total): + """Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" + start_idx = position - n_slices // 2 + end_idx = ( + position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1 + ) + slice_idxs = np.arange(start_idx, end_idx) + + if slice_idxs[0] < 0: + slice_idxs = np.arange(0, n_slices) + elif slice_idxs[-1] > n_total: + slice_idxs = np.arange(n_total - n_slices, n_total) + + return slice_idxs -- GitLab