diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 7bba1467fd1687f62a11dd7e18c162ddf9f69c13..03c3c399a1b70aa6f70b9e551d0d179b70590a5d 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -65,7 +65,9 @@ class Dataset(torch.utils.data.Dataset): image_path = self.sample_images[idx] target_path = self.sample_targets[idx] - if image_path.suffix in ['.nii', '.nii.gz']: + full_suffix = ''.join(image_path.suffixes) + + if full_suffix in ['.nii', '.nii.gz']: # Load 3D volume image_data = nib.load(str(image_path)) @@ -75,21 +77,42 @@ class Dataset(torch.utils.data.Dataset): image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype()) target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype()) + # Add extra channel dimension + image = np.expand_dims(image, axis=0) + target = np.expand_dims(target, axis=0) + else: + # Load 2D image image = Image.open(str(image_path)) image = np.array(image) target = Image.open(str(target_path)) target = np.array(target) + # Grayscale image + if len(image.shape) == 2 and len(target.shape) == 2: + + # Add channel dimension + image = np.expand_dims(image, axis=0) + target = np.expand_dims(target, axis=0) + + # RGB image + elif len(image.shape) == 3 and len(target.shape) == 3: + + # Convert to (C, H, W) + image = image.transpose((2, 0, 1)) + target = target.transpose((2, 0, 1)) + if self.transform: - transformed = self.transform(image=image, mask=target) - image = transformed["image"] - target = transformed["mask"] + image = self.transform(image) # uint8 + target = self.transform(target) # int32 + + # TODO: Which dtype? + image = image.clone().detach().to(dtype=torch.float32) + target = target.clone().detach().to(dtype=torch.float32) return image, target - # TODO: working with images of different sizes def check_shape_consistency(self, sample_images: tuple[str]): image_shapes= [] @@ -234,6 +257,7 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat # TODO: Support more formats for 3D images if full_suffix in ['.nii', '.nii.gz']: + # Load 3D volume image = nib.load(str(first_img)).get_fdata() orig_shape = image.shape @@ -247,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat final_shape = check_resize(orig_shape, resize, n_channels, is_3d) - train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train)) - val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation)) - test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test)) + train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train)) + val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation)) + test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test)) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set)) @@ -265,7 +289,7 @@ def prepare_dataloaders(train_set: torch.utils.data, test_set: torch.utils.data, batch_size: int, shuffle_train: bool = True, - num_workers: int = 8, + num_workers: int = 8, pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """ Prepares the dataloaders for model training.