Skip to content
Snippets Groups Projects
Commit d0b1d98a authored by s193396's avatar s193396
Browse files

expand image dimensions in dataloader

parent f068dd86
No related branches found
No related tags found
No related merge requests found
...@@ -65,7 +65,9 @@ class Dataset(torch.utils.data.Dataset): ...@@ -65,7 +65,9 @@ class Dataset(torch.utils.data.Dataset):
image_path = self.sample_images[idx] image_path = self.sample_images[idx]
target_path = self.sample_targets[idx] target_path = self.sample_targets[idx]
if image_path.suffix in ['.nii', '.nii.gz']: full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume # Load 3D volume
image_data = nib.load(str(image_path)) image_data = nib.load(str(image_path))
...@@ -75,21 +77,42 @@ class Dataset(torch.utils.data.Dataset): ...@@ -75,21 +77,42 @@ class Dataset(torch.utils.data.Dataset):
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype()) image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype()) target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# Add extra channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
else: else:
# Load 2D image # Load 2D image
image = Image.open(str(image_path)) image = Image.open(str(image_path))
image = np.array(image) image = np.array(image)
target = Image.open(str(target_path)) target = Image.open(str(target_path))
target = np.array(target) target = np.array(target)
# Grayscale image
if len(image.shape) == 2 and len(target.shape) == 2:
# Add channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
# RGB image
elif len(image.shape) == 3 and len(target.shape) == 3:
# Convert to (C, H, W)
image = image.transpose((2, 0, 1))
target = target.transpose((2, 0, 1))
if self.transform: if self.transform:
transformed = self.transform(image=image, mask=target) image = self.transform(image) # uint8
image = transformed["image"] target = self.transform(target) # int32
target = transformed["mask"]
# TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32)
target = target.clone().detach().to(dtype=torch.float32)
return image, target return image, target
# TODO: working with images of different sizes # TODO: working with images of different sizes
def check_shape_consistency(self, sample_images: tuple[str]): def check_shape_consistency(self, sample_images: tuple[str]):
image_shapes= [] image_shapes= []
...@@ -234,6 +257,7 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat ...@@ -234,6 +257,7 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
# TODO: Support more formats for 3D images # TODO: Support more formats for 3D images
if full_suffix in ['.nii', '.nii.gz']: if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume # Load 3D volume
image = nib.load(str(first_img)).get_fdata() image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape orig_shape = image.shape
...@@ -247,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat ...@@ -247,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
final_shape = check_resize(orig_shape, resize, n_channels, is_3d) final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train)) train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation)) val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test)) test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set))) split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set)) indices = torch.randperm(len(train_set))
...@@ -265,7 +289,7 @@ def prepare_dataloaders(train_set: torch.utils.data, ...@@ -265,7 +289,7 @@ def prepare_dataloaders(train_set: torch.utils.data,
test_set: torch.utils.data, test_set: torch.utils.data,
batch_size: int, batch_size: int,
shuffle_train: bool = True, shuffle_train: bool = True,
num_workers: int = 8, num_workers: int = 8,
pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
""" """
Prepares the dataloaders for model training. Prepares the dataloaders for model training.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment