Skip to content
Snippets Groups Projects
Commit ddc40406 authored by s193396's avatar s193396
Browse files

removed 2D dataloader

parent 6f8cb981
No related branches found
No related tags found
No related merge requests found
...@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset): ...@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset):
image_path = self.sample_images[idx] image_path = self.sample_images[idx]
target_path = self.sample_targets[idx] target_path = self.sample_targets[idx]
full_suffix = ''.join(image_path.suffixes) # Load 3D volume
image_data = nib.load(str(image_path))
if full_suffix in ['.nii', '.nii.gz']: target_data = nib.load(str(target_path))
# Load 3D volume
image_data = nib.load(str(image_path))
target_data = nib.load(str(target_path))
# Get data from volume
image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# Add extra channel dimension
image = np.expand_dims(image, axis=0)
target = np.expand_dims(target, axis=0)
else:
# Load 2D image
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
# Grayscale image
if len(image.shape) == 2 and len(target.shape) == 2:
# Add channel dimension # Get data from volume
image = np.expand_dims(image, axis=0) image = np.asarray(image_data.dataobj, dtype=image_data.get_data_dtype())
target = np.expand_dims(target, axis=0) target = np.asarray(target_data.dataobj, dtype=target_data.get_data_dtype())
# RGB image
elif len(image.shape) == 3 and len(target.shape) == 3:
# Convert to (C, H, W) # Add extra channel dimension
image = image.transpose((2, 0, 1)) image = np.expand_dims(image, axis=0)
target = target.transpose((2, 0, 1)) target = np.expand_dims(target, axis=0)
if self.transform: if self.transform:
transformed = self.transform({"image": image, "label": target}) transformed = self.transform({"image": image, "label": target})
image = transformed["image"] image = transformed["image"]
target = transformed["label"] target = transformed["label"]
# image = self.transform(image) # uint8
# target = self.transform(target) # int32
# TODO: Which dtype?
image = image.clone().detach().to(dtype=torch.float32) image = image.clone().detach().to(dtype=torch.float32)
target = target.clone().detach().to(dtype=torch.float32) target = target.clone().detach().to(dtype=torch.float32)
...@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset): ...@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset):
return consistency_check return consistency_check
def _get_shape(self, image_path): def _get_shape(self, image_path):
full_suffix = ''.join(image_path.suffixes)
if full_suffix in ['.nii', '.nii.gz']:
# Load 3D volume
image = nib.load(str(image_path)).get_fdata()
return image.shape
else:
# Load 2D image
image = Image.open(str(image_path))
return image.size
# Load 3D volume
image = nib.load(str(image_path)).get_fdata()
return image.shape
def check_resize( def check_resize(
orig_shape: tuple, orig_shape: tuple,
resize: tuple, resize: tuple,
n_channels: int, n_channels: int,
is_3d: bool
) -> tuple: ) -> tuple:
""" """
Checks and adjusts the resize dimensions based on the original shape and the number of channels. Checks and adjusts the resize dimensions based on the original shape and the number of channels.
...@@ -164,7 +125,6 @@ def check_resize( ...@@ -164,7 +125,6 @@ def check_resize(
orig_shape (tuple): Original shape of the image. orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions. resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model. n_channels (int): Number of channels in the model.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns: Returns:
tuple: Final resize dimensions. tuple: Final resize dimensions.
...@@ -174,23 +134,22 @@ def check_resize( ...@@ -174,23 +134,22 @@ def check_resize(
""" """
# 3D images # 3D images
if is_3d: orig_d, orig_h, orig_w = orig_shape
orig_d, orig_h, orig_w = orig_shape final_d = resize[0] if resize[0] else orig_d
final_d = resize[0] if resize[0] else orig_d final_h = resize[1] if resize[1] else orig_h
final_h = resize[1] if resize[1] else orig_h final_w = resize[2] if resize[2] else orig_w
final_w = resize[2] if resize[2] else orig_w
# Finding suitable size to upsize with padding
# Finding suitable size to upsize with padding if resize == 'padding':
if resize == 'padding': final_d = (orig_d // 2**n_channels + 1) * 2**n_channels
final_d = (orig_d // 2**n_channels + 1) * 2**n_channels final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
# Finding suitable size to downsize with crop / resize else:
else: final_d = (orig_d // 2**n_channels) * 2**n_channels
final_d = (orig_d // 2**n_channels) * 2**n_channels final_h = (orig_h // 2**n_channels) * 2**n_channels
final_h = (orig_h // 2**n_channels) * 2**n_channels final_w = (orig_w // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth # Check if the image size is too small compared to the model's depth
if final_d == 0 or final_h == 0 or final_w == 0: if final_d == 0 or final_h == 0 or final_w == 0:
...@@ -205,35 +164,6 @@ def check_resize( ...@@ -205,35 +164,6 @@ def check_resize(
return final_d, final_h, final_w return final_d, final_h, final_w
# 2D images
else:
orig_h, orig_w = orig_shape
final_h = resize[0] if resize[0] else orig_h
final_w = resize[1] if resize[1] else orig_w
# Finding suitable size to upsize with padding
if resize == 'padding':
final_h = (orig_h // 2**n_channels + 1) * 2**n_channels
final_w = (orig_w // 2**n_channels + 1) * 2**n_channels
# Finding suitable size to downsize with crop / resize
else:
final_h = (orig_h // 2**n_channels) * 2**n_channels
final_w = (orig_w // 2**n_channels) * 2**n_channels
# Check if the image size is too small compared to the model's depth
if final_h == 0 or final_w == 0:
msg = "The size of the image is too small compared to the depth of the UNet. \
Choose a different 'resize' and/or a smaller model."
raise ValueError(msg)
if final_h != orig_h or final_w != orig_w:
log.warning(f"The image size doesn't match the Unet model's depth. \
The image is changed with '{resize}', from {orig_h, orig_w} to {final_h, final_w}.")
return final_h, final_w
def prepare_datasets( def prepare_datasets(
path: str, path: str,
val_fraction: float, val_fraction: float,
...@@ -262,23 +192,12 @@ def prepare_datasets( ...@@ -262,23 +192,12 @@ def prepare_datasets(
# Determine if the dataset is 2D or 3D by checking the first image # Determine if the dataset is 2D or 3D by checking the first image
im_path = Path(path) / 'train' im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0] first_img = sorted((im_path / "images").iterdir())[0]
full_suffix = ''.join(first_img.suffixes)
# TODO: Support more formats for 3D images # Load 3D volume
if full_suffix in ['.nii', '.nii.gz']: image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
# Load 3D volume
image = nib.load(str(first_img)).get_fdata()
orig_shape = image.shape
is_3d = True
else:
# Load 2D image
image = Image.open(str(first_img))
orig_shape = image.size[:2]
is_3d = False
final_shape = check_resize(orig_shape, resize, n_channels, is_3d) final_shape = check_resize(orig_shape, resize, n_channels)
train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train)) train_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_train))
val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation)) val_set = Dataset(root_path=path, transform=augmentation.augment(final_shape, level = augmentation.transform_validation))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment