diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index e2d74d07bda9623337585f827ca5e07fe70e51ca..71729993b837b23ceac2e722cf3b4ca9fae18f62 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -72,7 +72,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha row_titles = ["Input images", "Ground truth segmentation", "Mask"] # Make new list such that possible augmentations remain identical for all three rows - plot_data = list(data[:num_images]) + plot_data = [data[idx] for idx in range(num_images)] fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True) @@ -168,7 +168,7 @@ def grid_pred( model.eval() # Make new list such that possible augmentations remain identical for all three rows - plot_data = list(data[:num_images]) + plot_data = [data[idx] for idx in range(num_images)] # Create input and target batch inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)