Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
fb0797c5
Commit
fb0797c5
authored
1 year ago
by
ofhkr
Browse files
Options
Downloads
Patches
Plain Diff
update: train val test dataloader for several types of stored data.
parent
ba974ff9
No related branches found
No related tags found
1 merge request
!47
(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
qim3d/utils/data.py
+232
-73
232 additions, 73 deletions
qim3d/utils/data.py
qim3d/utils/internal_tools.py
+12
-1
12 additions, 1 deletion
qim3d/utils/internal_tools.py
with
244 additions
and
74 deletions
qim3d/utils/data.py
+
232
−
73
View file @
fb0797c5
...
@@ -2,56 +2,72 @@
...
@@ -2,56 +2,72 @@
from
pathlib
import
Path
from
pathlib
import
Path
from
PIL
import
Image
from
PIL
import
Image
from
qim3d.io.logger
import
log
from
qim3d.io.logger
import
log
from
qim3d.utils.internal_tools
import
find_one_image
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
os
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
"""
'''
Custom Dataset class for building a PyTorch dataset.
Custom Dataset class for building a PyTorch dataset.
Case 1: There are no folder - all images and targets are stored in the same data directory.
Args:
The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
root_path (str): The root directory path of the dataset.
split (str, optional): The split of the dataset, either
"
train
"
or
"
test
"
.
|-- data
Default is
"
train
"
.
|-- img1.tif
transform (callable, optional): A callable function or transformation to
|-- img1_mask.tif
be applied to the data. Default is None.
|-- img2.tif
|-- img2_mask.tif
Raises:
|-- ...
ValueError: If the provided split is not valid (neither
"
train
"
nor
"
test
"
).
Case 2: There are two folders - one with all the images and one with all the targets.
Attributes:
split (str): The split of the dataset (
"
train
"
or
"
test
"
).
|-- data
transform (callable): The transformation to be applied to the data.
|-- images
sample_images (list): A list containing the paths to the sample images in the dataset.
|-- img1.tif
sample_targets (list): A list containing the paths to the corresponding target images
|-- img2.tif
in the dataset.
|-- ...
|-- masks
Methods:
|-- img1_mask.tif
__len__(): Returns the total number of samples in the dataset.
|-- img2_mask.tif
__getitem__(idx): Returns the image and its target segmentation at the given index.
|-- ...
Usage:
Case 3: There are many folders - each folder with a case (eg. patient) and multiple images.
dataset = Dataset(root_path=
"
path/to/dataset
"
, split=
"
train
"
,
transform=albumentations.Compose([ToTensorV2()]))
|-- data
image, target = dataset[idx]
|-- patient1
"""
|-- p1_img1.tif
def
__init__
(
self
,
root_path
:
str
,
split
=
"
train
"
,
transform
=
None
):
|-- p1_img1_mask.tif
|-- p1_img2.tif
|-- p1_img2_mask.tif
|-- p1_img3.tif
|-- p1_img3_mask.tif
|-- ...
|-- patient2
|-- p2_img1.tif
|-- p2_img1_mask.tif
|-- p2_img2.tif
|-- p2_img2_mask.tif
|-- p2_img3.tif
|-- p2_img3_mask.tif
|-- ...
|-- ...
'''
def
__init__
(
self
,
root_path
:
str
,
transform
=
None
):
super
().
__init__
()
super
().
__init__
()
# Check if split is valid
self
.
root_path
=
root_path
if
split
not
in
[
"
train
"
,
"
test
"
]:
raise
ValueError
(
"
Split must be either train or test
"
)
self
.
split
=
split
self
.
transform
=
transform
self
.
transform
=
transform
path
=
Path
(
root_path
)
/
split
# scans folders
self
.
_data_scan
()
# finds the images and targets given the folder setup
self
.
_find_samples
()
self
.
sample_images
=
[
file
for
file
in
sorted
((
path
/
"
images
"
).
iterdir
())]
self
.
sample_targets
=
[
file
for
file
in
sorted
((
path
/
"
labels
"
).
iterdir
())]
assert
len
(
self
.
sample_images
)
==
len
(
self
.
sample_targets
)
assert
len
(
self
.
sample_images
)
==
len
(
self
.
sample_targets
)
# checking the characteristics of the dataset
# checking the characteristics of the dataset
...
@@ -77,17 +93,108 @@ class Dataset(torch.utils.data.Dataset):
...
@@ -77,17 +93,108 @@ class Dataset(torch.utils.data.Dataset):
return
image
,
target
return
image
,
target
def
_data_scan
(
self
):
'''
Find out which of the three categories the data belongs to.
'''
# how many folders there are:
files
=
os
.
listdir
(
self
.
root_path
)
n_folders
=
0
folder_names
=
[]
for
f
in
files
:
if
os
.
path
.
isdir
(
Path
(
self
.
root_path
,
f
)):
n_folders
+=
1
folder_names
.
append
(
f
)
self
.
n_folders
=
n_folders
self
.
folder_names
=
folder_names
def
_find_samples
(
self
):
'''
Scans and retrieves the images and targets from their given folder configuration.
'''
target_folder_names
=
[
'
mask
'
,
'
label
'
,
'
target
'
]
# Case 1
if
self
.
n_folders
==
0
:
sample_images
=
[]
sample_targets
=
[]
for
file
in
os
.
listdir
(
self
.
root_path
):
# checks if a label extension is in the filename
if
any
(
ext
in
file
.
lower
()
for
ext
in
target_folder_names
):
sample_targets
.
append
(
Path
(
self
.
root_path
,
file
))
# otherwise the file is assumed to be the image
else
:
sample_images
.
append
(
Path
(
self
.
root_path
,
file
))
self
.
sample_images
=
sorted
(
sample_images
)
self
.
sample_targets
=
sorted
(
sample_targets
)
# Case 2
elif
self
.
n_folders
==
2
:
# if the first folder contains the targets:
if
any
(
ext
in
self
.
folder_names
[
0
].
lower
()
for
ext
in
target_folder_names
):
images
=
self
.
folders_names
[
1
]
targets
=
self
.
folder_names
[
0
]
# if the second folder contains the targets:
elif
any
(
ext
in
self
.
folder_names
[
1
].
lower
()
for
ext
in
target_folder_names
):
images
=
self
.
folder_names
[
0
]
targets
=
self
.
folder_names
[
1
]
else
:
raise
ValueError
(
'
Folder names do not match categories such as
"
mask
"
,
"
label
"
or
"
target
"
.
'
)
self
.
sample_images
=
[
image
for
image
in
sorted
(
Path
(
self
.
root_path
,
images
).
iterdir
())]
self
.
sample_targets
=
[
target
for
target
in
sorted
(
Path
(
self
.
root_path
,
targets
).
iterdir
())]
# Case 3
elif
self
.
n_folders
>
2
:
sample_images
=
[]
sample_targets
=
[]
for
folder
in
os
.
listdir
(
self
.
root_path
):
# if some files are not a folder
if
not
os
.
path
.
isdir
(
Path
(
self
.
root_path
,
folder
)):
raise
NotImplementedError
(
f
'
The current data structure is not supported.
{
Path
(
self
.
root_path
,
folder
)
}
is not a folder.
'
)
for
file
in
os
.
listdir
(
Path
(
self
.
root_path
,
folder
)):
# if files are not images:
if
not
os
.
path
.
isfile
(
Path
(
self
.
root_path
,
folder
,
file
)):
raise
NotImplementedError
(
f
'
The current data structure is not supported.
{
Path
(
self
.
root_path
,
folder
,
file
)
}
is not a file.
'
)
# checks if a label extension is in the filename
if
any
(
ext
in
file
for
ext
in
target_folder_names
):
sample_targets
.
append
(
Path
(
self
.
root_path
,
folder
,
file
))
# otherwise the file is assumed to be the image
else
:
sample_images
.
append
(
Path
(
self
.
root_path
,
folder
,
file
))
self
.
sample_images
=
sorted
(
sample_images
)
self
.
sample_targets
=
sorted
(
sample_targets
)
else
:
raise
NotImplementedError
(
'
The current data structure is not supported.
'
)
# TODO: working with images of different sizes
# TODO: working with images of different sizes
def
check_shape_consistency
(
self
,
sample_images
):
def
check_shape_consistency
(
self
,
sample_images
):
image_shapes
=
[]
image_shapes
=
[]
for
image_path
in
sample_images
:
for
image_path
in
sample_images
[:
100
]
:
image_shape
=
self
.
_get_shape
(
image_path
)
image_shape
=
self
.
_get_shape
(
image_path
)
image_shapes
.
append
(
image_shape
)
image_shapes
.
append
(
image_shape
)
# check if all images have the same size.
# check if all images have the same size.
consistency_check
=
all
(
i
==
image_shapes
[
0
]
for
i
in
image_shapes
)
unique_shapes
=
len
(
set
(
image_shapes
)
)
if
not
consistency_check
:
if
unique_shapes
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"
Only images of all the same size can be processed at the moment
"
"
Only images of all the same size can be processed at the moment
"
)
)
...
@@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset):
...
@@ -95,7 +202,6 @@ class Dataset(torch.utils.data.Dataset):
log
.
debug
(
log
.
debug
(
"
Images are all the same size!
"
"
Images are all the same size!
"
)
)
return
consistency_check
def
_get_shape
(
self
,
image_path
):
def
_get_shape
(
self
,
image_path
):
return
Image
.
open
(
str
(
image_path
)).
size
return
Image
.
open
(
str
(
image_path
)).
size
...
@@ -133,47 +239,100 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
...
@@ -133,47 +239,100 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
return
h_adjust
,
w_adjust
return
h_adjust
,
w_adjust
def
prepare_datasets
(
path
:
str
,
val_fraction
:
float
,
model
,
augmentation
):
def
prepare_datasets
(
"""
path
:
str
,
Splits and augments the train/validation/test datasets.
val_fraction
:
float
,
test_fraction
:
float
,
Args:
model
,
path (str): Path to the dataset.
augmentation
,
val_fraction (float): Fraction of the data for the validation set.
train_folder
:
str
=
None
,
model (torch.nn.Module): PyTorch Model.
val_folder
:
str
=
None
,
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
test_folder
:
str
=
None
):
'''
Splits and augments the train/validation/test datasets
Raises:
'''
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
"""
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
raise
ValueError
(
"
The validation fraction must be a float between 0 and 1.
"
)
raise
ValueError
(
"
The validation fraction must be a float between 0 and 1.
"
)
resize
=
augmentation
.
resize
if
not
isinstance
(
test_fraction
,
float
)
or
not
(
0
<=
test_fraction
<
1
):
n_channels
=
len
(
model
.
channels
)
raise
ValueError
(
"
The test fraction must be a float between 0 and 1.
"
)
if
(
val_fraction
+
test_fraction
)
>=
1
:
print
(
int
(
val_fraction
+
test_fraction
)
*
100
)
raise
ValueError
(
f
"
The validation and test fractions cover
{
int
((
val_fraction
+
test_fraction
)
*
100
)
}
%.
"
"
Make sure to lower it below 100%, and include some place for the training data.
"
)
# taking the size of the 1st image in the dataset
# find one image:
im_path
=
Path
(
path
)
/
'
train
'
image
=
Image
.
open
(
find_one_image
(
path
=
path
))
first_img
=
sorted
((
im_path
/
"
images
"
).
iterdir
())[
0
]
image
=
Image
.
open
(
str
(
first_img
))
orig_h
,
orig_w
=
image
.
size
[:
2
]
orig_h
,
orig_w
=
image
.
size
[:
2
]
resize
=
augmentation
.
resize
n_channels
=
len
(
model
.
channels
)
final_h
,
final_w
=
check_resize
(
orig_h
,
orig_w
,
resize
,
n_channels
)
final_h
,
final_w
=
check_resize
(
orig_h
,
orig_w
,
resize
,
n_channels
)
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
train
'
))
# change number of channels in UNet if needed
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
validation
'
))
if
len
(
np
.
array
(
image
).
shape
)
>
2
:
test_set
=
Dataset
(
root_path
=
path
,
split
=
'
test
'
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
'
test
'
))
model
.
img_channels
=
np
.
array
(
image
).
shape
[
2
]
model
.
update_params
()
# Only Train and Test folders are given, splits Train into Train/Val.
if
train_folder
and
test_folder
and
not
val_folder
:
log
.
info
(
'
Only train and test given, splitting train_folder with val fraction.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
Path
(
path
,
test_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
split_idx
=
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
indices
=
torch
.
randperm
(
len
(
train_set
))
indices
=
torch
.
randperm
(
len
(
train_set
))
split_idx
=
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[
split_idx
:])
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[
split_idx
:])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[:
split_idx
])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[:
split_idx
])
# Only Train and Val folder are given.
elif
train_folder
and
val_folder
and
not
test_folder
:
log
.
info
(
'
Only train and validation folder provided, will not be able to make inference on test data.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
None
# All Train/Val/Test folders are given.
elif
train_folder
and
val_folder
and
test_folder
:
log
.
info
(
'
Retrieving data from train, validation and test folder.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
Path
(
path
,
test_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
# None of the train/val/test folders are given:
elif
not
(
train_folder
or
val_folder
or
test_folder
):
log
.
info
(
'
No specific train/validation/test folders given. Splitting the data into train/validation/test sets.
'
)
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
indices
=
torch
.
randperm
(
len
(
train_set
))
train_idx
=
int
(
np
.
floor
((
1
-
val_fraction
-
test_fraction
)
*
len
(
train_set
)))
val_idx
=
train_idx
+
int
(
np
.
floor
(
val_fraction
*
len
(
train_set
)))
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[:
train_idx
])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[
train_idx
:
val_idx
])
test_set
=
torch
.
utils
.
data
.
Subset
(
test_set
,
indices
[
val_idx
:])
else
:
raise
ValueError
(
"
Your folder configuration cannot be recognized.
"
"
Give a path to the dataset, or paths to the train/validation/test folders.
"
)
return
train_set
,
val_set
,
test_set
return
train_set
,
val_set
,
test_set
def
prepare_dataloaders
(
train_set
,
val_set
,
test_set
,
batch_size
,
shuffle_train
=
True
,
num_workers
=
0
,
pin_memory
=
False
):
def
prepare_dataloaders
(
train_set
,
val_set
,
test_set
,
batch_size
,
shuffle_train
=
True
,
num_workers
=
0
,
pin_memory
=
False
):
"""
"""
Prepares the dataloaders for model training.
Prepares the dataloaders for model training.
...
...
This diff is collapsed.
Click to expand it.
qim3d/utils/internal_tools.py
+
12
−
1
View file @
fb0797c5
...
@@ -304,3 +304,14 @@ def get_css():
...
@@ -304,3 +304,14 @@ def get_css():
css_content
=
file
.
read
()
css_content
=
file
.
read
()
return
css_content
return
css_content
def
find_one_image
(
path
):
for
entry
in
os
.
scandir
(
path
):
if
entry
.
is_dir
():
return
find_one_image
(
entry
.
path
)
elif
entry
.
is_file
():
if
any
(
entry
.
path
.
endswith
(
imagetype
)
for
imagetype
in
[
'
jpg
'
,
'
jpeg
'
,
'
tif
'
,
'
tiff
'
,
'
png
'
,
'
PNG
'
]):
return
entry
.
path
# If all folders/sub-folders do not have anything:
raise
ValueError
(
'
No Images Found.
'
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment