Skip to content
Snippets Groups Projects
Commit 070fc09a authored by s193396's avatar s193396
Browse files

added support for 3D in data preparation functions

parent 93942fe3
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ from PIL import Image
from qim3d.utils import log
import torch
import numpy as np
import nibabel as nib
from typing import Optional, Callable
import torch.nn as nn
from ._augmentations import Augmentation
......@@ -54,7 +55,7 @@ class Dataset(torch.utils.data.Dataset):
self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
assert len(self.sample_images) == len(self.sample_targets)
# checking the characteristics of the dataset
# Checking the characteristics of the dataset
self.check_shape_consistency(self.sample_images)
def __len__(self):
......@@ -64,10 +65,22 @@ class Dataset(torch.utils.data.Dataset):
image_path = self.sample_images[idx]
target_path = self.sample_targets[idx]
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
if image_path.suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image_data = nib.load(str(image_path))
target_data = nib.load(str(target_path))
# Get data from volume
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
else:
# Load 2D image
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
if self.transform:
transformed = self.transform(image=image, mask=target)
......@@ -78,13 +91,13 @@ class Dataset(torch.utils.data.Dataset):
# TODO: working with images of different sizes
def check_shape_consistency(self,sample_images: tuple[str]):
def check_shape_consistency(self, sample_images: tuple[str]):
image_shapes= []
for image_path in sample_images:
image_shape = self._get_shape(image_path)
image_shapes.append(image_shape)
# check if all images have the same size.
# Check if all images have the same size
consistency_check = all(i == image_shapes[0] for i in image_shapes)
if not consistency_check:
......@@ -97,41 +110,102 @@ class Dataset(torch.utils.data.Dataset):
)
return consistency_check
def _get_shape(self,image_path):
return Image.open(str(image_path)).size
def _get_shape(self, image_path):
full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(image_path)).get_fdata()
return image.shape
else:
# Load 2D image
image = Image.open(str(image_path))
return image.size
def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
def check_resize(
orig_shape: tuple,
resize: tuple,
n_channels: int,
is_3d: bool
) -> tuple:
"""
Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
If so, the image is changed to the closest appropriate dimension, and the user is notified with a warning.
Checks and adjusts the resize dimensions based on the original shape and the number of channels.
Args:
im_height (int) : Height of the original image from the dataset.
im_width (int) : Width of the original image from the dataset.
resize (str) : Type of resize to be used on the image if the size doesn't fit the model.
orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model.
is_3d (bool): Whether the data is 3D or not.
Returns:
tuple: Final resize dimensions.
Raises:
ValueError: If the image size is smaller than minimum required for the model's depth.
"""
# finding suitable size to upsize with padding
if resize == 'padding':
h_adjust, w_adjust = (im_height // 2**n_channels+1) * 2**n_channels , (im_width // 2**n_channels+1) * 2**n_channels
# 3D images
if is_3d:
orig_d, orig_h, orig_w = orig_shape
final_d = resize[0] if resize[0] else orig_d
final_h = resize[1] if resize[1] else orig_h
final_w = resize[2] if resize[2] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# finding suitable size to downsize with crop / resize
else:
h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels
if h_adjust == 0 or w_adjust == 0:
raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
elif h_adjust != im_height or w_adjust != im_width:
log.warning(f"The image size doesn't match the Unet model's depth. The image is changed with '{resize}', from {im_height, im_width} to {h_adjust, w_adjust}.")
# Finding suitable size to downsize with crop / resize
else:
final_d = (orig_d // 2**n_channels) * 2**n_channels
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth
if final_d == 0 or final_h == 0 or final_w == 0:
msg = "The size of the image is too small compared to the depth of the UNet. \
Choose a different 'resize' and/or a smaller model."
raise ValueError(msg)
if final_d != orig_d or final_h != orig_h or final_w != orig_w:
log.warning(f"The image size doesn't match the Unet model's depth. \
The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
return final_d, final_h, final_w
return h_adjust, w_adjust
# 2D images
else:
orig_h, orig_w = orig_shape
final_h = resize[0] if resize[0] else orig_h
final_w = resize[1] if resize[1] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
else:
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth
if final_h == 0 or final_w == 0:
msg = "The size of the image is too small compared to the depth of the UNet. \
Choose a different 'resize' and/or a smaller model."
raise ValueError(msg)
if final_h != orig_h or final_w != orig_w:
log.warning(f"The image size doesn't match the Unet model's depth. \
The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
return final_h, final_w
def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
"""
......@@ -153,17 +227,29 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
resize = augmentation.resize
n_channels = len(model.channels)
# taking the size of the 1st image in the dataset
# Determine if the dataset is 2D or 3D by checking the first image
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img))
orig_h, orig_w = image.size[:2]
final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
full_suffix = ''.join(first_img.suffixes)
# TODO: Support more formats for 3D images
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
is_3d = True
else:
# Load 2D image
image = Image.open(str(first_img))
orig_shape = image.size[:2]
is_3d = False
final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment