diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index f268dc2556fb74de04acc836cec184875f42d80e..d926024935635123f4b1a186ffb6dd1a1c528840 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -81,7 +81,7 @@ def train_model(
             for data in train_loader:
                 inputs, targets = data
                 inputs = inputs.to(device)
-                targets = targets.to(device).unsqueeze(1)
+                targets = targets.to(device) #.unsqueeze(1)
 
                 optimizer.zero_grad()
                 outputs = model(inputs)
@@ -111,7 +111,7 @@ def train_model(
                 for data in val_loader:
                     inputs, targets = data
                     inputs = inputs.to(device)
-                    targets = targets.to(device).unsqueeze(1)
+                    targets = targets.to(device) #.unsqueeze(1)
 
                     with torch.no_grad():
                         outputs = model(inputs)
@@ -239,7 +239,7 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
 
     # if there is only one image
     if inputs.dim() == 2:
-        inputs = inputs.unsqueeze(0)
+        inputs = inputs.unsqueeze(0)    # TODO: Not sure if unsqueeze (add extra dimension) is necessary
         targets = targets.unsqueeze(0)
         preds = preds.unsqueeze(0)
 
@@ -270,7 +270,7 @@ def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float
     for idx in np.arange(len(volume)):
         input_with_channel = np.expand_dims(volume[idx], axis=0)
         input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device)
-        input_tensor = input_tensor.unsqueeze(0)
+        input_tensor = input_tensor.unsqueeze(0)  # TODO: Not sure if unsqueeze (add extra dimension) is necessary
         output = model(input_tensor) > threshold
         output = output.cpu() if device == "cuda" else output
         output_detached = output.detach()