Skip to content
Snippets Groups Projects
Commit 3d734ce5 authored by s193396's avatar s193396
Browse files

temporary

parent ac6a7a44
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -53,8 +53,8 @@ class Augmentation: ...@@ -53,8 +53,8 @@ class Augmentation:
ValueError: If `level` is neither None, light, moderate nor heavy. ValueError: If `level` is neither None, light, moderate nor heavy.
""" """
from monai.transforms import ( from monai.transforms import (
Compose, RandRotate90, RandFlip, RandAffine, ToTensor, \ Compose, RandRotate90d, RandFlipd, RandAffined, ToTensor, \
RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd
) )
# Check if 2D or 3D # Check if 2D or 3D
...@@ -74,41 +74,64 @@ class Augmentation: ...@@ -74,41 +74,64 @@ class Augmentation:
# For 2D, add normalization to the baseline augmentations # For 2D, add normalization to the baseline augmentations
# TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise) # TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
if not self.is_3d: if not self.is_3d:
baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std)) # baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std))
# Resize augmentations # Resize augmentations
if self.resize == 'crop': if self.resize == 'crop':
resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))] if self.is_3d else [CenterSpatialCrop((im_h, im_w))] # resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))]
resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))]
elif self.resize == 'reshape': elif self.resize == 'reshape':
resize_aug = [Resize((im_d, im_h, im_w))] if self.is_3d else [Resize((im_h, im_w))] # resize_aug = [Resize((im_d, im_h, im_w))]
resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
elif self.resize == 'padding': elif self.resize == 'padding':
resize_aug = [SpatialPad((im_d, im_h, im_w))] if self.is_3d else [SpatialPad((im_h, im_w))] # resize_aug = [SpatialPad((im_d, im_h, im_w))]
resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
# Level of augmentation # Level of augmentation
if level == None: if level == None:
# No augmentation for the validation and test sets
level_aug = [] level_aug = []
resize_aug = []
elif level == 'light': elif level == 'light':
level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)] # level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))]
level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))]
elif level == 'moderate': elif level == 'moderate':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.3, spatial_axis=0),
# RandFlip(prob=0.3, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
# RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
# ]
level_aug = [ level_aug = [
RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlip(prob=0.3, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
RandFlip(prob=0.3, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1), RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)), RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
] ]
elif level == 'heavy': elif level == 'heavy':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.7, spatial_axis=0),
# RandFlip(prob=0.7, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
# RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
# ]
level_aug = [ level_aug = [
RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1), RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlip(prob=0.7, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
RandFlip(prob=0.7, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3), RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15)) RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
] ]
return Compose(baseline_aug + resize_aug + level_aug) return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file
...@@ -104,8 +104,12 @@ class Dataset(torch.utils.data.Dataset): ...@@ -104,8 +104,12 @@ class Dataset(torch.utils.data.Dataset):
target = target.transpose((2, 0, 1)) target = target.transpose((2, 0, 1))
if self.transform: if self.transform:
image = self.transform(image) # uint8 transformed = self.transform({"image": image, "label": target})
target = self.transform(target) # int32 image = transformed["image"]
target = transformed["label"]
# image = self.transform(image) # uint8
# target = self.transform(target) # int32
# TODO: Which dtype? # TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32) image = image.clone().detach().to(dtype=torch.float32)
...@@ -160,7 +164,7 @@ def check_resize( ...@@ -160,7 +164,7 @@ def check_resize(
orig_shape (tuple): Original shape of the image. orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions. resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model. n_channels (int): Number of channels in the model.
is_3d (bool): Whether the data is 3D or not. is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns: Returns:
tuple: Final resize dimensions. tuple: Final resize dimensions.
...@@ -230,7 +234,12 @@ def check_resize( ...@@ -230,7 +234,12 @@ def check_resize(
return final_h, final_w return final_h, final_w
def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]: def prepare_datasets(
path: str,
val_fraction: float,
model: nn.Module,
augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
""" """
Splits and augments the train/validation/test datasets. Splits and augments the train/validation/test datasets.
......
...@@ -133,6 +133,10 @@ def train_model( ...@@ -133,6 +133,10 @@ def train_model(
f"val loss: {val_loss['loss'][epoch]:.4f}" f"val loss: {val_loss['loss'][epoch]:.4f}"
) )
# NOTE: Delete this again
# Save model checkpoint to .pth file
torch.save(model.state_dict(), "C:/Users\s193396/dataset/model.pth")
if plot: if plot:
plot_metrics(train_loss, val_loss, labels=["Train", "Valid."], show=True) plot_metrics(train_loss, val_loss, labels=["Train", "Valid."], show=True)
...@@ -163,7 +167,12 @@ def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Modul ...@@ -163,7 +167,12 @@ def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Modul
return model_s return model_s
def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def inference(
data: torch.utils.data.Dataset,
model: torch.nn.Module,
threshold: float = 0.5,
is_3d: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs inference on input data using the specified model. """Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list, Performs inference on the input data using the provided model. The input data should be in the form of a list,
...@@ -177,6 +186,8 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t ...@@ -177,6 +186,8 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
data (torch.utils.data.Dataset): A Torch dataset containing input image and data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data. ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations. model (torch.nn.Module): The trained network model used for predicting segmentations.
threshold (float): The threshold value used to binarize the model predictions.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns: Returns:
tuple: A tuple containing the input images, target labels, and predicted labels. tuple: A tuple containing the input images, target labels, and predicted labels.
...@@ -194,10 +205,44 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t ...@@ -194,10 +205,44 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
model = MySegmentationModel() model = MySegmentationModel()
qim3d.ml.inference(data,model) qim3d.ml.inference(data,model)
""" """
# Set model to evaluation mode
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
results = []
# 3D data
if is_3d:
for volume, target in data:
if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Add batch and channel dimensions
volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
with torch.no_grad():
# Get model predictions (logits)
output = model(volume)
# Convert logits to probabilities [0, 1]
preds = torch.sigmoid(output)
# Convert to binary mask by thresholding the probabilities
preds = (preds > threshold).float()
# Remove batch and channel dimensions
volume = volume.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
preds = preds.squeeze().cpu().numpy()
# Append results to list
results.append((volume, target, preds))
# 2D data
else:
# Check if data have the right format # Check if data have the right format
if not isinstance(data[0], tuple): if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples") raise ValueError("Data items must be tuples")
...@@ -207,46 +252,83 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t ...@@ -207,46 +252,83 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
if not isinstance(element, torch.Tensor): if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors") raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format for inputs, targets in data:
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]): inputs = inputs.to(device)
pass targets = targets.to(device)
else:
raise ValueError("Input image must be (C,H,W) format")
model.to(device)
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(len(data))]
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Get output predictions
with torch.no_grad(): with torch.no_grad():
outputs = model(inputs) outputs = model(inputs)
# Prepare data for plotting # Prepare data for plotting
inputs = inputs.cpu().squeeze() inputs_cpu = inputs.cpu().squeeze()
targets = targets.squeeze() targets_cpu = targets.cpu().squeeze()
if outputs.shape[1] == 1: if outputs.shape[1] == 1:
preds = ( preds = outputs.cpu().squeeze() > threshold
outputs.cpu().squeeze() > 0.5
) # TODO: outputs from model are not between [0,1] yet, need to implement that
else: else:
preds = outputs.cpu().argmax(axis=1) preds = outputs.cpu().argmax(axis=1)
# if there is only one image # If there is only one image
if inputs.dim() == 2: if inputs_cpu.dim() == 2:
inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary inputs_cpu = inputs_cpu.unsqueeze(0).numpy()
targets = targets.unsqueeze(0) targets_cpu = targets_cpu.unsqueeze(0).numpy()
preds = preds.unsqueeze(0) preds = preds.unsqueeze(0).numpy()
return inputs, targets, preds # Append results to list
results.append((inputs_cpu, targets_cpu, preds))
def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray: return results
# Old implementation:
# else:
# # Check if data have the right format
# if not isinstance(data[0], tuple):
# raise ValueError("Data items must be tuples")
# # Check if data is torch tensors
# for element in data[0]:
# if not isinstance(element, torch.Tensor):
# raise ValueError("Data items must consist of tensors")
# # Check if input image is (C,H,W) format
# if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
# pass
# else:
# raise ValueError("Input image must be (C,H,W) format")
# # Make new list such that possible augmentations remain identical for all three rows
# plot_data = [data[idx] for idx in range(len(data))]
# # Create input and target batch
# inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
# targets = torch.stack([item[1] for item in plot_data], dim=0)
# # Get output predictions
# with torch.no_grad():
# outputs = model(inputs)
# # Prepare data for plotting
# inputs = inputs.cpu().squeeze()
# targets = targets.squeeze()
# if outputs.shape[1] == 1:
# preds = (
# outputs.cpu().squeeze() > threshold
# ) # TODO: outputs from model are not between [0,1] yet, need to implement that
# else:
# preds = outputs.cpu().argmax(axis=1)
# # if there is only one image
# if inputs.dim() == 2:
# inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
# targets = targets.unsqueeze(0)
# preds = preds.unsqueeze(0)
# return inputs, targets, preds
def volume_inference(
volume: np.ndarray,
model: torch.nn.Module,
threshold:float = 0.5,
) -> np.ndarray:
""" """
Compute on the entire volume Compute on the entire volume
Args: Args:
......
...@@ -69,7 +69,8 @@ class UNet(nn.Module): ...@@ -69,7 +69,8 @@ class UNet(nn.Module):
in_channels=1, # TODO: check if image has 1 or multiple input channels in_channels=1, # TODO: check if image has 1 or multiple input channels
out_channels=1, out_channels=1,
channels=self.channels, channels=self.channels,
strides=(2,) * (len(self.channels) - 1), strides=(2,) * (len(self.channels) - 1), # TODO: Check if the strides are correct?
num_res_units=2, # TODO: This was not here before
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size, up_kernel_size=self.up_kernel_size,
act=self.activation, act=self.activation,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment