diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index b90c00148b5414950bfe9e15dc96548c97f67e60..65bffbe7b9fec508ed147a28633b94c19a3894db 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -171,7 +171,6 @@ def inference( data: torch.utils.data.Dataset, model: torch.nn.Module, threshold: float = 0.5, - is_3d: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Performs inference on input data using the specified model. @@ -187,7 +186,6 @@ def inference( ground truth label data. model (torch.nn.Module): The trained network model used for predicting segmentations. threshold (float): The threshold value used to binarize the model predictions. - is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True. Returns: tuple: A tuple containing the input images, target labels, and predicted labels. @@ -212,151 +210,31 @@ def inference( results = [] - # 3D data - if is_3d: - for volume, target in data: - if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor): - raise ValueError("Data items must consist of tensors") + for volume, target in data: + if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor): + raise ValueError("Data items must consist of tensors") - # Add batch and channel dimensions - volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W] - target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W] + # Add batch and channel dimensions + volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W] + target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W] - with torch.no_grad(): + with torch.no_grad(): - # Get model predictions (logits) - output = model(volume) + # Get model predictions (logits) + output = model(volume) - # Convert logits to probabilities [0, 1] - preds = torch.sigmoid(output) + # Convert logits to probabilities [0, 1] + preds = torch.sigmoid(output) - # Convert to binary mask by thresholding the probabilities - preds = (preds > threshold).float() + # Convert to binary mask by thresholding the probabilities + preds = (preds > threshold).float() - # Remove batch and channel dimensions - volume = volume.squeeze().cpu().numpy() - target = target.squeeze().cpu().numpy() - preds = preds.squeeze().cpu().numpy() - - # Append results to list - results.append((volume, target, preds)) - - # 2D data - else: - # Check if data have the right format - if not isinstance(data[0], tuple): - raise ValueError("Data items must be tuples") - - # Check if data is torch tensors - for element in data[0]: - if not isinstance(element, torch.Tensor): - raise ValueError("Data items must consist of tensors") - - for inputs, targets in data: - inputs = inputs.to(device) - targets = targets.to(device) - - with torch.no_grad(): - outputs = model(inputs) - - # Prepare data for plotting - inputs_cpu = inputs.cpu().squeeze() - targets_cpu = targets.cpu().squeeze() - if outputs.shape[1] == 1: - preds = outputs.cpu().squeeze() > threshold - else: - preds = outputs.cpu().argmax(axis=1) - - # If there is only one image - if inputs_cpu.dim() == 2: - inputs_cpu = inputs_cpu.unsqueeze(0).numpy() - targets_cpu = targets_cpu.unsqueeze(0).numpy() - preds = preds.unsqueeze(0).numpy() + # Remove batch and channel dimensions + volume = volume.squeeze().cpu().numpy() + target = target.squeeze().cpu().numpy() + preds = preds.squeeze().cpu().numpy() # Append results to list - results.append((inputs_cpu, targets_cpu, preds)) + results.append((volume, target, preds)) return results - - # Old implementation: - # else: - # # Check if data have the right format - # if not isinstance(data[0], tuple): - # raise ValueError("Data items must be tuples") - - # # Check if data is torch tensors - # for element in data[0]: - # if not isinstance(element, torch.Tensor): - # raise ValueError("Data items must consist of tensors") - - # # Check if input image is (C,H,W) format - # if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]): - # pass - # else: - # raise ValueError("Input image must be (C,H,W) format") - - # # Make new list such that possible augmentations remain identical for all three rows - # plot_data = [data[idx] for idx in range(len(data))] - - # # Create input and target batch - # inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device) - # targets = torch.stack([item[1] for item in plot_data], dim=0) - - # # Get output predictions - # with torch.no_grad(): - # outputs = model(inputs) - - # # Prepare data for plotting - # inputs = inputs.cpu().squeeze() - # targets = targets.squeeze() - # if outputs.shape[1] == 1: - # preds = ( - # outputs.cpu().squeeze() > threshold - # ) # TODO: outputs from model are not between [0,1] yet, need to implement that - # else: - # preds = outputs.cpu().argmax(axis=1) - - # # if there is only one image - # if inputs.dim() == 2: - # inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary - # targets = targets.unsqueeze(0) - # preds = preds.unsqueeze(0) - - # return inputs, targets, preds - -def volume_inference( - volume: np.ndarray, - model: torch.nn.Module, - threshold:float = 0.5, - ) -> np.ndarray: - """ - Compute on the entire volume - Args: - volume (numpy.ndarray): A 3D numpy array representing the input volume. - model (torch.nn.Module): The trained network model used for inference. - threshold (float): The threshold value used to binarize the model predictions. - Returns: - numpy.ndarray: A 3D numpy array representing the model predictions for each slice of the input volume. - Raises: - ValueError: If the input volume is not a 3D numpy array. - """ - if len(volume.shape) != 3: - raise ValueError("Input volume must be a 3D numpy array") - - device = "cuda" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - - inference_vol = np.zeros_like(volume) - - for idx in np.arange(len(volume)): - input_with_channel = np.expand_dims(volume[idx], axis=0) - input_tensor = torch.tensor(input_with_channel, dtype=torch.float32).to(device) - input_tensor = input_tensor.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary - output = model(input_tensor) > threshold - output = output.cpu() if device == "cuda" else output - output_detached = output.detach() - output_numpy = output_detached.numpy()[0, 0, :, :] - inference_vol[idx] = output_numpy - - return inference_vol