Skip to content
Snippets Groups Projects
Commit b19d3cd4 authored by fima's avatar fima :beers:
Browse files

Merge branch 'unet_crop' into 'main'

Adding the crop and padding as resizing features.

See merge request !17
parents 2e4ba1b2 8fea762c
Branches
Tags
1 merge request!17Adding the crop and padding as resizing features.
This diff is collapsed.
......@@ -8,7 +8,7 @@ class Augmentation:
Class for defining image augmentation transformations using the Albumentations library.
Args:
resize ((int,tuple), optional): The target size to resize the image.
resize (str, optional): Specifies how the images should be reshaped to the appropriate size.
trainsform_train (str, optional): level of transformation for the training set.
transform_validation (str, optional): level of transformation for the validation set.
transform_test (str, optional): level of transformation for the test set.
......@@ -16,14 +16,14 @@ class Augmentation:
std (float, optional): The standard deviation value for normalizing pixel intensities.
Raises:
ValueError: If `resize` is neither a None, int nor tuple.
ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
Example:
my_augmentation = Augmentation(resize = (256,256), transform_train = 'heavy')
my_augmentation = Augmentation(resize = 'crop', transform_train = 'heavy')
"""
def __init__(self,
resize = None,
resize = 'crop',
transform_train = 'moderate',
transform_validation = None,
transform_test = None,
......@@ -31,8 +31,8 @@ class Augmentation:
std: float = 0.5
):
if not isinstance(resize,(type(None),int,tuple)):
raise ValueError(f"Invalid input for resize: {resize}. Use an integer or tuple to modify the data.")
if resize not in ['crop', 'reshape', 'padding']:
raise ValueError(f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'.")
self.resize = resize
self.mean = mean
......@@ -62,10 +62,21 @@ class Augmentation:
# Baseline
baseline_aug = [
A.Resize(im_h, im_w),
A.Normalize(mean = (self.mean),std = (self.std)),
ToTensorV2()
]
if self.resize == 'crop':
resize_aug = [
A.CenterCrop(im_h,im_w)
]
elif self.resize == 'reshape':
resize_aug =[
A.Resize(im_h,im_w)
]
elif self.resize == 'padding':
resize_aug = [
A.PadIfNeeded(im_h,im_w,border_mode = 0) # OpenCV border mode
]
# Level of augmentation
if level == None:
......@@ -91,6 +102,6 @@ class Augmentation:
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
]
augment = A.Compose(level_aug + baseline_aug)
augment = A.Compose(level_aug + resize_aug + baseline_aug)
return augment
\ No newline at end of file
......@@ -101,27 +101,34 @@ class Dataset(torch.utils.data.Dataset):
return Image.open(str(image_path)).size
def check_resize(im_height: int, im_width: int, n_channels: int):
def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
"""
Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
If so, the image is reshaped to the closest appropriate dimension, and the user is notified with a warning.
If so, the image is changed to the closest appropriate dimension, and the user is notified with a warning.
Args:
im_height (int): Height of the image chosen by the user.
im_width (int): Width of the image chosen by the user.
im_height (int) : Height of the original image from the dataset.
im_width (int) : Width of the original image from the dataset.
resize (str) : Type of resize to be used on the image if the size doesn't fit the model.
n_channels (int): Number of channels in the model.
Raises:
ValueError: If the image size is smaller than minimum required for the model's depth.
"""
# finding suitable size to upsize with padding
if resize == 'padding':
h_adjust, w_adjust = (im_height // 2**n_channels+1) * 2**n_channels , (im_width // 2**n_channels+1) * 2**n_channels
# finding suitable size to downsize with crop / resize
else:
h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels
if h_adjust == 0 or w_adjust == 0:
raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
elif (h_adjust!=im_height) or (w_adjust != im_width):
log.warning(f"The image size doesn't match the Unet model's depth. The image is resized to: {h_adjust,w_adjust}")
elif h_adjust != im_height or w_adjust != im_width:
log.warning(f"The image size doesn't match the Unet model's depth. The image is changed with '{resize}', from {im_height, im_width} to {h_adjust, w_adjust}.")
return h_adjust, w_adjust
......@@ -146,20 +153,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
resize = augmentation.resize
n_channels = len(model.channels)
if isinstance(resize,type(None)):
# OPEN THE FIRST IMAGE
# taking the size of the 1st image in the dataset
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img))
im_h, im_w = image.size[:2]
log.info("User did not choose a specific value for 'resize'. Checking the first image in the dataset...")
elif isinstance(resize,int):
im_h, im_w = resize, resize
else:
im_h,im_w = resize
orig_h, orig_w = image.size[:2]
final_h, final_w = check_resize(im_h, im_w, n_channels)
final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
......
......@@ -140,9 +140,6 @@ def model_summary(dataloader,model):
print(summary)
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment