Skip to content
Snippets Groups Projects
Commit f79041f8 authored by s193396's avatar s193396
Browse files

removed 2D augmentations

parent ddc40406
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,6 @@ class Augmentation:
transform_test (str, optional): level of transformation for the test set.
mean (float, optional): The mean value for normalizing pixel intensities.
std (float, optional): The standard deviation value for normalizing pixel intensities.
is_3d (bool, optional): Specifies if the images are 3D or 2D. Default is True.
Raises:
ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
......@@ -27,7 +26,6 @@ class Augmentation:
transform_test: str | None = None,
mean: float = 0.5,
std: float = 0.5,
is_3d: bool = True,
):
if resize not in ['crop', 'reshape', 'padding']:
......@@ -39,7 +37,6 @@ class Augmentation:
self.transform_train = transform_train
self.transform_validation = transform_validation
self.transform_test = transform_test
self.is_3d = is_3d
def augment(self, img_shape: tuple, level: str | None = None):
"""
......@@ -57,37 +54,31 @@ class Augmentation:
RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd
)
# Check if 2D or 3D
if len(img_shape) == 2:
im_h, im_w = img_shape
elif len(img_shape) == 3:
# Check if image is 3D
if len(img_shape) == 3:
im_d, im_h, im_w = img_shape
else:
msg = f"Invalid image shape: {img_shape}. Must be 3D."
raise ValueError(msg)
# Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
if level not in [None,'light', 'moderate', 'heavy']:
msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
raise ValueError(msg)
# Baseline augmentations
# TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
baseline_aug = [ToTensor()]
# For 2D, add normalization to the baseline augmentations
# TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise)
if not self.is_3d:
# baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std))
baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std))
# Resize augmentations
if self.resize == 'crop':
# resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))]
resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))]
elif self.resize == 'reshape':
# resize_aug = [Resize((im_d, im_h, im_w))]
resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
elif self.resize == 'padding':
# resize_aug = [SpatialPad((im_d, im_h, im_w))]
resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))]
# Level of augmentation
......@@ -98,40 +89,25 @@ class Augmentation:
resize_aug = []
elif level == 'light':
# level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))]
# TODO: Do rotations along other axes?
level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))]
elif level == 'moderate':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.3, spatial_axis=0),
# RandFlip(prob=0.3, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1),
# RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
# ]
level_aug = [
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)),
]
elif level == 'heavy':
# level_aug = [
# RandRotate90(prob=1, spatial_axes=(0, 1)),
# RandFlip(prob=0.7, spatial_axis=0),
# RandFlip(prob=0.7, spatial_axis=1),
# RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3),
# RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
# ]
level_aug = [
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1),
RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3),
RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15))
]
return Compose(baseline_aug + resize_aug + level_aug)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment