Skip to content
Snippets Groups Projects
Commit 32fa3f08 authored by ofhkr's avatar ofhkr
Browse files

Changed to less drastic heavy augmentation, changed the way augmentation level...

Changed to less drastic heavy augmentation, changed the way augmentation level is chosen in prepare_datasets.
parent a695f8d0
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
...@@ -40,7 +40,7 @@ class Augmentation: ...@@ -40,7 +40,7 @@ class Augmentation:
self.transform_validation = transform_validation self.transform_validation = transform_validation
self.transform_test = transform_test self.transform_test = transform_test
def augment(self, im_h, im_w, level=None): def augment(self, im_h, im_w, type=None):
""" """
Returns an albumentations.core.composition.Compose class depending on the augmentation level. Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
...@@ -49,12 +49,19 @@ class Augmentation: ...@@ -49,12 +49,19 @@ class Augmentation:
Args: Args:
im_h (int): image height for resize. im_h (int): image height for resize.
im_w (int): image width for resize. im_w (int): image width for resize.
level (str, optional): level of augmentation. type (str, optional): level of augmentation.
Raises: Raises:
ValueError: If `level` is neither None, light, moderate nor heavy. ValueError: If `level` is neither None, light, moderate nor heavy.
""" """
if type=='train':
level = self.transform_train
elif type=='validation':
level = self.transform_validation
elif type=='test':
level = self.transform_test
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
...@@ -98,7 +105,7 @@ class Augmentation: ...@@ -98,7 +105,7 @@ class Augmentation:
A.HorizontalFlip(p = 0.7), A.HorizontalFlip(p = 0.7),
A.VerticalFlip(p = 0.7), A.VerticalFlip(p = 0.7),
A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3), A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)) A.Affine(scale = [0.9,1.1], translate_percent = (-0.2,0.2), shear = (-5,5))
] ]
augment = A.Compose(level_aug + resize_aug + baseline_aug) augment = A.Compose(level_aug + resize_aug + baseline_aug)
......
...@@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): ...@@ -161,9 +161,9 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'train'))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, 'validation'))
test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test)) test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, 'test'))
split_idx = int(np.floor(val_fraction * len(train_set))) split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set)) indices = torch.randperm(len(train_set))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment