From ddc40406aa3768b7502174276dc4d34850257a68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk> Date: Mon, 17 Feb 2025 14:55:56 +0100 Subject: [PATCH] removed 2D dataloader --- qim3d/ml/_data.py | 147 +++++++++++----------------------------------- 1 file changed, 33 insertions(+), 114 deletions(-) diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index d2bd93a1..a5717b79 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset): image_path = self.sample_images[idx] target_path = self.sample_targets[idx] - full_suffix = ''.join(image_path.suffixes) - - if full_suffix in ['.nii', '.nii.gz']: - - # Load 3D volume - image_data = nib.load(str(image_path)) - target_data = nib.load(str(target_path)) - - # Get data from volume - image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype()) - target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype()) - - # Add extra channel dimension - image = np.expand_dims(image, axis=0) - target = np.expand_dims(target, axis=0) - - else: - - # Load 2D image - image = Image.open(str(image_path)) - image = np.array(image) - target = Image.open(str(target_path)) - target = np.array(target) - - # Grayscale image - if len(image.shape) == 2 and len(target.shape) == 2: + # Load 3D volume + image_data = nib.load(str(image_path)) + target_data = nib.load(str(target_path)) - # Add channel dimension - image = np.expand_dims(image, axis=0) - target = np.expand_dims(target, axis=0) - - # RGB image - elif len(image.shape) == 3 and len(target.shape) == 3: + # Get data from volume + image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype()) + target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype()) - # Convert to (C, H, W) - image = image.transpose((2, 0, 1)) - target = target.transpose((2, 0, 1)) + # Add extra channel dimension + image = np.expand_dims(image, axis=0) + target = np.expand_dims(target, axis=0) if self.transform: transformed = self.transform({"image": image, "label": target}) image = transformed["image"] target = transformed["label"] - - # image = self.transform(image) # uint8 - # target = self.transform(target) # int32 - # TODO: Which dtype? image = image.clone().detach().to(dtype=torch.float32) target = target.clone().detach().to(dtype=torch.float32) @@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset): return consistency_check def _get_shape(self, image_path): - full_suffix = ''.join(image_path.suffixes) - - if full_suffix in ['.nii', '.nii.gz']: - # Load 3D volume - image = nib.load(str(image_path)).get_fdata() - return image.shape - - else: - # Load 2D image - image = Image.open(str(image_path)) - return image.size + # Load 3D volume + image = nib.load(str(image_path)).get_fdata() + return image.shape def check_resize( orig_shape: tuple, resize: tuple, - n_channels: int, - is_3d: bool + n_channels: int, ) -> tuple: """ Checks and adjusts the resize dimensions based on the original shape and the number of channels. @@ -164,7 +125,6 @@ def check_resize( orig_shape (tuple): Original shape of the image. resize (tuple): Desired resize dimensions. n_channels (int): Number of channels in the model. - is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True. Returns: tuple: Final resize dimensions. @@ -174,23 +134,22 @@ def check_resize( """ # 3D images - if is_3d: - orig_d, orig_h, orig_w = orig_shape - final_d = resize[0] if resize[0] else orig_d - final_h = resize[1] if resize[1] else orig_h - final_w = resize[2] if resize[2] else orig_w - - # Finding suitable size to upsize with padding - if resize == 'padding': - final_d = (orig_d // 2**n_channels + 1) * 2**n_channels - final_h = (orig_h // 2**n_channels + 1) * 2**n_channels - final_w = (orig_w // 2**n_channels + 1) * 2**n_channels - - # Finding suitable size to downsize with crop / resize - else: - final_d = (orig_d // 2**n_channels) * 2**n_channels - final_h = (orig_h // 2**n_channels) * 2**n_channels - final_w = (orig_w // 2**n_channels) * 2**n_channels + orig_d, orig_h, orig_w = orig_shape + final_d = resize[0] if resize[0] else orig_d + final_h = resize[1] if resize[1] else orig_h + final_w = resize[2] if resize[2] else orig_w + + # Finding suitable size to upsize with padding + if resize == 'padding': + final_d = (orig_d // 2**n_channels + 1) * 2**n_channels + final_h = (orig_h // 2**n_channels + 1) * 2**n_channels + final_w = (orig_w // 2**n_channels + 1) * 2**n_channels + + # Finding suitable size to downsize with crop / resize + else: + final_d = (orig_d // 2**n_channels) * 2**n_channels + final_h = (orig_h // 2**n_channels) * 2**n_channels + final_w = (orig_w // 2**n_channels) * 2**n_channels # Check if the image size is too small compared to the model's depth if final_d == 0 or final_h == 0 or final_w == 0: @@ -205,35 +164,6 @@ def check_resize( return final_d, final_h, final_w - # 2D images - else: - orig_h, orig_w = orig_shape - final_h = resize[0] if resize[0] else orig_h - final_w = resize[1] if resize[1] else orig_w - - # Finding suitable size to upsize with padding - if resize == 'padding': - final_h = (orig_h // 2**n_channels + 1) * 2**n_channels - final_w = (orig_w // 2**n_channels + 1) * 2**n_channels - - # Finding suitable size to downsize with crop / resize - else: - final_h = (orig_h // 2**n_channels) * 2**n_channels - final_w = (orig_w // 2**n_channels) * 2**n_channels - - # Check if the image size is too small compared to the model's depth - if final_h == 0 or final_w == 0: - msg = "The size of the image is too small compared to the depth of the UNet. \ - Choose a different 'resize' and/or a smaller model." - - raise ValueError(msg) - - if final_h != orig_h or final_w != orig_w: - log.warning(f"The image size doesn't match the Unet model's depth. \ - The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.") - - return final_h, final_w - def prepare_datasets( path: str, val_fraction: float, @@ -262,23 +192,12 @@ def prepare_datasets( # Determine if the dataset is 2D or 3D by checking the first image im_path = Path(path) / 'train' first_img = sorted((im_path / "images").iterdir())[0] - full_suffix = ''.join(first_img.suffixes) - # TODO: Support more formats for 3D images - if full_suffix in ['.nii', '.nii.gz']: - - # Load 3D volume - image = nib.load(str(first_img)).get_fdata() - orig_shape = image.shape - is_3d = True - - else: - # Load 2D image - image = Image.open(str(first_img)) - orig_shape = image.size[:2] - is_3d = False + # Load 3D volume + image = nib.load(str(first_img)).get_fdata() + orig_shape = image.shape - final_shape = check_resize(orig_shape, resize, n_channels, is_3d) + final_shape = check_resize(orig_shape, resize, n_channels) train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train)) val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation)) -- GitLab