Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
39cca3af
Commit
39cca3af
authored
Jul 5, 2023
by
ofhkr
Committed by
fima
Jul 5, 2023
Browse files
Options
Downloads
Patches
Plain Diff
Vizualization update
parent
d9b70f14
Branches
Branches containing commit
No related tags found
1 merge request
!4
Vizualization update
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
docs/notebooks/Unet.ipynb
+100
-36
100 additions, 36 deletions
docs/notebooks/Unet.ipynb
qim3d/utils/__init__.py
+1
-0
1 addition, 0 deletions
qim3d/utils/__init__.py
qim3d/utils/models.py
+81
-0
81 additions, 0 deletions
qim3d/utils/models.py
qim3d/viz/img.py
+35
-74
35 additions, 74 deletions
qim3d/viz/img.py
with
217 additions
and
110 deletions
docs/notebooks/Unet.ipynb
+
100
−
36
View file @
39cca3af
This diff is collapsed.
Click to expand it.
qim3d/utils/__init__.py
+
1
−
0
View file @
39cca3af
from
.
import
internal_tools
from
.
import
internal_tools
from
.
import
models
from
.data
import
Dataset
from
.data
import
Dataset
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/utils/models.py
0 → 100644
+
81
−
0
View file @
39cca3af
"""
Tools performed with trained models.
"""
import
torch
def
inference
(
data
,
model
):
"""
Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list,
where each item is a tuple containing the input image tensor and the corresponding target label tensor.
The function checks the format and validity of the input data, ensures the model is in evaluation mode,
and generates predictions using the model. The input images, target labels, and predicted labels are returned
as a tuple.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The function does not assume the model is already in evaluation mode (model.eval()).
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
inference(data,model)
"""
# Get device
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data items must be tuples
"
)
# Check if data is torch tensors
for
element
in
data
[
0
]:
if
not
isinstance
(
element
,
torch
.
Tensor
):
raise
ValueError
(
"
Data items must consist of tensors
"
)
# Check if input image is (C,H,W) format
if
data
[
0
][
0
].
dim
()
==
3
and
(
data
[
0
][
0
].
shape
[
0
]
in
[
1
,
3
]):
pass
else
:
raise
ValueError
(
"
Input image must be (C,H,W) format
"
)
model
.
eval
()
# Make new list such that possible augmentations remain identical for all three rows
plot_data
=
[
data
[
idx
]
for
idx
in
range
(
len
(
data
))]
# Create input and target batch
inputs
=
torch
.
stack
([
item
[
0
]
for
item
in
plot_data
],
dim
=
0
).
to
(
device
)
targets
=
torch
.
stack
([
item
[
1
]
for
item
in
plot_data
],
dim
=
0
)
# Get output predictions
with
torch
.
no_grad
():
outputs
=
model
(
inputs
)
# Prepare data for plotting
inputs
=
inputs
.
cpu
().
squeeze
()
targets
=
targets
.
squeeze
()
if
outputs
.
shape
[
1
]
==
1
:
preds
=
outputs
.
cpu
().
squeeze
()
>
0.5
else
:
preds
=
outputs
.
cpu
().
argmax
(
axis
=
1
)
# if there is only one image
if
inputs
.
dim
()
==
2
:
inputs
=
inputs
.
unsqueeze
(
0
)
targets
=
targets
.
unsqueeze
(
0
)
preds
=
preds
.
unsqueeze
(
0
)
return
inputs
,
targets
,
preds
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/viz/img.py
+
35
−
74
View file @
39cca3af
...
@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
...
@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
fig
.
show
()
fig
.
show
()
def
grid_pred
(
def
grid_pred
(
in_targ_preds
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
):
data
,
model
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
"""
Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
):
"""
Displays a grid of input images, predicted segmentations, and ground truth segmentations.
Args:
Displays a grid of subplots representing different aspects of the input images and segmentations.
data (torch.utils.data.Dataset): A Torch dataset containing input image and
The grid includes the following rows:
ground truth label data.
- Row 1: Input images
model (torch.nn.Module): The trained network model used for predicting segmentations.
- Row 2: Predicted segmentations overlaying input images
num_images (int, optional): The maximum number of images to display. Defaults to 7.
- Row 3: Ground truth segmentations overlaying input images
cmap_im (str, optional): The colormap to be used for displaying input images.
- Row 4: Comparison between true and predicted segmentations overlaying input images
Defaults to
'
gray
'
.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to
'
viridis
'
.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Raises:
Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset.
ValueError: If the data items are not tuples or data items do not consist of tensors.
The function utilizes various color maps for visualization and applies transparency to the segmentations.
ValueError: If the input image is not in (C, H, W) format.
Notes:
Args:
- The number of displayed images is limited to the minimum between `num_images`
in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations.
and the length of the data.
num_images (int, optional): Number of images to display. Defaults to 7.
- The function does not assume that the model is already in evaluation mode (model.eval()).
cmap_im (str, optional): Color map for input images. Defaults to
"
gray
"
.
- The function will execute faster on a CUDA-enabled GPU.
cmap_segm (str, optional): Color map for segmentations. Defaults to
"
viridis
"
.
- The grid layout consists of three rows: input images, predicted segmentations,
alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
and ground truth segmentations.
Returns:
Returns:
None
None
Raises:
None
Example:
Example:
dataset = MySegmentationDataset()
dataset = MySegmentationDataset()
model = MySegmentationModel()
model = MySegmentationModel()
grid_pred(dataset, model, cmap_im=
'
viridis
'
, alpha=0.5)
in_targ_preds = qim3d.utils.models.inference(dataset,model)
grid_pred(in_targ_preds, cmap_im=
'
viridis
'
, alpha=0.5)
"""
"""
# Get device
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data items must be tuples
"
)
# Check if data is torch tensors
for
element
in
data
[
0
]:
if
not
isinstance
(
element
,
torch
.
Tensor
):
raise
ValueError
(
"
Data items must consist of tensors
"
)
# Check if input image is (C,H,W) format
if
data
[
0
][
0
].
dim
()
==
3
and
(
data
[
0
][
0
].
shape
[
0
]
in
[
1
,
3
]):
pass
else
:
raise
ValueError
(
"
Input image must be (C,H,W) format
"
)
# Check if dataset have at least specified number of images
# Check if dataset have at least specified number of images
if
len
(
data
)
<
num_images
:
if
len
(
in_targ_preds
[
0
]
)
<
num_images
:
log
.
warning
(
log
.
warning
(
"
Not enough images in the dataset. Changing num_images=%d to num_images=%d
"
,
"
Not enough images in the dataset. Changing num_images=%d to num_images=%d
"
,
num_images
,
num_images
,
len
(
data
),
len
(
in_targ_preds
[
0
]
),
)
)
num_images
=
len
(
data
)
num_images
=
len
(
in_targ_preds
[
0
])
# Take only the number of images from in_targ_preds
inputs
,
targets
,
preds
=
[
items
[:
num_images
]
for
items
in
in_targ_preds
]
# Adapt segmentation cmap so that background is transparent
# Adapt segmentation cmap so that background is transparent
colors_segm
=
cm
.
get_cmap
(
cmap_segm
)(
np
.
linspace
(
0
,
1
,
256
))
colors_segm
=
cm
.
get_cmap
(
cmap_segm
)(
np
.
linspace
(
0
,
1
,
256
))
colors_segm
[:
128
,
3
]
=
0
colors_segm
[:
128
,
3
]
=
0
custom_cmap
=
LinearSegmentedColormap
.
from_list
(
"
CustomCmap
"
,
colors_segm
)
custom_cmap
=
LinearSegmentedColormap
.
from_list
(
"
CustomCmap
"
,
colors_segm
)
model
.
eval
()
N
=
num_images
H
=
inputs
[
0
].
shape
[
-
2
]
# Make new list such that possible augmentations remain identical for all three rows
W
=
inputs
[
0
].
shape
[
-
1
]
plot_data
=
[
data
[
idx
]
for
idx
in
range
(
num_images
)]
# Create input and target batch
inputs
=
torch
.
stack
([
item
[
0
]
for
item
in
plot_data
],
dim
=
0
).
to
(
device
)
targets
=
torch
.
stack
([
item
[
1
]
for
item
in
plot_data
],
dim
=
0
)
# Get output predictions
with
torch
.
no_grad
():
outputs
=
model
(
inputs
)
# Prepare data for plotting
inputs
=
inputs
.
cpu
().
squeeze
()
targets
=
targets
.
squeeze
()
if
outputs
.
shape
[
1
]
==
1
:
preds
=
outputs
.
cpu
().
squeeze
()
>
0.5
else
:
preds
=
outputs
.
cpu
().
argmax
(
axis
=
1
)
N
=
len
(
plot_data
)
H
=
plot_data
[
0
][
0
].
shape
[
-
2
]
W
=
plot_data
[
0
][
0
].
shape
[
-
1
]
comp_rgb
=
torch
.
zeros
((
N
,
4
,
H
,
W
))
comp_rgb
=
torch
.
zeros
((
N
,
4
,
H
,
W
))
comp_rgb
[:,
1
,:,:]
=
targets
.
logical_and
(
preds
)
comp_rgb
[:,
1
,:,:]
=
targets
.
logical_and
(
preds
)
...
@@ -223,7 +182,7 @@ def grid_pred(
...
@@ -223,7 +182,7 @@ def grid_pred(
elif
row
==
2
:
# Ground truth segmentation
elif
row
==
2
:
# Ground truth segmentation
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
ax
.
imshow
(
plot_data
[
col
][
1
].
cpu
().
squeeze
()
,
cmap
=
custom_cmap
,
alpha
=
alpha
targets
[
col
]
,
cmap
=
custom_cmap
,
alpha
=
alpha
)
)
ax
.
axis
(
"
off
"
)
ax
.
axis
(
"
off
"
)
else
:
else
:
...
@@ -232,3 +191,5 @@ def grid_pred(
...
@@ -232,3 +191,5 @@ def grid_pred(
ax
.
axis
(
"
off
"
)
ax
.
axis
(
"
off
"
)
fig
.
show
()
fig
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment