Skip to content
Snippets Groups Projects
Commit d0b1d98a authored by s193396's avatar s193396
Browse files

expand image dimensions in dataloader

parent f068dd86
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,9 @@ class Dataset(torch.utils.data.Dataset):
image_path = self.sample_images[idx]
target_path = self.sample_targets[idx]
if image_path.suffix in ['.nii', '.nii.gz']:
full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image_data = nib.load(str(image_path))
......@@ -75,21 +77,42 @@ class Dataset(torch.utils.data.Dataset):
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# Add extra channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
else:
# Load 2D image
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
# Grayscale image
if len(image.shape) == 2 and len(target.shape) == 2:
# Add channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
# RGB image
elif len(image.shape) == 3 and len(target.shape) == 3:
# Convert to (C, H, W)
image = image.transpose((2, 0, 1))
target = target.transpose((2, 0, 1))
if self.transform:
transformed = self.transform(image=image, mask=target)
image = transformed["image"]
target = transformed["mask"]
image = self.transform(image) # uint8
target = self.transform(target) # int32
# TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32)
target = target.clone().detach().to(dtype=torch.float32)
return image, target
# TODO: working with images of different sizes
def check_shape_consistency(self, sample_images: tuple[str]):
image_shapes= []
......@@ -234,6 +257,7 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
# TODO: Support more formats for 3D images
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
......@@ -247,9 +271,9 @@ def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentat
final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(final_shape, augmentation.transform_test))
train_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(*final_shape, level = augmentation.transform_validation))
test_set = Dataset(root_path=path, split='test', transform=augmentation.augment(*final_shape, level = augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
......@@ -265,7 +289,7 @@ def prepare_dataloaders(train_set: torch.utils.data,
test_set: torch.utils.data,
batch_size: int,
shuffle_train: bool = True,
num_workers: int = 8,
num_workers: int = 8,
pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
"""
Prepares the dataloaders for model training.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment