Skip to content
Snippets Groups Projects
Commit d6f89ddb authored by s184058's avatar s184058 Committed by fima
Browse files

Slice viz overhaul

parent 28e57057
No related branches found
No related tags found
1 merge request!52Slice viz overhaul
...@@ -32,7 +32,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True) ...@@ -32,7 +32,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
``` ```
## Visualize data ## Visualize data
You can easily check slices from your volume using `slice_viz` You can easily check slices from your volume using `slices`
```python ```python
import qim3d import qim3d
...@@ -40,13 +40,13 @@ import qim3d ...@@ -40,13 +40,13 @@ import qim3d
img = qim3d.examples.fly_150x256x256 img = qim3d.examples.fly_150x256x256
# By default shows the middle slice # By default shows the middle slice
qim3d.viz.slice_viz(img) qim3d.viz.slices(img)
# Or we can specifly positions # Or we can specifly positions
qim3d.viz.slice_viz(img, position=[0,32,128]) qim3d.viz.slices(img, position=[0,32,128])
# Parameters for size and colormap are also possible # Parameters for size and colormap are also possible
qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno") qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
``` ```
......
...@@ -31,7 +31,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True) ...@@ -31,7 +31,7 @@ vol = qim3d.io.load("path/to/file.tif", virtual_stack=True)
``` ```
### Visualize data ### Visualize data
You can easily check slices from your volume using `slice_viz` You can easily check slices from your volume using `slices`
```python ```python
import qim3d import qim3d
...@@ -39,13 +39,13 @@ import qim3d ...@@ -39,13 +39,13 @@ import qim3d
img = qim3d.examples.fly_150x256x256 img = qim3d.examples.fly_150x256x256
# By default shows the middle slice # By default shows the middle slice
qim3d.viz.slice_viz(img) qim3d.viz.slices(img)
# Or we can specifly positions # Or we can specifly positions
qim3d.viz.slice_viz(img, position=[0,32,128]) qim3d.viz.slices(img, position=[0,32,128])
# Parameters for size and colormap are also possible # Parameters for size and colormap are also possible
qim3d.viz.slice_viz(img, img_width=6, img_height=6, cmap="inferno") qim3d.viz.slices(img, img_width=6, img_height=6, cmap="inferno")
``` ```
......
...@@ -31,7 +31,7 @@ Includes new develoments toward the usability of the library, as well as its int ...@@ -31,7 +31,7 @@ Includes new develoments toward the usability of the library, as well as its int
- For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality. - For the local thicknes GUI, now it is possible to pass and receive numpy arrays instead of using the upload functionality.
- Improved data loader - Improved data loader
- Now the extensions `tif`, `h5` and `txm` are supported. - Now the extensions `tif`, `h5` and `txm` are supported.
- Added `qim3d.viz.slice_viz` for easy slice visualization. - Added `qim3d.viz.slices` for easy slice visualization.
- U-net model creation - U-net model creation
- Model availabe from `qim3d.models.UNet` - Model availabe from `qim3d.models.UNet`
- Data augmentation class at `qim3d.utils.Augmentation` - Data augmentation class at `qim3d.utils.Augmentation`
......
...@@ -12,7 +12,7 @@ import qim3d ...@@ -12,7 +12,7 @@ import qim3d
members: members:
- grid_overview - grid_overview
- grid_pred - grid_pred
- slice_viz - slices
import torch
import numpy as np
import qim3d import qim3d
import matplotlib.pyplot as plt
import pytest import pytest
from torch import ones
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for grid overview # unit tests for grid overview
def test_grid_overview(): def test_grid_overview():
random_tuple = (ones(1,256,256),ones(256,256)) random_tuple = (torch.ones(1,256,256),torch.ones(256,256))
n_images = 10 n_images = 10
train_set = [random_tuple for t in range(n_images)] train_set = [random_tuple for t in range(n_images)]
...@@ -16,7 +15,7 @@ def test_grid_overview(): ...@@ -16,7 +15,7 @@ def test_grid_overview():
def test_grid_overview_tuple(): def test_grid_overview_tuple():
random_tuple = (ones(256,256),ones(256,256)) random_tuple = (torch.ones(256,256),torch.ones(256,256))
with pytest.raises(ValueError,match="Data elements must be tuples"): with pytest.raises(ValueError,match="Data elements must be tuples"):
qim3d.viz.grid_overview(random_tuple,num_images=1) qim3d.viz.grid_overview(random_tuple,num_images=1)
...@@ -42,22 +41,88 @@ def test_grid_pred(): ...@@ -42,22 +41,88 @@ def test_grid_pred():
# unit tests for slice visualization # unit tests for slice visualization
def test_slice_viz(): def test_slices_numpy_array_input():
example_volume = ones(10,10,10) example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width)
assert fig.get_figwidth() == img_width
def test_slices_torch_tensor_input():
example_volume = torch.ones((10,10,10))
img_width = 3 img_width = 3
fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width) fig = qim3d.viz.slices(example_volume,n_slices = 1, img_width = img_width)
assert fig.get_figwidth() == img_width assert fig.get_figwidth() == img_width
def test_slices_wrong_input_format():
input = 'not_a_volume'
with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'):
qim3d.viz.slices(input)
def test_slices_not_volume():
example_volume = np.ones((10,10))
with pytest.raises(ValueError, match = 'The provided object is not a volume as it has less than 3 dimensions.'):
qim3d.viz.slices(example_volume)
def test_slices_wrong_position_format1():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = 'invalid_slice')
def test_slices_wrong_position_format2():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = 1.5)
def test_slices_wrong_position_format3():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'):
qim3d.viz.slices(example_volume, position = [1, 2, 3.5])
def test_slices_invalid_axis_value():
example_volume = np.ones((10,10,10))
with pytest.raises(ValueError, match = "Invalid value for 'axis'. It should be an integer between 0 and 2"):
qim3d.viz.slices(example_volume, axis = 3)
def test_slices_show_title_option():
example_volume = np.ones((10, 10, 10))
img_width = 3
fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, show_title=False)
# Assert that titles are not shown
assert all(ax.get_title() == '' for ax in fig.get_axes())
def test_slice_viz_not_volume(): def test_slices_interpolation_option():
example_volume = ones(10,10) example_volume = torch.ones((10, 10, 10))
dim = example_volume.ndim img_width = 3
with pytest.raises(ValueError, match = f"Given array is not a volume! Current dimension: {dim}"): interpolation_method = 'bilinear'
qim3d.viz.slice_viz(example_volume) fig = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, interpolation=interpolation_method)
for ax in fig.get_axes():
# Access the interpolation method used for each Axes object
actual_interpolation = ax.images[0].get_interpolation()
# Assert that the actual interpolation method matches the expected method
assert actual_interpolation == interpolation_method
def test_slices_multiple_slices():
example_volume = np.ones((10, 10, 10))
img_width = 3
n_slices = 3
fig = qim3d.viz.slices(example_volume, n_slices=n_slices, img_width=img_width)
# Add assertions for the expected number of subplots in the figure
assert len(fig.get_axes()) == n_slices
def test_slices_axis_argument():
# Non-symmetric input
example_volume = np.arange(1000).reshape((10, 10, 10))
img_width = 3
# Call the function with different values of the axis
fig_axis_0 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=0)
fig_axis_1 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=1)
fig_axis_2 = qim3d.viz.slices(example_volume, n_slices=1, img_width=img_width, axis=2)
def test_slice_viz_wrong_slice(): # Ensure that different axes result in different plots
example_volume = ones(10,10,10) assert not np.allclose(fig_axis_0.get_axes()[0].images[0].get_array(), fig_axis_1.get_axes()[0].images[0].get_array())
with pytest.raises(ValueError, match = 'Position not recognized. Choose an integer, list, array or "start","mid","end".'): assert not np.allclose(fig_axis_1.get_axes()[0].images[0].get_array(), fig_axis_2.get_axes()[0].images[0].get_array())
qim3d.viz.slice_viz(example_volume, position = 'invalid_slice') assert not np.allclose(fig_axis_2.get_axes()[0].images[0].get_array(), fig_axis_0.get_axes()[0].images[0].get_array())
\ No newline at end of file
from .visualizations import plot_metrics from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slice_viz from .img import grid_pred, grid_overview, slices
""" """
Provides a collection of visualization functions. Provides a collection of visualization functions.
""" """
from typing import List, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps from matplotlib import colormaps
import torch import torch
import numpy as np import numpy as np
from qim3d.io.logger import log from qim3d.io.logger import log
import math
import qim3d.io import qim3d.io
import os
def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show = False): def grid_overview(
data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
):
"""Displays an overview grid of images, labels, and masks (if they exist). """Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations Labels are the annotated target segmentations
...@@ -22,7 +29,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -22,7 +29,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'. cmap_im (str, optional): The colormap to be used for displaying input images. Defaults to 'gray'.
cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'. cmap_segm (str, optional): The colormap to be used for displaying labels. Defaults to 'viridis'.
alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5. alpha (float, optional): The transparency level of the label and mask overlays. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
Raises: Raises:
ValueError: If the data elements are not tuples. ValueError: If the data elements are not tuples.
...@@ -75,7 +82,9 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -75,7 +82,9 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
# Make new list such that possible augmentations remain identical for all three rows # Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(num_images)] plot_data = [data[idx] for idx in range(num_images)]
fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True) fig = plt.figure(
figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True
)
# create 2 (3) x 1 subfigs # create 2 (3) x 1 subfigs
subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1) subfigs = fig.subfigures(nrows=3 if has_mask else 2, ncols=1)
...@@ -100,7 +109,14 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha ...@@ -100,7 +109,14 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
return fig return fig
def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5,show = False): def grid_pred(
in_targ_preds,
num_images=7,
cmap_im="gray",
cmap_segm="viridis",
alpha=0.5,
show=False,
):
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations. Displays a grid of subplots representing different aspects of the input images and segmentations.
...@@ -119,7 +135,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ...@@ -119,7 +135,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
cmap_im (str, optional): Color map for input images. Defaults to "gray". cmap_im (str, optional): Color map for input images. Defaults to "gray".
cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis". cmap_segm (str, optional): Color map for segmentations. Defaults to "viridis".
alpha (float, optional): Alpha value for transparency. Defaults to 0.5. alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
show (bool, optional): If True, displays the plot. Defaults to False. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
Returns: Returns:
fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models. fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models.
...@@ -187,9 +203,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ...@@ -187,9 +203,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
ax.axis("off") ax.axis("off")
elif row == 2: # Ground truth segmentation elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow( ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha)
targets[col], cmap=custom_cmap, alpha=alpha
)
ax.axis("off") ax.axis("off")
else: else:
ax.imshow(inputs[col], cmap=cmap_im) ax.imshow(inputs[col], cmap=cmap_im)
...@@ -202,105 +216,172 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", ...@@ -202,105 +216,172 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
return fig return fig
def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False):
""" Displays one or several slices from a 3d array.
By default if `position` is None, slice_viz plots an overview of the entire stack. def slices(
If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position. vol: Union[np.ndarray, torch.Tensor],
If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted. axis: int = 0,
position: Optional[Union[str, int, List[int]]] = None,
n_slices: int = 5,
max_cols: int = 5,
cmap: str = "viridis",
img_height: int = 2,
img_width: int = 2,
show: bool = False,
show_position: bool = True,
interpolation: Optional[str] = None,
) -> plt.Figure:
"""Displays one or several slices from a 3d volume.
By default if `position` is None, slices plots `n_slices` linearly spaced slices.
If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.
Args: Args:
input (str, numpy.ndarray): Path to the file or 3-dimensional array. vol (np.ndarray or torch.Tensor): The 3D volume to be sliced.
position (str, int, list, array, optional): One or several slicing levels. axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
n_slices (int, optional): Defines how many slices the user wants. position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
cmap (str, optional): Specifies the color map for the image. n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
axis (bool, optional): Specifies whether the axes should be included. max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
img_height(int, optional): Height of the figure. img_height(int, optional): Height of the figure.
img_width(int, optional): Width of the figure. img_width(int, optional): Width of the figure.
show (bool, optional): If True, displays the plot. Defaults to False. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns: Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
Raises: Raises:
ValueError: If the file or array is not a 3D volume. ValueError: If the input is not a numpy.ndarray or torch.Tensor.
ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom). ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
ValueError: If the file or array is not a volume with at least 3 dimensions.
ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
Example: Example:
image_path = '/my_image_path/my_image.tif' vol_path = '/my_vol_path/my_vol.tif'
slice_viz(image_path) vol = qim3d.io.load(vol_path)
slices(vol, axis = 1, position = 'mid', n_slices = 3, cmap = 'viridis', img_height = 4, img_width = 4, show = True, show_position = True, interpolation = None)
""" """
# Filepath input # Numpy array or Torch tensor input
if isinstance(input,str): if not isinstance(vol, (np.ndarray, torch.Tensor)):
vol = qim3d.io.load(input) # Function has its own ValueErrors raise ValueError("Input must be a numpy.ndarray or torch.Tensor")
dim = vol.ndim
# Numpy array input if vol.ndim < 3:
elif isinstance(input,(np.ndarray,torch.Tensor)): raise ValueError(
vol = input "The provided object is not a volume as it has less than 3 dimensions."
dim = input.ndim )
# Ensure axis is a valid choice
if not (0 <= axis < vol.ndim):
raise ValueError(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
)
if dim != 3: # Get total number of slices in the specified dimension
raise ValueError(f"Given array is not a volume! Current dimension: {dim}") n_total = vol.shape[axis]
# Position is not provided - will use linearly spaced slices
if position is None: if position is None:
height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int) slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
# Position is a string # Position is a string
elif isinstance(position,str): elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
if position.lower() == "start":
if position.lower() in ['mid','middle']: slice_idxs = _get_slice_range(0, n_slices, n_total)
expansion_start = int(vol.shape[0]/2) elif position.lower() == "mid":
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
elif position.lower() == "end":
elif position.lower() in ['top','upper', 'start']: slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
expansion_start = 0
height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int)
elif position.lower() in ['bot','bottom', 'end']:
expansion_start = vol.shape[0]-1
height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int)
else:
raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".')
# Position is an integer # Position is an integer
elif isinstance(position, int): elif isinstance(position, int):
expansion_start = position slice_idxs = _get_slice_range(position, n_slices, n_total)
n_stacks = vol.shape[0]-1 # Position is a list of integers
elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
# if linspace would extend beyond n_stacks slice_idxs = position
if expansion_start + n_slices > n_stacks:
height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int)
# if linspace would extend below 0
elif expansion_start - n_slices < 0:
height = np.linspace(0,n_slices-1,n_slices).astype(int)
else: else:
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) raise ValueError(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
)
# Position is a list or array of integers # Make grid
elif isinstance(position,(list,np.ndarray)): nrows = math.ceil(n_slices / max_cols)
height = position ncols = min(n_slices, max_cols)
num_images = len(height) # Generate figure
fig, axs = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * img_height, nrows * img_width),
constrained_layout=True,
)
if nrows == 1:
axs = [axs] # Convert to a list for uniformity
# Convert Torch tensor to NumPy array in order to use the numpy.take method
if isinstance(vol, torch.Tensor):
vol = vol.numpy()
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)):
slice_idx = i * max_cols + j
try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
ax.imshow(slice_img, cmap=cmap, interpolation=interpolation)
if show_position:
ax.text(
0.0,
1.0,
f"slice {slice_idxs[slice_idx]} ",
transform=ax.transAxes,
color="white",
fontsize=8,
va="top",
ha="left",
bbox=dict(facecolor="#303030", linewidth=0, pad=0),
)
ax.text(
1.0,
0.0,
f"axis {axis} ",
transform=ax.transAxes,
color="white",
fontsize=8,
va="bottom",
ha="right",
bbox=dict(facecolor="#303030", linewidth=0, pad=0),
)
fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True) except IndexError:
axs = fig.subplots(nrows = 1, ncols = num_images) # Not a problem, because we simply do not have a slice to show
pass
for col, ax in enumerate(np.atleast_1d(axs)): # Hide the axis, so that we have a nice grid
ax.imshow(vol[height[col],:,:],cmap = cmap) ax.axis("off")
ax.set_title(f'Slice {height[col]}', fontsize=6*img_height)
if not axis:
ax.axis('off')
if show: if show:
plt.show() plt.show()
plt.close() plt.close()
return fig return fig
def _get_slice_range(position: int, n_slices: int, n_total):
"""Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
start_idx = position - n_slices // 2
end_idx = (
position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1
)
slice_idxs = np.arange(start_idx, end_idx)
if slice_idxs[0] < 0:
slice_idxs = np.arange(0, n_slices)
elif slice_idxs[-1] > n_total:
slice_idxs = np.arange(n_total - n_slices, n_total)
return slice_idxs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment