Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
ddc40406
Commit
ddc40406
authored
1 month ago
by
s193396
Browse files
Options
Downloads
Patches
Plain Diff
removed 2D dataloader
parent
6f8cb981
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
qim3d/ml/_data.py
+33
-114
33 additions, 114 deletions
qim3d/ml/_data.py
with
33 additions
and
114 deletions
qim3d/ml/_data.py
+
33
−
114
View file @
ddc40406
...
@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset):
...
@@ -65,53 +65,23 @@ class Dataset(torch.utils.data.Dataset):
image_path
=
self
.
sample_images
[
idx
]
image_path
=
self
.
sample_images
[
idx
]
target_path
=
self
.
sample_targets
[
idx
]
target_path
=
self
.
sample_targets
[
idx
]
full_suffix
=
''
.
join
(
image_path
.
suffixes
)
# Load 3D volume
image_data
=
nib
.
load
(
str
(
image_path
))
if
full_suffix
in
[
'
.nii
'
,
'
.nii.gz
'
]:
target_data
=
nib
.
load
(
str
(
target_path
))
# Load 3D volume
image_data
=
nib
.
load
(
str
(
image_path
))
target_data
=
nib
.
load
(
str
(
target_path
))
# Get data from volume
image
=
np
.
asarray
(
image_data
.
dataobj
,
dtype
=
image_data
.
get_data_dtype
())
target
=
np
.
asarray
(
target_data
.
dataobj
,
dtype
=
target_data
.
get_data_dtype
())
# Add extra channel dimension
image
=
np
.
expand_dims
(
image
,
axis
=
0
)
target
=
np
.
expand_dims
(
target
,
axis
=
0
)
else
:
# Load 2D image
image
=
Image
.
open
(
str
(
image_path
))
image
=
np
.
array
(
image
)
target
=
Image
.
open
(
str
(
target_path
))
target
=
np
.
array
(
target
)
# Grayscale image
if
len
(
image
.
shape
)
==
2
and
len
(
target
.
shape
)
==
2
:
# Add channel dimension
# Get data from volume
image
=
np
.
expand_dims
(
image
,
axis
=
0
)
image
=
np
.
asarray
(
image_data
.
dataobj
,
dtype
=
image_data
.
get_data_dtype
())
target
=
np
.
expand_dims
(
target
,
axis
=
0
)
target
=
np
.
asarray
(
target_data
.
dataobj
,
dtype
=
target_data
.
get_data_dtype
())
# RGB image
elif
len
(
image
.
shape
)
==
3
and
len
(
target
.
shape
)
==
3
:
# Convert to (C, H, W)
# Add extra channel dimension
image
=
image
.
transpose
((
2
,
0
,
1
)
)
image
=
np
.
expand_dims
(
image
,
axis
=
0
)
target
=
target
.
transpose
((
2
,
0
,
1
)
)
target
=
np
.
expand_dims
(
target
,
axis
=
0
)
if
self
.
transform
:
if
self
.
transform
:
transformed
=
self
.
transform
({
"
image
"
:
image
,
"
label
"
:
target
})
transformed
=
self
.
transform
({
"
image
"
:
image
,
"
label
"
:
target
})
image
=
transformed
[
"
image
"
]
image
=
transformed
[
"
image
"
]
target
=
transformed
[
"
label
"
]
target
=
transformed
[
"
label
"
]
# image = self.transform(image) # uint8
# target = self.transform(target) # int32
# TODO: Which dtype?
image
=
image
.
clone
().
detach
().
to
(
dtype
=
torch
.
float32
)
image
=
image
.
clone
().
detach
().
to
(
dtype
=
torch
.
float32
)
target
=
target
.
clone
().
detach
().
to
(
dtype
=
torch
.
float32
)
target
=
target
.
clone
().
detach
().
to
(
dtype
=
torch
.
float32
)
...
@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset):
...
@@ -138,24 +108,15 @@ class Dataset(torch.utils.data.Dataset):
return
consistency_check
return
consistency_check
def
_get_shape
(
self
,
image_path
):
def
_get_shape
(
self
,
image_path
):
full_suffix
=
''
.
join
(
image_path
.
suffixes
)
if
full_suffix
in
[
'
.nii
'
,
'
.nii.gz
'
]:
# Load 3D volume
image
=
nib
.
load
(
str
(
image_path
)).
get_fdata
()
return
image
.
shape
else
:
# Load 2D image
image
=
Image
.
open
(
str
(
image_path
))
return
image
.
size
# Load 3D volume
image
=
nib
.
load
(
str
(
image_path
)).
get_fdata
()
return
image
.
shape
def
check_resize
(
def
check_resize
(
orig_shape
:
tuple
,
orig_shape
:
tuple
,
resize
:
tuple
,
resize
:
tuple
,
n_channels
:
int
,
n_channels
:
int
,
is_3d
:
bool
)
->
tuple
:
)
->
tuple
:
"""
"""
Checks and adjusts the resize dimensions based on the original shape and the number of channels.
Checks and adjusts the resize dimensions based on the original shape and the number of channels.
...
@@ -164,7 +125,6 @@ def check_resize(
...
@@ -164,7 +125,6 @@ def check_resize(
orig_shape (tuple): Original shape of the image.
orig_shape (tuple): Original shape of the image.
resize (tuple): Desired resize dimensions.
resize (tuple): Desired resize dimensions.
n_channels (int): Number of channels in the model.
n_channels (int): Number of channels in the model.
is_3d (bool): If True, the input data is 3D. Otherwise the input data is 2D. Defaults to True.
Returns:
Returns:
tuple: Final resize dimensions.
tuple: Final resize dimensions.
...
@@ -174,23 +134,22 @@ def check_resize(
...
@@ -174,23 +134,22 @@ def check_resize(
"""
"""
# 3D images
# 3D images
if
is_3d
:
orig_d
,
orig_h
,
orig_w
=
orig_shape
orig_d
,
orig_h
,
orig_w
=
orig_shape
final_d
=
resize
[
0
]
if
resize
[
0
]
else
orig_d
final_d
=
resize
[
0
]
if
resize
[
0
]
else
orig_d
final_h
=
resize
[
1
]
if
resize
[
1
]
else
orig_h
final_h
=
resize
[
1
]
if
resize
[
1
]
else
orig_h
final_w
=
resize
[
2
]
if
resize
[
2
]
else
orig_w
final_w
=
resize
[
2
]
if
resize
[
2
]
else
orig_w
# Finding suitable size to upsize with padding
# Finding suitable size to upsize with padding
if
resize
==
'
padding
'
:
if
resize
==
'
padding
'
:
final_d
=
(
orig_d
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_d
=
(
orig_d
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_h
=
(
orig_h
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_h
=
(
orig_h
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
+
1
)
*
2
**
n_channels
# Finding suitable size to downsize with crop / resize
# Finding suitable size to downsize with crop / resize
else
:
else
:
final_d
=
(
orig_d
//
2
**
n_channels
)
*
2
**
n_channels
final_d
=
(
orig_d
//
2
**
n_channels
)
*
2
**
n_channels
final_h
=
(
orig_h
//
2
**
n_channels
)
*
2
**
n_channels
final_h
=
(
orig_h
//
2
**
n_channels
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
)
*
2
**
n_channels
# Check if the image size is too small compared to the model's depth
# Check if the image size is too small compared to the model's depth
if
final_d
==
0
or
final_h
==
0
or
final_w
==
0
:
if
final_d
==
0
or
final_h
==
0
or
final_w
==
0
:
...
@@ -205,35 +164,6 @@ def check_resize(
...
@@ -205,35 +164,6 @@ def check_resize(
return
final_d
,
final_h
,
final_w
return
final_d
,
final_h
,
final_w
# 2D images
else
:
orig_h
,
orig_w
=
orig_shape
final_h
=
resize
[
0
]
if
resize
[
0
]
else
orig_h
final_w
=
resize
[
1
]
if
resize
[
1
]
else
orig_w
# Finding suitable size to upsize with padding
if
resize
==
'
padding
'
:
final_h
=
(
orig_h
//
2
**
n_channels
+
1
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
+
1
)
*
2
**
n_channels
# Finding suitable size to downsize with crop / resize
else
:
final_h
=
(
orig_h
//
2
**
n_channels
)
*
2
**
n_channels
final_w
=
(
orig_w
//
2
**
n_channels
)
*
2
**
n_channels
# Check if the image size is too small compared to the model's depth
if
final_h
==
0
or
final_w
==
0
:
msg
=
"
The size of the image is too small compared to the depth of the UNet.
\
Choose a different
'
resize
'
and/or a smaller model.
"
raise
ValueError
(
msg
)
if
final_h
!=
orig_h
or
final_w
!=
orig_w
:
log
.
warning
(
f
"
The image size doesn
'
t match the Unet model
'
s depth.
\
The image is changed with
'
{
resize
}
'
, from
{
orig_h
,
orig_w
}
to
{
final_h
,
final_w
}
.
"
)
return
final_h
,
final_w
def
prepare_datasets
(
def
prepare_datasets
(
path
:
str
,
path
:
str
,
val_fraction
:
float
,
val_fraction
:
float
,
...
@@ -262,23 +192,12 @@ def prepare_datasets(
...
@@ -262,23 +192,12 @@ def prepare_datasets(
# Determine if the dataset is 2D or 3D by checking the first image
# Determine if the dataset is 2D or 3D by checking the first image
im_path
=
Path
(
path
)
/
'
train
'
im_path
=
Path
(
path
)
/
'
train
'
first_img
=
sorted
((
im_path
/
"
images
"
).
iterdir
())[
0
]
first_img
=
sorted
((
im_path
/
"
images
"
).
iterdir
())[
0
]
full_suffix
=
''
.
join
(
first_img
.
suffixes
)
# TODO: Support more formats for 3D images
# Load 3D volume
if
full_suffix
in
[
'
.nii
'
,
'
.nii.gz
'
]:
image
=
nib
.
load
(
str
(
first_img
)).
get_fdata
()
orig_shape
=
image
.
shape
# Load 3D volume
image
=
nib
.
load
(
str
(
first_img
)).
get_fdata
()
orig_shape
=
image
.
shape
is_3d
=
True
else
:
# Load 2D image
image
=
Image
.
open
(
str
(
first_img
))
orig_shape
=
image
.
size
[:
2
]
is_3d
=
False
final_shape
=
check_resize
(
orig_shape
,
resize
,
n_channels
,
is_3d
)
final_shape
=
check_resize
(
orig_shape
,
resize
,
n_channels
)
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_shape
,
level
=
augmentation
.
transform_train
))
train_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_shape
,
level
=
augmentation
.
transform_train
))
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_shape
,
level
=
augmentation
.
transform_validation
))
val_set
=
Dataset
(
root_path
=
path
,
transform
=
augmentation
.
augment
(
final_shape
,
level
=
augmentation
.
transform_validation
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment