Skip to content
Snippets Groups Projects
Commit f79041f8 authored by s193396's avatar s193396
Browse files

removed 2D augmentations

parent ddc40406
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,6 @@ class Augmentation: ...@@ -11,7 +11,6 @@ class Augmentation:
transform_test (str, optional): level of transformation for the test set. transform_test (str, optional): level of transformation for the test set.
mean (float, optional): The mean value for normalizing pixel intensities. mean (float, optional): The mean value for normalizing pixel intensities.
std (float, optional): The standard deviation value for normalizing pixel intensities. std (float, optional): The standard deviation value for normalizing pixel intensities.
is_3d (bool, optional): Specifies if the images are 3D or 2D. Default is True.
Raises: Raises:
ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'. ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
...@@ -27,7 +26,6 @@ class Augmentation: ...@@ -27,7 +26,6 @@ class Augmentation:
transform_test: str | None = None, transform_test: str | None = None,
mean: float = 0.5, mean: float = 0.5,
std: float = 0.5, std: float = 0.5,
is_3d: bool = True,
): ):
if resize not in ['crop', 'reshape', 'padding']: if resize not in ['crop', 'reshape', 'padding']:
...@@ -39,7 +37,6 @@ class Augmentation: ...@@ -39,7 +37,6 @@ class Augmentation:
self.transform_train = transform_train self.transform_train = transform_train
self.transform_validation = transform_validation self.transform_validation = transform_validation
self.transform_test = transform_test self.transform_test = transform_test
self.is_3d = is_3d
def augment(self, img_shape: tuple, level: str | None = None): def augment(self, img_shape: tuple, level: str | None = None):
""" """
...@@ -57,37 +54,31 @@ class Augmentation: ...@@ -57,37 +54,31 @@ class Augmentation:
RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd
) )
# Check if 2D or 3D # Check if image is 3D
if len(img_shape) == 2: if len(img_shape) == 3:
im_h, im_w = img_shape
elif len(img_shape) == 3:
im_d, im_h, im_w = img_shape im_d, im_h, im_w = img_shape
else:
msg = f"Invalid image shape: {img_shape}. Must be 3D."
raise ValueError(msg)
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light', 'moderate', 'heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
raise ValueError(msg)
# Baseline augmentations # Baseline augmentations
# TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
baseline_aug = [ToTensor()] baseline_aug = [ToTensor()]
# For 2D, add normalization to the baseline augmentations
# TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
if not self.is_3d:
# baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std))
# Resize augmentations # Resize augmentations
if self.resize == 'crop': if self.resize == 'crop':
# resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))]
resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))] resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))]
elif self.resize == 'reshape': elif self.resize == 'reshape':
# resize_aug = [Resize((im_d, im_h, im_w))]
resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))] resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
elif self.resize == 'padding': elif self.resize == 'padding':
# resize_aug = [SpatialPad((im_d, im_h, im_w))]
resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))] resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
# Level of augmentation # Level of augmentation
...@@ -98,40 +89,25 @@ class Augmentation: ...@@ -98,40 +89,25 @@ class Augmentation:
resize_aug = [] resize_aug = []
elif level == 'light': elif level == 'light':
# level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] # TODO: Do rotations along other axes?
level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))] level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))]
elif level == 'moderate': elif level == 'moderate':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.3, spatial_axis=0),
# RandFlip(prob=0.3, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
# RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
# ]
level_aug = [ level_aug = [
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1), RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)), RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
] ]
elif level == 'heavy': elif level == 'heavy':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.7, spatial_axis=0),
# RandFlip(prob=0.7, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
# RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
# ]
level_aug = [ level_aug = [
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3), RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15)) RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
] ]
return Compose(baseline_aug + resize_aug + level_aug) return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment