Skip to content
Snippets Groups Projects
Commit ddc40406 authored by s193396's avatar s193396
Browse files

removed 2D dataloader

parent 6f8cb981
No related branches found
No related tags found
No related merge requests found
......@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset):
image_path = self.sample_images[idx]
target_path = self.sample_targets[idx]
full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image_data = nib.load(str(image_path))
target_data = nib.load(str(target_path))
# Get data from volume
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# Add extra channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
else:
# Load 2D image
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
# Grayscale image
if len(image.shape) == 2 and len(target.shape) == 2:
# Load 3D volume
image_data = nib.load(str(image_path))
target_data = nib.load(str(target_path))
# Add channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
# RGB image
elif len(image.shape) == 3 and len(target.shape) == 3:
# Get data from volume
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# Convert to (C, H, W)
image = image.transpose((2, 0, 1))
target = target.transpose((2, 0, 1))
# Add extra channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
if self.transform:
transformed = self.transform({"image": image, "label": target})
image = transformed["image"]
target = transformed["label"]
# image = self.transform(image) # uint8
# target = self.transform(target) # int32
# TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32)
target = target.clone().detach().to(dtype=torch.float32)
......@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset):
return consistency_check
def _get_shape(self, image_path):
full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(image_path)).get_fdata()
return image.shape
else:
# Load 2D image
image = Image.open(str(image_path))
return image.size
# Load 3D volume
image = nib.load(str(image_path)).get_fdata()
return image.shape
def check_resize(
orig_shape: tuple,
resize: tuple,
n_channels: int,
is_3d: bool
n_channels: int,
) -> tuple:
"""
Checks and adjusts the resize dimensions based on the original shape and the number of channels.
......@@ -164,7 +125,6 @@ def check_resize(
orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns:
tuple: Final resize dimensions.
......@@ -174,23 +134,22 @@ def check_resize(
"""
# 3D images
if is_3d:
orig_d, orig_h, orig_w = orig_shape
final_d = resize[0] if resize[0] else orig_d
final_h = resize[1] if resize[1] else orig_h
final_w = resize[2] if resize[2] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
else:
final_d = (orig_d // 2**n_channels) * 2**n_channels
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
orig_d, orig_h, orig_w = orig_shape
final_d = resize[0] if resize[0] else orig_d
final_h = resize[1] if resize[1] else orig_h
final_w = resize[2] if resize[2] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
else:
final_d = (orig_d // 2**n_channels) * 2**n_channels
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth
if final_d == 0 or final_h == 0 or final_w == 0:
......@@ -205,35 +164,6 @@ def check_resize(
return final_d, final_h, final_w
# 2D images
else:
orig_h, orig_w = orig_shape
final_h = resize[0] if resize[0] else orig_h
final_w = resize[1] if resize[1] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
else:
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth
if final_h == 0 or final_w == 0:
msg = "The size of the image is too small compared to the depth of the UNet. \
Choose a different 'resize' and/or a smaller model."
raise ValueError(msg)
if final_h != orig_h or final_w != orig_w:
log.warning(f"The image size doesn't match the Unet model's depth. \
The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
return final_h, final_w
def prepare_datasets(
path: str,
val_fraction: float,
......@@ -262,23 +192,12 @@ def prepare_datasets(
# Determine if the dataset is 2D or 3D by checking the first image
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
full_suffix = ''.join(first_img.suffixes)
# TODO: Support more formats for 3D images
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
is_3d = True
else:
# Load 2D image
image = Image.open(str(first_img))
orig_shape = image.size[:2]
is_3d = False
# Load 3D volume
image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
final_shape = check_resize(orig_shape, resize, n_channels, is_3d)
final_shape = check_resize(orig_shape, resize, n_channels)
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment