Skip to content
Snippets Groups Projects
Commit 3d734ce5 authored by s193396's avatar s193396
Browse files

temporary

parent ac6a7a44
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -53,8 +53,8 @@ class Augmentation:
ValueError: If `level` is neither None, light, moderate nor heavy.
"""
from monai.transforms import (
Compose, RandRotate90, RandFlip, RandAffine, ToTensor, \
RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad
Compose, RandRotate90d, RandFlipd, RandAffined, ToTensor, \
RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd
)
# Check if 2D or 3D
......@@ -74,41 +74,64 @@ class Augmentation:
# For 2D, add normalization to the baseline augmentations
# TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
if not self.is_3d:
baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
# baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std))
# Resize augmentations
if self.resize == 'crop':
resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))] if self.is_3d else [CenterSpatialCrop((im_h, im_w))]
# resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))]
resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))]
elif self.resize == 'reshape':
resize_aug = [Resize((im_d, im_h, im_w))] if self.is_3d else [Resize((im_h, im_w))]
# resize_aug = [Resize((im_d, im_h, im_w))]
resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
elif self.resize == 'padding':
resize_aug = [SpatialPad((im_d, im_h, im_w))] if self.is_3d else [SpatialPad((im_h, im_w))]
# resize_aug = [SpatialPad((im_d, im_h, im_w))]
resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
# Level of augmentation
if level == None:
# No augmentation for the validation and test sets
level_aug = []
resize_aug = []
elif level == 'light':
level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] if self.is_3d else [RandRotate90(prob=1)]
# level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))]
level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))]
elif level == 'moderate':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.3, spatial_axis=0),
# RandFlip(prob=0.3, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
# RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
# ]
level_aug = [
RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
RandFlip(prob=0.3, spatial_axis=0),
RandFlip(prob=0.3, spatial_axis=1),
RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
]
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
]
elif level == 'heavy':
level_aug = [
RandRotate90(prob=1, spatial_axes=(0, 1)) if self.is_3d else RandRotate90(prob=1),
RandFlip(prob=0.7, spatial_axis=0),
RandFlip(prob=0.7, spatial_axis=1),
RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
]
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.7, spatial_axis=0),
# RandFlip(prob=0.7, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
# RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
# ]
level_aug = [
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
]
return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file
......@@ -104,8 +104,12 @@ class Dataset(torch.utils.data.Dataset):
target = target.transpose((2, 0, 1))
if self.transform:
image = self.transform(image) # uint8
target = self.transform(target) # int32
transformed = self.transform({"image": image, "label": target})
image = transformed["image"]
target = transformed["label"]
# image = self.transform(image) # uint8
# target = self.transform(target) # int32
# TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32)
......@@ -160,7 +164,7 @@ def check_resize(
orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model.
is_3d (bool): Whether the data is 3D or not.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns:
tuple: Final resize dimensions.
......@@ -230,7 +234,12 @@ def check_resize(
return final_h, final_w
def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
def prepare_datasets(
path: str,
val_fraction: float,
model: nn.Module,
augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
"""
Splits and augments the train/validation/test datasets.
......
......@@ -132,6 +132,10 @@ def train_model(
f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, "
f"val loss: {val_loss['loss'][epoch]:.4f}"
)
# NOTE: Delete this again
# Save model checkpoint to .pth file
torch.save(model.state_dict(), "C:/Users\s193396/dataset/model.pth")
if plot:
plot_metrics(train_loss, val_loss, labels=["Train", "Valid."], show=True)
......@@ -163,7 +167,12 @@ def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Modul
return model_s
def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def inference(
data: torch.utils.data.Dataset,
model: torch.nn.Module,
threshold: float = 0.5,
is_3d: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list,
......@@ -177,6 +186,8 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
threshold (float): The threshold value used to binarize the model predictions.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
......@@ -194,59 +205,130 @@ def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[t
model = MySegmentationModel()
qim3d.ml.inference(data,model)
"""
# Get device
# Set model to evaluation mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
results = []
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# 3D data
if is_3d:
for volume, target in data:
if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
raise ValueError("Data items must consist of tensors")
# Check if input image is (C,H,W) format
if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
pass
else:
raise ValueError("Input image must be (C,H,W) format")
# Add batch and channel dimensions
volume = volume.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
target = target.unsqueeze(0).to(device) # Shape: [1, 1, D, H, W]
model.to(device)
model.eval()
with torch.no_grad():
# Get model predictions (logits)
output = model(volume)
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(len(data))]
# Convert logits to probabilities [0, 1]
preds = torch.sigmoid(output)
# Create input and target batch
inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
targets = torch.stack([item[1] for item in plot_data], dim=0)
# Convert to binary mask by thresholding the probabilities
preds = (preds > threshold).float()
# Get output predictions
with torch.no_grad():
outputs = model(inputs)
# Remove batch and channel dimensions
volume = volume.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
preds = preds.squeeze().cpu().numpy()
# Prepare data for plotting
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = (
outputs.cpu().squeeze() > 0.5
) # TODO: outputs from model are not between [0,1] yet, need to implement that
# Append results to list
results.append((volume, target, preds))
# 2D data
else:
preds = outputs.cpu().argmax(axis=1)
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data items must be tuples")
# if there is only one image
if inputs.dim() == 2:
inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
targets = targets.unsqueeze(0)
preds = preds.unsqueeze(0)
# Check if data is torch tensors
for element in data[0]:
if not isinstance(element, torch.Tensor):
raise ValueError("Data items must consist of tensors")
return inputs, targets, preds
for inputs, targets in data:
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
outputs = model(inputs)
def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray:
# Prepare data for plotting
inputs_cpu = inputs.cpu().squeeze()
targets_cpu = targets.cpu().squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > threshold
else:
preds = outputs.cpu().argmax(axis=1)
# If there is only one image
if inputs_cpu.dim() == 2:
inputs_cpu = inputs_cpu.unsqueeze(0).numpy()
targets_cpu = targets_cpu.unsqueeze(0).numpy()
preds = preds.unsqueeze(0).numpy()
# Append results to list
results.append((inputs_cpu, targets_cpu, preds))
return results
# Old implementation:
# else:
# # Check if data have the right format
# if not isinstance(data[0], tuple):
# raise ValueError("Data items must be tuples")
# # Check if data is torch tensors
# for element in data[0]:
# if not isinstance(element, torch.Tensor):
# raise ValueError("Data items must consist of tensors")
# # Check if input image is (C,H,W) format
# if data[0][0].dim() == 3 and (data[0][0].shape[0] in [1, 3]):
# pass
# else:
# raise ValueError("Input image must be (C,H,W) format")
# # Make new list such that possible augmentations remain identical for all three rows
# plot_data = [data[idx] for idx in range(len(data))]
# # Create input and target batch
# inputs = torch.stack([item[0] for item in plot_data], dim=0).to(device)
# targets = torch.stack([item[1] for item in plot_data], dim=0)
# # Get output predictions
# with torch.no_grad():
# outputs = model(inputs)
# # Prepare data for plotting
# inputs = inputs.cpu().squeeze()
# targets = targets.squeeze()
# if outputs.shape[1] == 1:
# preds = (
# outputs.cpu().squeeze() > threshold
# ) # TODO: outputs from model are not between [0,1] yet, need to implement that
# else:
# preds = outputs.cpu().argmax(axis=1)
# # if there is only one image
# if inputs.dim() == 2:
# inputs = inputs.unsqueeze(0) # TODO: Not sure if unsqueeze (add extra dimension) is necessary
# targets = targets.unsqueeze(0)
# preds = preds.unsqueeze(0)
# return inputs, targets, preds
def volume_inference(
volume: np.ndarray,
model: torch.nn.Module,
threshold:float = 0.5,
) -> np.ndarray:
"""
Compute on the entire volume
Args:
......
......@@ -69,7 +69,8 @@ class UNet(nn.Module):
in_channels=1, # TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
strides=(2,) * (len(self.channels) - 1), # TODO: Check if the strides are correct?
num_res_units=2, # TODO: This was not here before
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment