Skip to content
Snippets Groups Projects
Commit 505419af authored by s193396's avatar s193396
Browse files

removed 2D inference

parent f79041f8
No related branches found
No related tags found
No related merge requests found
......@@ -171,7 +171,6 @@ def inference(
data: torch.utils.data.Dataset,
model: torch.nn.Module,
threshold: float = 0.5,
is_3d: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs inference on input data using the specified model.
......@@ -187,7 +186,6 @@ def inference(
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
threshold (float): The threshold value used to binarize the model predictions.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
......@@ -212,151 +210,31 @@ def inference(
results = []
# 3D data
if is_3d:
for volume, target in data:
if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
raise ValueError("Data items must consist of tensors")
for volume, target in data:
if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Add batch and channel dimensions
volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
# Add batch and channel dimensions
volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
with torch.no_grad():
with torch.no_grad():
# Get model predictions (logits)
output = model(volume)
# Get model predictions (logits)
output = model(volume)
# Convert logits to probabilities [0, 1]
preds = torch.sigmoid(output)
# Convert logits to probabilities [0, 1]
preds = torch.sigmoid(output)
# Convert to binary mask by thresholding the probabilities
preds = (preds > threshold).float()
# Convert to binary mask by thresholding the probabilities
preds = (preds > threshold).float()
# Remove batch and channel dimensions
volume = volume.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
preds = preds.squeeze().cpu().numpy()
# Append results to list
results.append((volume, target, preds))
# 2D data
else:
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
for inputs, targets in data:
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
outputs = model(inputs)
# Prepare data for plotting
inputs_cpu = inputs.cpu().squeeze()
targets_cpu = targets.cpu().squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > threshold
else:
preds = outputs.cpu().argmax(axis=1)
# If there is only one image
if inputs_cpu.dim() == 2:
inputs_cpu = inputs_cpu.unsqueeze(0).numpy()
targets_cpu = targets_cpu.unsqueeze(0).numpy()
preds = preds.unsqueeze(0).numpy()
# Remove batch and channel dimensions
volume = volume.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
preds = preds.squeeze().cpu().numpy()
# Append results to list
results.append((inputs_cpu, targets_cpu, preds))
results.append((volume, target, preds))
return results
# Old implementation:
# else:
# # Check if data have the right format
# if not isinstance(data[0], tuple):
# raise ValueError("Data items must be tuples")
# # Check if data is torch tensors
# for element in data[0]:
# if not isinstance(element, torch.Tensor):
# raise ValueError("Data items must consist of tensors")
# # Check if input image is (C,H,W) format
# if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
# pass
# else:
# raise ValueError("Input image must be (C,H,W) format")
# # Make new list such that possible augmentations remain identical for all three rows
# plot_data = [data[idx] for idx in range(len(data))]
# # Create input and target batch
# inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
# targets = torch.stack([item[1] for item in plot_data], dim=0)
# # Get output predictions
# with torch.no_grad():
# outputs = model(inputs)
# # Prepare data for plotting
# inputs = inputs.cpu().squeeze()
# targets = targets.squeeze()
# if outputs.shape[1] == 1:
# preds = (
# outputs.cpu().squeeze() > threshold
# ) # TODO: outputs from model are not between [0,1] yet, need to implement that
# else:
# preds = outputs.cpu().argmax(axis=1)
# # if there is only one image
# if inputs.dim() == 2:
# inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
# targets = targets.unsqueeze(0)
# preds = preds.unsqueeze(0)
# return inputs, targets, preds
def volume_inference(
volume: np.ndarray,
model: torch.nn.Module,
threshold:float = 0.5,
) -> np.ndarray:
"""
Compute on the entire volume
Args:
volume (numpy.ndarray): A 3D numpy array representing the input volume.
model (torch.nn.Module): The trained network model used for inference.
threshold (float): The threshold value used to binarize the model predictions.
Returns:
numpy.ndarray: A 3D numpy array representing the model predictions for each slice of the input volume.
Raises:
ValueError: If the input volume is not a 3D numpy array.
"""
if len(volume.shape) != 3:
raise ValueError("Input volume must be a 3D numpy array")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
inference_vol = np.zeros_like(volume)
for idx in np.arange(len(volume)):
input_with_channel = np.expand_dims(volume[idx], axis=0)
input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device)
input_tensor = input_tensor.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
output = model(input_tensor) > threshold
output = output.cpu() if device == "cuda" else output
output_detached = output.detach()
output_numpy = output_detached.numpy()[0, 0, :, :]
inference_vol[idx] = output_numpy
return inference_vol
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment