From b4ea1bf3536a428183e57b446672db2b099ed99d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk> Date: Tue, 11 Feb 2025 15:00:08 +0100 Subject: [PATCH] removed unsqueeze since extra dimension is added in dataloader --- qim3d/ml/_ml_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index f268dc25..d9260249 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -81,7 +81,7 @@ def train_model( for data in train_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).unsqueeze(1) + targets = targets.to(device) #.unsqueeze(1) optimizer.zero_grad() outputs = model(inputs) @@ -111,7 +111,7 @@ def train_model( for data in val_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).unsqueeze(1) + targets = targets.to(device) #.unsqueeze(1) with torch.no_grad(): outputs = model(inputs) @@ -239,7 +239,7 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t # if there is only one image if inputs.dim() == 2: - inputs = inputs.unsqueeze(0) + inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary targets = targets.unsqueeze(0) preds = preds.unsqueeze(0) @@ -270,7 +270,7 @@ def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float for idx in np.arange(len(volume)): input_with_channel = np.expand_dims(volume[idx], axis=0) input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device) - input_tensor = input_tensor.unsqueeze(0) + input_tensor = input_tensor.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary output = model(input_tensor) > threshold output = output.cpu() if device == "cuda" else output output_detached = output.detach() -- GitLab