Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
fb0797c5
Commit
fb0797c5
authored
1 year ago
by
ofhkr
Browse files
Options
Downloads
Patches
Plain Diff
update: train val test dataloader for several types of stored data.
parent
ba974ff9
No related branches found
Branches containing commit
No related tags found
1 merge request
!47
(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
qim3d/utils/data.py
+232
-73
232 additions, 73 deletions
qim3d/utils/data.py
qim3d/utils/internal_tools.py
+12
-1
12 additions, 1 deletion
qim3d/utils/internal_tools.py
with
244 additions
and
74 deletions
qim3d/utils/data.py
+
232
−
73
View file @
fb0797c5
...
...
@@ -2,64 +2,80 @@
from
pathlib
import
Path
from
PIL
import
Image
from
qim3d.io.logger
import
log
from
qim3d.utils.internal_tools
import
find_one_image
from
torch.utils.data
import
DataLoader
import
os
import
torch
import
numpy
as
np
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Custom Dataset class for building a PyTorch dataset.
Args:
root_path (str): The root directory path of the dataset.
split (str, optional): The split of the dataset, either
"
train
"
or
"
test
"
.
Default is
"
train
"
.
transform (callable, optional): A callable function or transformation to
be applied to the data. Default is None.
'''
Custom Dataset class for building a PyTorch dataset.
Raises:
ValueError: If the provided split is not valid (neither
"
train
"
nor
"
test
"
).
Attributes:
split (str): The split of the dataset (
"
train
"
or
"
test
"
).
transform (callable): The transformation to be applied to the data.
sample_images (list): A list containing the paths to the sample images in the dataset.
sample_targets (list): A list containing the paths to the corresponding target images
in the dataset.
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path=
"
path/to/dataset
"
, split=
"
train
"
,
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
"""
def
__init__
(
self
,
root_path
:
str
,
split
=
"
train
"
,
transform
=
None
):
Case 1: There are no folder - all images and targets are stored in the same data directory.
The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
|-- data
|-- img1.tif
|-- img1_mask.tif
|-- img2.tif
|-- img2_mask.tif
|-- ...
Case 2: There are two folders - one with all the images and one with all the targets.
|-- data
|-- images
|-- img1.tif
|-- img2.tif
|-- ...
|-- masks
|-- img1_mask.tif
|-- img2_mask.tif
|-- ...
Case 3: There are many folders - each folder with a case (eg. patient) and multiple images.
|-- data
|-- patient1
|-- p1_img1.tif
|-- p1_img1_mask.tif
|-- p1_img2.tif
|-- p1_img2_mask.tif
|-- p1_img3.tif
|-- p1_img3_mask.tif
|-- ...
|-- patient2
|-- p2_img1.tif
|-- p2_img1_mask.tif
|-- p2_img2.tif
|-- p2_img2_mask.tif
|-- p2_img3.tif
|-- p2_img3_mask.tif
|-- ...
|-- ...
'''
def
__init__
(
self
,
root_path
:
str
,
transform
=
None
):
super
().
__init__
()
# Check if split is valid
if
split
not
in
[
"
train
"
,
"
test
"
]:
raise
ValueError
(
"
Split must be either train or test
"
)
self
.
split
=
split
self
.
root_path
=
root_path
self
.
transform
=
transform
path
=
Path
(
root_path
)
/
split
self
.
sample_images
=
[
file
for
file
in
sorted
((
path
/
"
images
"
).
iterdir
())]
self
.
sample_targets
=
[
file
for
file
in
sorted
((
path
/
"
labels
"
).
iterdir
())]
assert
len
(
self
.
sample_images
)
==
len
(
self
.
sample_targets
)
# scans folders
self
.
_data_scan
()
# finds the images and targets given the folder setup
self
.
_find_samples
()
assert
len
(
self
.
sample_images
)
==
len
(
self
.
sample_targets
)
# checking the characteristics of the dataset
self
.
check_shape_consistency
(
self
.
sample_images
)
def
__len__
(
self
):
return
len
(
self
.
sample_images
)
def
__getitem__
(
self
,
idx
):
image_path
=
self
.
sample_images
[
idx
]
target_path
=
self
.
sample_targets
[
idx
]
...
...
@@ -75,19 +91,110 @@ class Dataset(torch.utils.data.Dataset):
target
=
transformed
[
"
mask
"
]
return
image
,
target
def
_data_scan
(
self
):
'''
Find out which of the three categories the data belongs to.
'''
# how many folders there are:
files
=
os
.
listdir
(
self
.
root_path
)
n_folders
=
0
folder_names
=
[]
for
f
in
files
:
if
os
.
path
.
isdir
(
Path
(
self
.
root_path
,
f
)):
n_folders
+=
1
folder_names
.
append
(
f
)
self
.
n_folders
=
n_folders
self
.
folder_names
=
folder_names
def
_find_samples
(
self
):
'''
Scans and retrieves the images and targets from their given folder configuration.
'''
target_folder_names
=
[
'
mask
'
,
'
label
'
,
'
target
'
]
# Case 1
if
self
.
n_folders
==
0
:
sample_images
=
[]
sample_targets
=
[]
for
file
in
os
.
listdir
(
self
.
root_path
):
# checks if a label extension is in the filename
if
any
(
ext
in
file
.
lower
()
for
ext
in
target_folder_names
):
sample_targets
.
append
(
Path
(
self
.
root_path
,
file
))
# otherwise the file is assumed to be the image
else
:
sample_images
.
append
(
Path
(
self
.
root_path
,
file
))
self
.
sample_images
=
sorted
(
sample_images
)
self
.
sample_targets
=
sorted
(
sample_targets
)
# Case 2
elif
self
.
n_folders
==
2
:
# if the first folder contains the targets:
if
any
(
ext
in
self
.
folder_names
[
0
].
lower
()
for
ext
in
target_folder_names
):
images
=
self
.
folders_names
[
1
]
targets
=
self
.
folder_names
[
0
]
# if the second folder contains the targets:
elif
any
(
ext
in
self
.
folder_names
[
1
].
lower
()
for
ext
in
target_folder_names
):
images
=
self
.
folder_names
[
0
]
targets
=
self
.
folder_names
[
1
]
else
:
raise
ValueError
(
'
Folder names do not match categories such as
"
mask
"
,
"
label
"
or
"
target
"
.
'
)
self
.
sample_images
=
[
image
for
image
in
sorted
(
Path
(
self
.
root_path
,
images
).
iterdir
())]
self
.
sample_targets
=
[
target
for
target
in
sorted
(
Path
(
self
.
root_path
,
targets
).
iterdir
())]
# Case 3
elif
self
.
n_folders
>
2
:
sample_images
=
[]
sample_targets
=
[]
for
folder
in
os
.
listdir
(
self
.
root_path
):
# if some files are not a folder
if
not
os
.
path
.
isdir
(
Path
(
self
.
root_path
,
folder
)):
raise
NotImplementedError
(
f
'
The current data structure is not supported.
{
Path
(
self
.
root_path
,
folder
)
}
is not a folder.
'
)
for
file
in
os
.
listdir
(
Path
(
self
.
root_path
,
folder
)):
# if files are not images:
if
not
os
.
path
.
isfile
(
Path
(
self
.
root_path
,
folder
,
file
)):
raise
NotImplementedError
(
f
'
The current data structure is not supported.
{
Path
(
self
.
root_path
,
folder
,
file
)
}
is not a file.
'
)
# checks if a label extension is in the filename
if
any
(
ext
in
file
for
ext
in
target_folder_names
):
sample_targets
.
append
(
Path
(
self
.
root_path
,
folder
,
file
))
# otherwise the file is assumed to be the image
else
:
sample_images
.
append
(
Path
(
self
.
root_path
,
folder
,
file
))
self
.
sample_images
=
sorted
(
sample_images
)
self
.
sample_targets
=
sorted
(
sample_targets
)
else
:
raise
NotImplementedError
(
'
The current data structure is not supported.
'
)
# TODO: working with images of different sizes
def
check_shape_consistency
(
self
,
sample_images
):
image_shapes
=
[]
for
image_path
in
sample_images
:
image_shapes
=
[]
for
image_path
in
sample_images
[:
100
]
:
image_shape
=
self
.
_get_shape
(
image_path
)
image_shapes
.
append
(
image_shape
)
# check if all images have the same size.
consistency_check
=
all
(
i
==
image_shapes
[
0
]
for
i
in
image_shapes
)
if
not
consistency_check
:
unique_shapes
=
len
(
set
(
image_shapes
)
)
if
unique_shapes
>
1
:
raise
NotImplementedError
(
"
Only images of all the same size can be processed at the moment
"
)
...
...
@@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset):
log
.
debug
(
"
Images are all the same size!
"
)
return
consistency_check
def
_get_shape
(
self
,
image_path
):
return
Image
.
open
(
str
(
image_path
)).
size
...
...
@@ -133,45 +239,98 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
return
h_adjust
,
w_adjust
def
prepare_datasets
(
path
:
str
,
val_fraction
:
float
,
model
,
augmentation
):
"""
Splits and augments the train/validation/test datasets.
def
prepare_datasets
(
path
:
str
,
val_fraction
:
float
,
test_fraction
:
float
,
model
,
augmentation
,
train_folder
:
str
=
None
,
val_folder
:
str
=
None
,
test_folder
:
str
=
None
):
'''
Splits and augments the train/validation/test datasets
Args:
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
'''
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
"""
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
raise
ValueError
(
"
The validation fraction must be a float between 0 and 1.
"
)
if
not
isinstance
(
test_fraction
,
float
)
or
not
(
0
<=
test_fraction
<
1
):
raise
ValueError
(
"
The test fraction must be a float between 0 and 1.
"
)
if
(
val_fraction
+
test_fraction
)
>=
1
:
print
(
int
(
val_fraction
+
test_fraction
)
*
100
)
raise
ValueError
(
f
"
The validation and test fractions cover
{
int
((
val_fraction
+
test_fraction
)
*
100
)
}
%.
"
"
Make sure to lower it below 100%, and include some place for the training data.
"
)
# find one image:
image
=
Image
.
open
(
find_one_image
(
path
=
path
))
orig_h
,
orig_w
=
image
.
size
[:
2
]
resize
=
augmentation
.
resize
n_channels
=
len
(
model
.
channels
)
# taking the size of the 1st image in the dataset
im_path
=
Path
(
path
)
/
'
train
'
first_img
=
sorted
((
im_path
/
"
images
"
).
iterdir
())[
0
]
image
=
Image
.
open
(
str
(
first_img
))
orig_h
,
orig_w
=
image
.
size
[:
2
]
final_h
,
final_w
=
check_resize
(
orig_h
,
orig_w
,
resize
,
n_channels
)
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
train
'
))
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
validation
'
))
test_set
=
Dataset
(
root_path
=
path
,
split
=
'
test
'
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
test
'
))
# change number of channels in UNet if needed
if
len
(
np
.
array
(
image
).
shape
)
>
2
:
model
.
img_channels
=
np
.
array
(
image
).
shape
[
2
]
model
.
update_params
()
# Only Train and Test folders are given, splits Train into Train/Val.
if
train_folder
and
test_folder
and
not
val_folder
:
log
.
info
(
'
Only train and test given, splitting train_folder with val fraction.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
Path
(
path
,
test_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
indices
=
torch
.
randperm
(
len
(
train_set
))
split_idx
=
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[
split_idx
:])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[:
split_idx
])
# Only Train and Val folder are given.
elif
train_folder
and
val_folder
and
not
test_folder
:
log
.
info
(
'
Only train and validation folder provided, will not be able to make inference on test data.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
None
# All Train/Val/Test folders are given.
elif
train_folder
and
val_folder
and
test_folder
:
log
.
info
(
'
Retrieving data from train, validation and test folder.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
Path
(
path
,
test_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
# None of the train/val/test folders are given:
elif
not
(
train_folder
or
val_folder
or
test_folder
):
log
.
info
(
'
No specific train/validation/test folders given. Splitting the data into train/validation/test sets.
'
)
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
split_idx
=
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
indices
=
torch
.
randperm
(
len
(
train_set
))
indices
=
torch
.
randperm
(
len
(
train_set
))
train_idx
=
int
(
np
.
floor
((
1
-
val_fraction
-
test_fraction
)
*
len
(
train_set
)))
val_idx
=
train_idx
+
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[
split_idx
:])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[:
split_idx
])
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[:
train_idx
])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[
train_idx
:
val_idx
])
test_set
=
torch
.
utils
.
data
.
Subset
(
test_set
,
indices
[
val_idx
:])
return
train_set
,
val_set
,
test_set
else
:
raise
ValueError
(
"
Your folder configuration cannot be recognized.
"
"
Give a path to the dataset, or paths to the train/validation/test folders.
"
)
return
train_set
,
val_set
,
test_set
def
prepare_dataloaders
(
train_set
,
val_set
,
test_set
,
batch_size
,
shuffle_train
=
True
,
num_workers
=
0
,
pin_memory
=
False
):
...
...
This diff is collapsed.
Click to expand it.
qim3d/utils/internal_tools.py
+
12
−
1
View file @
fb0797c5
...
...
@@ -303,4 +303,15 @@ def get_css():
with
open
(
css_path
,
'
r
'
)
as
file
:
css_content
=
file
.
read
()
return
css_content
\ No newline at end of file
return
css_content
def
find_one_image
(
path
):
for
entry
in
os
.
scandir
(
path
):
if
entry
.
is_dir
():
return
find_one_image
(
entry
.
path
)
elif
entry
.
is_file
():
if
any
(
entry
.
path
.
endswith
(
imagetype
)
for
imagetype
in
[
'
jpg
'
,
'
jpeg
'
,
'
tif
'
,
'
tiff
'
,
'
png
'
,
'
PNG
'
]):
return
entry
.
path
# If all folders/sub-folders do not have anything:
raise
ValueError
(
'
No Images Found.
'
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment