diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index dc82e8e5a208ff10a0b9b7e51373208a881ac0f3..30a2a7f0df34ccea9917e51d6abd0d0f8aeff7b5 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -41,14 +41,12 @@ class Augmentation: self.transform_test = transform_test self.is_3d = is_3d - def augment(self, im_h: int, im_w: int, im_d: int | None = None, level: str | None = None): + def augment(self, img_shape: tuple, level: str | None = None): """ Creates an augmentation pipeline based on the specified level. Args: - im_h (int): Height of the image. - im_w (int): Width of the image. - im_d (int, optional): Depth of the image (for 3D). + img_shape (tuple): Dimensions of the image. level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy']. Raises: @@ -59,6 +57,13 @@ class Augmentation: RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad ) + # Check if 2D or 3D + if len(img_shape) == 2: + im_h, im_w = img_shape + + elif len(img_shape) == 3: + im_d, im_h, im_w = img_shape + # Check if one of standard augmentation levels if level not in [None,'light','moderate','heavy']: raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 03c3c399a1b70aa6f70b9e551d0d179b70590a5d..0a8fca48a613dee02c361e987c54b9a7d678279b 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -271,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat final_shape = check_resize(orig_shape, resize, n_channels, is_3d) - train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train)) - val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation)) - test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test)) + train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train)) + val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation)) + test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, level = augmentation.transform_test)) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set))