From 070fc09af3f4b71654015e5a347e578363b46d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk> Date: Tue, 11 Feb 2025 09:54:37 +0100 Subject: [PATCH] added support for 3D in data preparation functions --- qim3d/ml/_data.py | 160 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 123 insertions(+), 37 deletions(-) diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 253da4a1..7bba1467 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -4,6 +4,7 @@ from PIL import Image from qim3d.utils import log import torch import numpy as np +import nibabel as nib from typing import Optional, Callable import torch.nn as nn from ._augmentations import Augmentation @@ -54,7 +55,7 @@ class Dataset(torch.utils.data.Dataset): self.sample_targets = [file for file in sorted((path / "labels").iterdir())] assert len(self.sample_images) == len(self.sample_targets) - # checking the characteristics of the dataset + # Checking the characteristics of the dataset self.check_shape_consistency(self.sample_images) def __len__(self): @@ -64,10 +65,22 @@ class Dataset(torch.utils.data.Dataset): image_path = self.sample_images[idx] target_path = self.sample_targets[idx] - image = Image.open(str(image_path)) - image = np.array(image) - target = Image.open(str(target_path)) - target = np.array(target) + if image_path.suffix in ['.nii', '.nii.gz']: + + # Load 3D volume + image_data = nib.load(str(image_path)) + target_data = nib.load(str(target_path)) + + # Get data from volume + image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype()) + target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype()) + + else: + # Load 2D image + image = Image.open(str(image_path)) + image = np.array(image) + target = Image.open(str(target_path)) + target = np.array(target) if self.transform: transformed = self.transform(image=image, mask=target) @@ -78,13 +91,13 @@ class Dataset(torch.utils.data.Dataset): # TODO: working with images of different sizes - def check_shape_consistency(self,sample_images: tuple[str]): + def check_shape_consistency(self, sample_images: tuple[str]): image_shapes= [] for image_path in sample_images: image_shape = self._get_shape(image_path) image_shapes.append(image_shape) - # check if all images have the same size. + # Check if all images have the same size consistency_check = all(i == image_shapes[0] for i in image_shapes) if not consistency_check: @@ -97,41 +110,102 @@ class Dataset(torch.utils.data.Dataset): ) return consistency_check - def _get_shape(self,image_path): - return Image.open(str(image_path)).size + def _get_shape(self, image_path): + full_suffix = ''.join(image_path.suffixes) + + if full_suffix in ['.nii', '.nii.gz']: + # Load 3D volume + image = nib.load(str(image_path)).get_fdata() + return image.shape + + else: + # Load 2D image + image = Image.open(str(image_path)) + return image.size -def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]: +def check_resize( + orig_shape: tuple, + resize: tuple, + n_channels: int, + is_3d: bool + ) -> tuple: """ - Checks the compatibility of the image shape with the depth of the model. - If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate. - If so, the image is changed to the closest appropriate dimension, and the user is notified with a warning. + Checks and adjusts the resize dimensions based on the original shape and the number of channels. Args: - im_height (int) : Height of the original image from the dataset. - im_width (int) : Width of the original image from the dataset. - resize (str) : Type of resize to be used on the image if the size doesn't fit the model. + orig_shape (tuple): Original shape of the image. + resize (tuple): Desired resize dimensions. n_channels (int): Number of channels in the model. + is_3d (bool): Whether the data is 3D or not. + + Returns: + tuple: Final resize dimensions. Raises: ValueError: If the image size is smaller than minimum required for the model's depth. """ - # finding suitable size to upsize with padding - if resize == 'padding': - h_adjust, w_adjust = (im_height // 2**n_channels+1) * 2**n_channels , (im_width // 2**n_channels+1) * 2**n_channels + + # 3D images + if is_3d: + orig_d, orig_h, orig_w = orig_shape + final_d = resize[0] if resize[0] else orig_d + final_h = resize[1] if resize[1] else orig_h + final_w = resize[2] if resize[2] else orig_w + + # Finding suitable size to upsize with padding + if resize == 'padding': + final_d = (orig_d // 2**n_channels + 1) * 2**n_channels + final_h = (orig_h // 2**n_channels + 1) * 2**n_channels + final_w = (orig_w // 2**n_channels + 1) * 2**n_channels - # finding suitable size to downsize with crop / resize - else: - h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels - - if h_adjust == 0 or w_adjust == 0: - raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.") - - elif h_adjust != im_height or w_adjust != im_width: - log.warning(f"The image size doesn't match the Unet model's depth. The image is changed with '{resize}', from {im_height, im_width} to {h_adjust, w_adjust}.") + # Finding suitable size to downsize with crop / resize + else: + final_d = (orig_d // 2**n_channels) * 2**n_channels + final_h = (orig_h // 2**n_channels) * 2**n_channels + final_w = (orig_w // 2**n_channels) * 2**n_channels + + # Check if the image size is too small compared to the model's depth + if final_d == 0 or final_h == 0 or final_w == 0: + msg = "The size of the image is too small compared to the depth of the UNet. \ + Choose a different 'resize' and/or a smaller model." + + raise ValueError(msg) + + if final_d != orig_d or final_h != orig_h or final_w != orig_w: + log.warning(f"The image size doesn't match the Unet model's depth. \ + The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.") + + return final_d, final_h, final_w - return h_adjust, w_adjust + # 2D images + else: + orig_h, orig_w = orig_shape + final_h = resize[0] if resize[0] else orig_h + final_w = resize[1] if resize[1] else orig_w + + # Finding suitable size to upsize with padding + if resize == 'padding': + final_h = (orig_h // 2**n_channels + 1) * 2**n_channels + final_w = (orig_w // 2**n_channels + 1) * 2**n_channels + + # Finding suitable size to downsize with crop / resize + else: + final_h = (orig_h // 2**n_channels) * 2**n_channels + final_w = (orig_w // 2**n_channels) * 2**n_channels + + # Check if the image size is too small compared to the model's depth + if final_h == 0 or final_w == 0: + msg = "The size of the image is too small compared to the depth of the UNet. \ + Choose a different 'resize' and/or a smaller model." + + raise ValueError(msg) + + if final_h != orig_h or final_w != orig_w: + log.warning(f"The image size doesn't match the Unet model's depth. \ + The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.") + return final_h, final_w def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]: """ @@ -153,17 +227,29 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat resize = augmentation.resize n_channels = len(model.channels) - # taking the size of the 1st image in the dataset + # Determine if the dataset is 2D or 3D by checking the first image im_path = Path(path) / 'train' first_img = sorted((im_path / "images").iterdir())[0] - image = Image.open(str(first_img)) - orig_h, orig_w = image.size[:2] - - final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) + full_suffix = ''.join(first_img.suffixes) + + # TODO: Support more formats for 3D images + if full_suffix in ['.nii', '.nii.gz']: + # Load 3D volume + image = nib.load(str(first_img)).get_fdata() + orig_shape = image.shape + is_3d = True + + else: + # Load 2D image + image = Image.open(str(first_img)) + orig_shape = image.size[:2] + is_3d = False + + final_shape = check_resize(orig_shape, resize, n_channels, is_3d) - train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) - val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) - test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test)) + train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train)) + val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation)) + test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test)) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set)) -- GitLab