Skip to content
Snippets Groups Projects
Commit b4ea1bf3 authored by s193396's avatar s193396
Browse files

removed unsqueeze since extra dimension is added in dataloader

parent ed05746d
No related branches found
No related tags found
No related merge requests found
......@@ -81,7 +81,7 @@ def train_model(
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).unsqueeze(1)
targets = targets.to(device) #.unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
......@@ -111,7 +111,7 @@ def train_model(
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).unsqueeze(1)
targets = targets.to(device) #.unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
......@@ -239,7 +239,7 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
# if there is only one image
if inputs.dim() == 2:
inputs = inputs.unsqueeze(0)
inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
targets = targets.unsqueeze(0)
preds = preds.unsqueeze(0)
......@@ -270,7 +270,7 @@ def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float
for idx in np.arange(len(volume)):
input_with_channel = np.expand_dims(volume[idx], axis=0)
input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device)
input_tensor = input_tensor.unsqueeze(0)
input_tensor = input_tensor.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
output = model(input_tensor) > threshold
output = output.cpu() if device == "cuda" else output
output_detached = output.detach()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment