diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index e2d74d07bda9623337585f827ca5e07fe70e51ca..71729993b837b23ceac2e722cf3b4ca9fae18f62 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -72,7 +72,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
     row_titles = ["Input images", "Ground truth segmentation", "Mask"]
 
     # Make new list such that possible augmentations remain identical for all three rows
-    plot_data = list(data[:num_images])
+    plot_data = [data[idx] for idx in range(num_images)]
 
     fig = plt.figure(figsize=(2 * num_images, 9 if has_mask else 6), constrained_layout=True)
 
@@ -168,7 +168,7 @@ def grid_pred(
     model.eval()
 
     # Make new list such that possible augmentations remain identical for all three rows
-    plot_data = list(data[:num_images])
+    plot_data = [data[idx] for idx in range(num_images)]
 
     # Create input and target batch
     inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)