diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index b90c00148b5414950bfe9e15dc96548c97f67e60..65bffbe7b9fec508ed147a28633b94c19a3894db 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -171,7 +171,6 @@ def inference(
         data: torch.utils.data.Dataset, 
         model: torch.nn.Module,
         threshold: float = 0.5,
-        is_3d: bool = True,
         ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """Performs inference on input data using the specified model.
 
@@ -187,7 +186,6 @@ def inference(
             ground truth label data.
         model (torch.nn.Module): The trained network model used for predicting segmentations.
         threshold (float): The threshold value used to binarize the model predictions.
-        is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
 
     Returns:
         tuple: A tuple containing the input images, target labels, and predicted labels.
@@ -212,151 +210,31 @@ def inference(
 
     results = []
 
-    # 3D data
-    if is_3d:
-        for volume, target in data:
-            if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
-                raise ValueError("Data items must consist of tensors")
+    for volume, target in data:
+        if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
+            raise ValueError("Data items must consist of tensors")
 
-            # Add batch and channel dimensions
-            volume = volume.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
-            target = target.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
+        # Add batch and channel dimensions
+        volume = volume.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
+        target = target.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
 
-            with torch.no_grad():
+        with torch.no_grad():
 
-                # Get model predictions (logits)
-                output = model(volume)
+            # Get model predictions (logits)
+            output = model(volume)
 
-                # Convert logits to probabilities [0, 1]
-                preds = torch.sigmoid(output)
+            # Convert logits to probabilities [0, 1]
+            preds = torch.sigmoid(output)
 
-                # Convert to binary mask by thresholding the probabilities
-                preds = (preds > threshold).float()
+            # Convert to binary mask by thresholding the probabilities
+            preds = (preds > threshold).float()
 
-                # Remove batch and channel dimensions
-                volume = volume.squeeze().cpu().numpy()
-                target = target.squeeze().cpu().numpy()
-                preds = preds.squeeze().cpu().numpy()
-
-            # Append results to list
-            results.append((volume, target, preds))
-
-    # 2D data
-    else:
-        # Check if data have the right format
-        if not isinstance(data[0], tuple):
-            raise ValueError("Data items must be tuples")
-
-        # Check if data is torch tensors
-        for element in data[0]:
-            if not isinstance(element, torch.Tensor):
-                raise ValueError("Data items must consist of tensors")
-
-        for inputs, targets in data:
-            inputs = inputs.to(device)
-            targets = targets.to(device)
-
-            with torch.no_grad():
-                outputs = model(inputs)
-
-            # Prepare data for plotting
-            inputs_cpu = inputs.cpu().squeeze()
-            targets_cpu = targets.cpu().squeeze()
-            if outputs.shape[1] == 1:
-                preds = outputs.cpu().squeeze() > threshold
-            else:
-                preds = outputs.cpu().argmax(axis=1)
-
-            # If there is only one image
-            if inputs_cpu.dim() == 2:
-                inputs_cpu = inputs_cpu.unsqueeze(0).numpy()
-                targets_cpu = targets_cpu.unsqueeze(0).numpy()
-                preds = preds.unsqueeze(0).numpy()
+            # Remove batch and channel dimensions
+            volume = volume.squeeze().cpu().numpy()
+            target = target.squeeze().cpu().numpy()
+            preds = preds.squeeze().cpu().numpy()
 
         # Append results to list
-        results.append((inputs_cpu, targets_cpu, preds))
+        results.append((volume, target, preds))
 
     return results
-
-    # Old implementation: 
-    # else:
-    #     # Check if data have the right format
-    #     if not isinstance(data[0], tuple):
-    #         raise ValueError("Data items must be tuples")
-
-    #     # Check if data is torch tensors
-    #     for element in data[0]:
-    #         if not isinstance(element, torch.Tensor):
-    #             raise ValueError("Data items must consist of tensors")
-
-    #     # Check if input image is (C,H,W) format
-    #     if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
-    #         pass
-    #     else:
-    #         raise ValueError("Input image must be (C,H,W) format")
-
-    #     # Make new list such that possible augmentations remain identical for all three rows
-    #     plot_data = [data[idx] for idx in range(len(data))]
-
-    #     # Create input and target batch
-    #     inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
-    #     targets = torch.stack([item[1] for item in plot_data], dim=0)
-
-    #     # Get output predictions
-    #     with torch.no_grad():
-    #         outputs = model(inputs)
-
-    #     # Prepare data for plotting
-    #     inputs = inputs.cpu().squeeze()
-    #     targets = targets.squeeze()
-    #     if outputs.shape[1] == 1:
-    #         preds = (
-    #             outputs.cpu().squeeze() > threshold
-    #         )  # TODO: outputs from model are not between [0,1] yet, need to implement that
-    #     else:
-    #         preds = outputs.cpu().argmax(axis=1)
-
-    #     # if there is only one image
-    #     if inputs.dim() == 2:
-    #         inputs = inputs.unsqueeze(0)    # TODO: Not sure if unsqueeze (add extra dimension) is necessary
-    #         targets = targets.unsqueeze(0)
-    #         preds = preds.unsqueeze(0)
-
-    #     return inputs, targets, preds
-
-def volume_inference(
-        volume: np.ndarray, 
-        model: torch.nn.Module, 
-        threshold:float = 0.5,
-        ) -> np.ndarray:
-    """
-    Compute on the entire volume
-    Args:
-        volume (numpy.ndarray): A 3D numpy array representing the input volume.
-        model (torch.nn.Module): The trained network model used for inference.
-        threshold (float): The threshold value used to binarize the model predictions.
-    Returns:
-        numpy.ndarray: A 3D numpy array representing the model predictions for each slice of the input volume.
-    Raises:
-        ValueError: If the input volume is not a 3D numpy array.
-    """
-    if len(volume.shape) != 3:
-        raise ValueError("Input volume must be a 3D numpy array")
-
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    model.to(device)
-    model.eval()
-
-    inference_vol = np.zeros_like(volume)
-
-    for idx in np.arange(len(volume)):
-        input_with_channel = np.expand_dims(volume[idx], axis=0)
-        input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device)
-        input_tensor = input_tensor.unsqueeze(0)  # TODO: Not sure if unsqueeze (add extra dimension) is necessary
-        output = model(input_tensor) > threshold
-        output = output.cpu() if device == "cuda" else output
-        output_detached = output.detach()
-        output_numpy = output_detached.numpy()[0, 0, :, :]
-        inference_vol[idx] = output_numpy
-
-    return inference_vol