diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py
index f1f41785c4d94a8f156bfddceaf54dff6a4a27fa..778f4bff7d524bdaddb67cc4833d81d44bd4a0d0 100644
--- a/qim3d/viz/_metrics.py
+++ b/qim3d/viz/_metrics.py
@@ -6,14 +6,14 @@ import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 from qim3d.utils._logger import log
-import torch
 import matplotlib
 
+
 def plot_metrics(
     *metrics: tuple[dict[str, float]],
     linestyle: str = "-",
     batch_linestyle: str = "dotted",
-    labels: list|None = None,
+    labels: list | None = None,
     figsize: tuple = (16, 6),
     show: bool = False
 ):
@@ -81,13 +81,13 @@ def plot_metrics(
 
 
 def grid_overview(
-    data: list|torch.utils.data.Dataset, 
-    num_images: int = 7, 
-    cmap_im: str = "gray", 
-    cmap_segm: str = "viridis", 
-    alpha: float = 0.5, 
-    show: bool = False
-)-> matplotlib.figure.Figure:
+    data: list,
+    num_images: int = 7,
+    cmap_im: str = "gray",
+    cmap_segm: str = "viridis",
+    alpha: float = 0.5,
+    show: bool = False,
+) -> matplotlib.figure.Figure:
     """Displays an overview grid of images, labels, and masks (if they exist).
 
     Labels are the annotated target segmentations
@@ -121,6 +121,7 @@ def grid_overview(
             and the length of the data.
         - The grid layout and dimensions vary based on the presence of a mask.
     """
+    import torch
 
     # Check if data has a mask
     has_mask = len(data[0]) > 2 and data[0][-1] is not None
@@ -187,7 +188,7 @@ def grid_pred(
     cmap_segm: str = "viridis",
     alpha: float = 0.5,
     show: bool = False,
-)-> matplotlib.figure.Figure:
+) -> matplotlib.figure.Figure:
     """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
 
     Displays a grid of subplots representing different aspects of the input images and segmentations.
@@ -290,7 +291,9 @@ def grid_pred(
     return fig
 
 
-def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray:
+def vol_masked(
+    vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int = 128
+) -> np.ndarray:
     """
     Applies masking to a volume based on a binary volume mask.