diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index f268dc2556fb74de04acc836cec184875f42d80e..d926024935635123f4b1a186ffb6dd1a1c528840 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -81,7 +81,7 @@ def train_model( for data in train_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).unsqueeze(1) + targets = targets.to(device) #.unsqueeze(1) optimizer.zero_grad() outputs = model(inputs) @@ -111,7 +111,7 @@ def train_model( for data in val_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).unsqueeze(1) + targets = targets.to(device) #.unsqueeze(1) with torch.no_grad(): outputs = model(inputs) @@ -239,7 +239,7 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t # if there is only one image if inputs.dim() == 2: - inputs = inputs.unsqueeze(0) + inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary targets = targets.unsqueeze(0) preds = preds.unsqueeze(0) @@ -270,7 +270,7 @@ def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float for idx in np.arange(len(volume)): input_with_channel = np.expand_dims(volume[idx], axis=0) input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device) - input_tensor = input_tensor.unsqueeze(0) + input_tensor = input_tensor.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary output = model(input_tensor) > threshold output = output.cpu() if device == "cuda" else output output_detached = output.detach()