Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
9a9ada7d
Commit
9a9ada7d
authored
Jan 23, 2024
by
ofhkr
Browse files
Options
Downloads
Patches
Plain Diff
Added docstring to modified files.
parent
fb0797c5
Branches
Branches containing commit
No related tags found
1 merge request
!47
(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
docs/notebooks/Unet.ipynb
+19
-163
19 additions, 163 deletions
docs/notebooks/Unet.ipynb
qim3d/utils/data.py
+99
-42
99 additions, 42 deletions
qim3d/utils/data.py
with
118 additions
and
205 deletions
docs/notebooks/Unet.ipynb
+
19
−
163
View file @
9a9ada7d
This diff is collapsed.
Click to expand it.
qim3d/utils/data.py
+
99
−
42
View file @
9a9ada7d
...
...
@@ -12,7 +12,33 @@ import numpy as np
class
Dataset
(
torch
.
utils
.
data
.
Dataset
):
'''
Custom Dataset class for building a PyTorch dataset.
Args:
root_path (str): The root directory path of the dataset.
transform (callable, optional): A callable function or transformation to
be applied to the data. Default is None.
Raises:
ValueError: If the provided split is not valid (neither
"
train
"
nor
"
test
"
).
Attributes:
root_path (str): root directory path to the dataset.
transform (callable): The transformation to be applied to the data.
sample_images (list): A list containing the paths to the sample images in the dataset.
sample_targets (list): A list containing the paths to the corresponding target images
in the dataset.
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
_data_scan(): Finds how many folders are in the directory path as well as their names.
_find_samples(): Finds the images and targets according to one of the 3 datastructure cases.
Usage:
dataset = Dataset(root_path=
"
path/to/dataset
"
,
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
Notes:
Case 1: There are no folder - all images and targets are stored in the same data directory.
The image and corresponding target have similar names (eg: data1.tif, data1mask.tif)
...
...
@@ -112,6 +138,12 @@ class Dataset(torch.utils.data.Dataset):
def
_find_samples
(
self
):
'''
Scans and retrieves the images and targets from their given folder configuration.
Raises:
ValueError: in Case 2, if no folder contains any of the labels
'
mask
'
,
'
label
'
,
'
target
'
.
NotImplementedError: in Case 3, if a file is found among the list of folders.
NotImplementedError: in Case 3, if a folder is found among the list of files.
NotImplementedError: If the data structure does not fall into one of the three cases.
'''
target_folder_names
=
[
'
mask
'
,
'
label
'
,
'
target
'
]
...
...
@@ -241,16 +273,41 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
def
prepare_datasets
(
path
:
str
,
val_fraction
:
float
,
test_fraction
:
float
,
model
,
augmentation
,
val_fraction
:
float
=
0.1
,
test_fraction
:
float
=
0.1
,
train_folder
:
str
=
None
,
val_folder
:
str
=
None
,
test_folder
:
str
=
None
):
'''
Splits and augments the train/validation/test datasets
Args:
path (str): Path to the dataset.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
val_fraction (float, optional): Fraction of the data for the validation set.
test_fraction (float, optional): Fraction of the data for the test set.
train_folder (str, optional): Can be used to specify where the data for training data is located.
val_folder (str, optional): Can be used to specify where the data for validation data is located.
test_folder (str, optional): Can be used to specify where the data for testing data is located.
Raises:
ValueError: If the validation fraction is not a float, and is not between 0 and 1.
ValueError: If the test fraction is not a float, and is not between 0 and 1.
ValueError: If the sum of the validation and test fractions is equal or larger than 1.
ValueError: If the combination of train/val/test_folder strings isn
'
t enough to prepare the data for model training.
Usage:
# if all data stored together:
prepare_datasets(path=
"
path/to/dataset
"
, val_fraction = 0.2, test_fraction = 0.1,
model = qim3d.models.UNet(), augmentation = qim3d.utils.Augmentation())
# if data has be pre-configured into training/testing:
prepare_datasets(path=
"
path/to/dataset
"
, val_fraction = 0.2, test_fraction = 0.1,
model = qim3d.models.UNet(), augmentation = qim3d.utils.Augmentation(),
train_folder =
'
training_folder_name
'
, test_folder =
'
test_folder_name
'
)
'''
if
not
isinstance
(
val_fraction
,
float
)
or
not
(
0
<=
val_fraction
<
1
):
...
...
@@ -264,7 +321,7 @@ def prepare_datasets(
raise
ValueError
(
f
"
The validation and test fractions cover
{
int
((
val_fraction
+
test_fraction
)
*
100
)
}
%.
"
"
Make sure to lower it below 100%, and include some place for the training data.
"
)
#
f
ind one image:
#
F
ind
s
one image:
image
=
Image
.
open
(
find_one_image
(
path
=
path
))
orig_h
,
orig_w
=
image
.
size
[:
2
]
...
...
@@ -273,7 +330,7 @@ def prepare_datasets(
final_h
,
final_w
=
check_resize
(
orig_h
,
orig_w
,
resize
,
n_channels
)
#
c
hange number of channels in UNet if needed
#
C
hange number of channels in UNet if needed
if
len
(
np
.
array
(
image
).
shape
)
>
2
:
model
.
img_channels
=
np
.
array
(
image
).
shape
[
2
]
model
.
update_params
()
...
...
@@ -281,7 +338,7 @@ def prepare_datasets(
# Only Train and Test folders are given, splits Train into Train/Val.
if
train_folder
and
test_folder
and
not
val_folder
:
log
.
info
(
'
Only train and test given, splitting train_folder with val
fraction.
'
)
log
.
info
(
'
Only train and test given, splitting train_folder with val
_
fraction.
'
)
train_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
train
'
))
val_set
=
Dataset
(
root_path
=
Path
(
path
,
train_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
validation
'
))
test_set
=
Dataset
(
root_path
=
Path
(
path
,
test_folder
),
transform
=
augmentation
.
augment
(
final_h
,
final_w
,
type
=
'
test
'
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment