diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index 219c208c16c1b720c92e6db6f98722043d1be590..276e33f08304ec548bd637881a84934c4312cdac 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -11,7 +11,6 @@ class Augmentation: transform_test (str, optional): level of transformation for the test set. mean (float, optional): The mean value for normalizing pixel intensities. std (float, optional): The standard deviation value for normalizing pixel intensities. - is_3d (bool, optional): Specifies if the images are 3D or 2D. Default is True. Raises: ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'. @@ -27,7 +26,6 @@ class Augmentation: transform_test: str | None = None, mean: float = 0.5, std: float = 0.5, - is_3d: bool = True, ): if resize not in ['crop', 'reshape', 'padding']: @@ -39,7 +37,6 @@ class Augmentation: self.transform_train = transform_train self.transform_validation = transform_validation self.transform_test = transform_test - self.is_3d = is_3d def augment(self, img_shape: tuple, level: str | None = None): """ @@ -57,37 +54,31 @@ class Augmentation: RandGaussianSmoothd, NormalizeIntensityd, Resized, CenterSpatialCropd, SpatialPadd ) - # Check if 2D or 3D - if len(img_shape) == 2: - im_h, im_w = img_shape - - elif len(img_shape) == 3: + # Check if image is 3D + if len(img_shape) == 3: im_d, im_h, im_w = img_shape + + else: + msg = f"Invalid image shape: {img_shape}. Must be 3D." + raise ValueError(msg) # Check if one of standard augmentation levels - if level not in [None,'light','moderate','heavy']: - raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") + if level not in [None,'light', 'moderate', 'heavy']: + msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'." + raise ValueError(msg) # Baseline augmentations + # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise) baseline_aug = [ToTensor()] - # For 2D, add normalization to the baseline augmentations - # TODO: Figure out how to properly do this in 3D (normalization should be done channel-wise) - if not self.is_3d: - # baseline_aug.append(NormalizeIntensity(subtrahend=self.mean, divisor=self.std)) - baseline_aug.append(NormalizeIntensityd(keys=["image"], subtrahend=self.mean, divisor=self.std)) - # Resize augmentations if self.resize == 'crop': - # resize_aug = [CenterSpatialCrop((im_d, im_h, im_w))] resize_aug = [CenterSpatialCropd(keys=["image", "label"], roi_size=(im_d, im_h, im_w))] elif self.resize == 'reshape': - # resize_aug = [Resize((im_d, im_h, im_w))] resize_aug = [Resized(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))] elif self.resize == 'padding': - # resize_aug = [SpatialPad((im_d, im_h, im_w))] resize_aug = [SpatialPadd(keys=["image", "label"], spatial_size=(im_d, im_h, im_w))] # Level of augmentation @@ -98,40 +89,25 @@ class Augmentation: resize_aug = [] elif level == 'light': - # level_aug = [RandRotate90(prob=1, spatial_axes=(0, 1))] + # TODO: Do rotations along other axes? level_aug = [RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1))] elif level == 'moderate': - # level_aug = [ - # RandRotate90(prob=1, spatial_axes=(0, 1)), - # RandFlip(prob=0.3, spatial_axis=0), - # RandFlip(prob=0.3, spatial_axis=1), - # RandGaussianSmooth(sigma_x=(0.7, 0.7), prob=0.1), - # RandAffine(prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)), - # ] level_aug = [ - RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), - RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0), - RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1), - RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1), - RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)), + RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), + RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=0), + RandFlipd(keys=["image", "label"], prob=0.3, spatial_axis=1), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.7, 0.7), prob=0.1), + RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.1, 0.1), scale_range=(0.9, 1.1)), ] elif level == 'heavy': - # level_aug = [ - # RandRotate90(prob=1, spatial_axes=(0, 1)), - # RandFlip(prob=0.7, spatial_axis=0), - # RandFlip(prob=0.7, spatial_axis=1), - # RandGaussianSmooth(sigma_x=(1.2, 1.2), prob=0.3), - # RandAffine(prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15)) - # ] - level_aug = [ - RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), - RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0), - RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1), - RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3), - RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15)) + RandRotate90d(keys=["image", "label"], prob=1, spatial_axes=(0, 1)), + RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=0), + RandFlipd(keys=["image", "label"], prob=0.7, spatial_axis=1), + RandGaussianSmoothd(keys=["image"], sigma_x=(1.2, 1.2), prob=0.3), + RandAffined(keys=["image", "label"], prob=0.5, translate_range=(0.2, 0.2), scale_range=(0.8, 1.4), shear_range=(-15, 15)) ] return Compose(baseline_aug + resize_aug + level_aug) \ No newline at end of file