Skip to content
Snippets Groups Projects
Commit ac6a7a44 authored by s193396's avatar s193396
Browse files

dimensions as tuple instead of unpacked

parent d0b1d98a
No related branches found
No related tags found
No related merge requests found
......@@ -41,14 +41,12 @@ class Augmentation:
self.transform_test = transform_test
self.is_3d = is_3d
def augment(self, im_h: int, im_w: int, im_d: int | None = None, level: str | None = None):
def augment(self, img_shape: tuple, level: str | None = None):
"""
Creates an augmentation pipeline based on the specified level.
Args:
im_h (int): Height of the image.
im_w (int): Width of the image.
im_d (int, optional): Depth of the image (for 3D).
img_shape (tuple): Dimensions of the image.
level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy'].
Raises:
......@@ -59,6 +57,13 @@ class Augmentation:
RandGaussianSmooth, NormalizeIntensity, Resize, CenterSpatialCrop, SpatialPad
)
# Check if 2D or 3D
if len(img_shape) == 2:
im_h, im_w = img_shape
elif len(img_shape) == 3:
im_d, im_h, im_w = img_shape
# Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
......
......@@ -271,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test))
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, level = augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment