Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
39cca3af
Commit
39cca3af
authored
Jul 5, 2023
by
ofhkr
Committed by
fima
Jul 5, 2023
Browse files
Options
Downloads
Patches
Plain Diff
Vizualization update
parent
d9b70f14
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!4
Vizualization update
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
docs/notebooks/Unet.ipynb
+100
-36
100 additions, 36 deletions
docs/notebooks/Unet.ipynb
qim3d/utils/__init__.py
+1
-0
1 addition, 0 deletions
qim3d/utils/__init__.py
qim3d/utils/models.py
+81
-0
81 additions, 0 deletions
qim3d/utils/models.py
qim3d/viz/img.py
+35
-74
35 additions, 74 deletions
qim3d/viz/img.py
with
217 additions
and
110 deletions
docs/notebooks/Unet.ipynb
+
100
−
36
View file @
39cca3af
Source diff could not be displayed: it is too large. Options to address this:
view the blob
.
This diff is collapsed.
Click to expand it.
qim3d/utils/__init__.py
+
1
−
0
View file @
39cca3af
from
.
import
internal_tools
from
.
import
models
from
.data
import
Dataset
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/utils/models.py
0 → 100644
+
81
−
0
View file @
39cca3af
"""
Tools performed with trained models.
"""
import
torch
def
inference
(
data
,
model
):
"""
Performs inference on input data using the specified model.
Performs inference on the input data using the provided model. The input data should be in the form of a list,
where each item is a tuple containing the input image tensor and the corresponding target label tensor.
The function checks the format and validity of the input data, ensures the model is in evaluation mode,
and generates predictions using the model. The input images, target labels, and predicted labels are returned
as a tuple.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
Returns:
tuple: A tuple containing the input images, target labels, and predicted labels.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The function does not assume the model is already in evaluation mode (model.eval()).
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
inference(data,model)
"""
# Get device
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data items must be tuples
"
)
# Check if data is torch tensors
for
element
in
data
[
0
]:
if
not
isinstance
(
element
,
torch
.
Tensor
):
raise
ValueError
(
"
Data items must consist of tensors
"
)
# Check if input image is (C,H,W) format
if
data
[
0
][
0
].
dim
()
==
3
and
(
data
[
0
][
0
].
shape
[
0
]
in
[
1
,
3
]):
pass
else
:
raise
ValueError
(
"
Input image must be (C,H,W) format
"
)
model
.
eval
()
# Make new list such that possible augmentations remain identical for all three rows
plot_data
=
[
data
[
idx
]
for
idx
in
range
(
len
(
data
))]
# Create input and target batch
inputs
=
torch
.
stack
([
item
[
0
]
for
item
in
plot_data
],
dim
=
0
).
to
(
device
)
targets
=
torch
.
stack
([
item
[
1
]
for
item
in
plot_data
],
dim
=
0
)
# Get output predictions
with
torch
.
no_grad
():
outputs
=
model
(
inputs
)
# Prepare data for plotting
inputs
=
inputs
.
cpu
().
squeeze
()
targets
=
targets
.
squeeze
()
if
outputs
.
shape
[
1
]
==
1
:
preds
=
outputs
.
cpu
().
squeeze
()
>
0.5
else
:
preds
=
outputs
.
cpu
().
argmax
(
axis
=
1
)
# if there is only one image
if
inputs
.
dim
()
==
2
:
inputs
=
inputs
.
unsqueeze
(
0
)
targets
=
targets
.
unsqueeze
(
0
)
preds
=
preds
.
unsqueeze
(
0
)
return
inputs
,
targets
,
preds
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/viz/img.py
+
35
−
74
View file @
39cca3af
...
...
@@ -95,100 +95,59 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
fig
.
show
()
def
grid_pred
(
data
,
model
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
):
"""
Displays a grid of input images, predicted segmentations, and ground truth segmentations.
def
grid_pred
(
in_targ_preds
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
):
"""
Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
num_images (int, optional): The maximum number of images to display. Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to
'
gray
'
.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to
'
viridis
'
.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Displays a grid of subplots representing different aspects of the input images and segmentations.
The grid includes the following rows:
- Row 1: Input images
- Row 2: Predicted segmentations overlaying input images
- Row 3: Ground truth segmentations overlaying input images
- Row 4: Comparison between true and predicted segmentations overlaying input images
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Each row consists of `num_images` subplots, where each subplot corresponds to an image from the dataset.
The function utilizes various color maps for visualization and applies transparency to the segmentations.
Notes:
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The function does not assume that the model is already in evaluation mode (model.eval()).
- The function will execute faster on a CUDA-enabled GPU.
- The grid layout consists of three rows: input images, predicted segmentations,
and ground truth segmentations.
Args:
in_targ_preds (tuple): A tuple containing input images, target segmentations, and predicted segmentations.
num_images (int, optional): Number of images to display. Defaults to 7.
cmap_im (str, optional): Color map for input images. Defaults to
"
gray
"
.
cmap_segm (str, optional): Color map for segmentations. Defaults to
"
viridis
"
.
alpha (float, optional): Alpha value for transparency. Defaults to 0.5.
Returns:
None
Raises:
None
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
grid_pred(dataset, model, cmap_im=
'
viridis
'
, alpha=0.5)
in_targ_preds = qim3d.utils.models.inference(dataset,model)
grid_pred(in_targ_preds, cmap_im=
'
viridis
'
, alpha=0.5)
"""
# Get device
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data items must be tuples
"
)
# Check if data is torch tensors
for
element
in
data
[
0
]:
if
not
isinstance
(
element
,
torch
.
Tensor
):
raise
ValueError
(
"
Data items must consist of tensors
"
)
# Check if input image is (C,H,W) format
if
data
[
0
][
0
].
dim
()
==
3
and
(
data
[
0
][
0
].
shape
[
0
]
in
[
1
,
3
]):
pass
else
:
raise
ValueError
(
"
Input image must be (C,H,W) format
"
)
# Check if dataset have at least specified number of images
if
len
(
data
)
<
num_images
:
if
len
(
in_targ_preds
[
0
]
)
<
num_images
:
log
.
warning
(
"
Not enough images in the dataset. Changing num_images=%d to num_images=%d
"
,
num_images
,
len
(
data
),
len
(
in_targ_preds
[
0
]
),
)
num_images
=
len
(
data
)
num_images
=
len
(
in_targ_preds
[
0
])
# Take only the number of images from in_targ_preds
inputs
,
targets
,
preds
=
[
items
[:
num_images
]
for
items
in
in_targ_preds
]
# Adapt segmentation cmap so that background is transparent
colors_segm
=
cm
.
get_cmap
(
cmap_segm
)(
np
.
linspace
(
0
,
1
,
256
))
colors_segm
[:
128
,
3
]
=
0
custom_cmap
=
LinearSegmentedColormap
.
from_list
(
"
CustomCmap
"
,
colors_segm
)
model
.
eval
()
# Make new list such that possible augmentations remain identical for all three rows
plot_data
=
[
data
[
idx
]
for
idx
in
range
(
num_images
)]
# Create input and target batch
inputs
=
torch
.
stack
([
item
[
0
]
for
item
in
plot_data
],
dim
=
0
).
to
(
device
)
targets
=
torch
.
stack
([
item
[
1
]
for
item
in
plot_data
],
dim
=
0
)
# Get output predictions
with
torch
.
no_grad
():
outputs
=
model
(
inputs
)
# Prepare data for plotting
inputs
=
inputs
.
cpu
().
squeeze
()
targets
=
targets
.
squeeze
()
if
outputs
.
shape
[
1
]
==
1
:
preds
=
outputs
.
cpu
().
squeeze
()
>
0.5
else
:
preds
=
outputs
.
cpu
().
argmax
(
axis
=
1
)
N
=
len
(
plot_data
)
H
=
plot_data
[
0
][
0
].
shape
[
-
2
]
W
=
plot_data
[
0
][
0
].
shape
[
-
1
]
N
=
num_images
H
=
inputs
[
0
].
shape
[
-
2
]
W
=
inputs
[
0
].
shape
[
-
1
]
comp_rgb
=
torch
.
zeros
((
N
,
4
,
H
,
W
))
comp_rgb
[:,
1
,:,:]
=
targets
.
logical_and
(
preds
)
...
...
@@ -223,7 +182,7 @@ def grid_pred(
elif
row
==
2
:
# Ground truth segmentation
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
plot_data
[
col
][
1
].
cpu
().
squeeze
()
,
cmap
=
custom_cmap
,
alpha
=
alpha
targets
[
col
]
,
cmap
=
custom_cmap
,
alpha
=
alpha
)
ax
.
axis
(
"
off
"
)
else
:
...
...
@@ -232,3 +191,5 @@ def grid_pred(
ax
.
axis
(
"
off
"
)
fig
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment