Skip to content
Snippets Groups Projects
Commit 8fea762c authored by ofhkr's avatar ofhkr Committed by fima
Browse files

Adding the crop and padding as resizing features.

parent 2e4ba1b2
Branches
Tags v0.3.2
1 merge request!17Adding the crop and padding as resizing features.
This diff is collapsed.
...@@ -8,7 +8,7 @@ class Augmentation: ...@@ -8,7 +8,7 @@ class Augmentation:
Class for defining image augmentation transformations using the Albumentations library. Class for defining image augmentation transformations using the Albumentations library.
Args: Args:
resize ((int,tuple), optional): The target size to resize the image. resize (str, optional): Specifies how the images should be reshaped to the appropriate size.
trainsform_train (str, optional): level of transformation for the training set. trainsform_train (str, optional): level of transformation for the training set.
transform_validation (str, optional): level of transformation for the validation set. transform_validation (str, optional): level of transformation for the validation set.
transform_test (str, optional): level of transformation for the test set. transform_test (str, optional): level of transformation for the test set.
...@@ -16,14 +16,14 @@ class Augmentation: ...@@ -16,14 +16,14 @@ class Augmentation:
std (float, optional): The standard deviation value for normalizing pixel intensities. std (float, optional): The standard deviation value for normalizing pixel intensities.
Raises: Raises:
ValueError: If `resize` is neither a None, int nor tuple. ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
Example: Example:
my_augmentation = Augmentation(resize = (256,256), transform_train = 'heavy') my_augmentation = Augmentation(resize = 'crop', transform_train = 'heavy')
""" """
def __init__(self, def __init__(self,
resize = None, resize = 'crop',
transform_train = 'moderate', transform_train = 'moderate',
transform_validation = None, transform_validation = None,
transform_test = None, transform_test = None,
...@@ -31,8 +31,8 @@ class Augmentation: ...@@ -31,8 +31,8 @@ class Augmentation:
std: float = 0.5 std: float = 0.5
): ):
if not isinstance(resize,(type(None),int,tuple)): if resize not in ['crop', 'reshape', 'padding']:
raise ValueError(f"Invalid input for resize: {resize}. Use an integer or tuple to modify the data.") raise ValueError(f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'.")
self.resize = resize self.resize = resize
self.mean = mean self.mean = mean
...@@ -62,10 +62,21 @@ class Augmentation: ...@@ -62,10 +62,21 @@ class Augmentation:
# Baseline # Baseline
baseline_aug = [ baseline_aug = [
A.Resize(im_h, im_w),
A.Normalize(mean = (self.mean),std = (self.std)), A.Normalize(mean = (self.mean),std = (self.std)),
ToTensorV2() ToTensorV2()
] ]
if self.resize == 'crop':
resize_aug = [
A.CenterCrop(im_h,im_w)
]
elif self.resize == 'reshape':
resize_aug =[
A.Resize(im_h,im_w)
]
elif self.resize == 'padding':
resize_aug = [
A.PadIfNeeded(im_h,im_w,border_mode = 0) # OpenCV border mode
]
# Level of augmentation # Level of augmentation
if level == None: if level == None:
...@@ -91,6 +102,6 @@ class Augmentation: ...@@ -91,6 +102,6 @@ class Augmentation:
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)) A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
] ]
augment = A.Compose(level_aug + baseline_aug) augment = A.Compose(level_aug + resize_aug + baseline_aug)
return augment return augment
\ No newline at end of file
...@@ -101,27 +101,34 @@ class Dataset(torch.utils.data.Dataset): ...@@ -101,27 +101,34 @@ class Dataset(torch.utils.data.Dataset):
return Image.open(str(image_path)).size return Image.open(str(image_path)).size
def check_resize(im_height: int, im_width: int, n_channels: int): def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
""" """
Checks the compatibility of the image shape with the depth of the model. Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate. If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
If so, the image is reshaped to the closest appropriate dimension, and the user is notified with a warning. If so, the image is changed to the closest appropriate dimension, and the user is notified with a warning.
Args: Args:
im_height (int): Height of the image chosen by the user. im_height (int) : Height of the original image from the dataset.
im_width (int): Width of the image chosen by the user. im_width (int) : Width of the original image from the dataset.
resize (str) : Type of resize to be used on the image if the size doesn't fit the model.
n_channels (int): Number of channels in the model. n_channels (int): Number of channels in the model.
Raises: Raises:
ValueError: If the image size is smaller than minimum required for the model's depth. ValueError: If the image size is smaller than minimum required for the model's depth.
""" """
# finding suitable size to upsize with padding
if resize == 'padding':
h_adjust, w_adjust = (im_height // 2**n_channels+1) * 2**n_channels , (im_width // 2**n_channels+1) * 2**n_channels
# finding suitable size to downsize with crop / resize
else:
h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels
if h_adjust == 0 or w_adjust == 0: if h_adjust == 0 or w_adjust == 0:
raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.") raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
elif (h_adjust!=im_height) or (w_adjust != im_width):
log.warning(f"The image size doesn't match the Unet model's depth. The image is resized to: {h_adjust,w_adjust}") elif h_adjust != im_height or w_adjust != im_width:
log.warning(f"The image size doesn't match the Unet model's depth. The image is changed with '{resize}', from {im_height, im_width} to {h_adjust, w_adjust}.")
return h_adjust, w_adjust return h_adjust, w_adjust
...@@ -146,20 +153,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): ...@@ -146,20 +153,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
resize = augmentation.resize resize = augmentation.resize
n_channels = len(model.channels) n_channels = len(model.channels)
if isinstance(resize,type(None)): # taking the size of the 1st image in the dataset
# OPEN THE FIRST IMAGE
im_path = Path(path) / 'train' im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0] first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img)) image = Image.open(str(first_img))
im_h, im_w = image.size[:2] orig_h, orig_w = image.size[:2]
log.info("User did not choose a specific value for 'resize'. Checking the first image in the dataset...")
elif isinstance(resize,int):
im_h, im_w = resize, resize
else:
im_h,im_w = resize
final_h, final_w = check_resize(im_h, im_w, n_channels) final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
......
...@@ -140,9 +140,6 @@ def model_summary(dataloader,model): ...@@ -140,9 +140,6 @@ def model_summary(dataloader,model):
print(summary) print(summary)
""" """
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
images,_ = next(iter(dataloader)) images,_ = next(iter(dataloader))
batch_size = tuple(images.shape) batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf) model_s = summary(model,batch_size,depth = torch.inf)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment