Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
c8085470
Commit
c8085470
authored
Jun 29, 2023
by
s184058
Committed by
fima
Jun 29, 2023
Browse files
Options
Downloads
Patches
Plain Diff
Grid viz
parent
7af67039
Branches
Branches containing commit
No related tags found
1 merge request
!1
Grid viz
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
qim3d/__init__.py
+2
-2
2 additions, 2 deletions
qim3d/__init__.py
qim3d/viz/__init__.py
+1
-0
1 addition, 0 deletions
qim3d/viz/__init__.py
qim3d/viz/img.py
+234
-0
234 additions, 0 deletions
qim3d/viz/img.py
with
237 additions
and
2 deletions
qim3d/__init__.py
+
2
−
2
View file @
c8085470
import
qim3d.io
import
qim3d.gui
import
qim3d.tools
import
qim3d.viz
import
logging
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/viz/__init__.py
0 → 100644
+
1
−
0
View file @
c8085470
from
.img
import
grid_pred
,
grid_overview
\ No newline at end of file
This diff is collapsed.
Click to expand it.
qim3d/viz/img.py
0 → 100644
+
234
−
0
View file @
c8085470
"""
Provides a collection of visualization functions.
"""
import
matplotlib.pyplot
as
plt
from
matplotlib.colors
import
LinearSegmentedColormap
from
matplotlib
import
cm
import
torch
import
numpy
as
np
from
qim3d.io.logger
import
log
def
grid_overview
(
data
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
):
"""
Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
Masks are applied to the output and target prior to the loss calculation in case of
sparse labeled data
Args:
data (list or torch.utils.data.Dataset): A list of tuples or Torch dataset containing image,
label, (and mask data).
num_images (int, optional): The maximum number of images to display.
Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to
'
gray
'
.
cmap_segm (str, optional): The colormap to be used for displaying labels.
Defaults to
'
viridis
'
.
alpha (float, optional): The transparency level of the label and mask overlays.
Defaults to 0.5.
Raises:
ValueError: If the data elements are not tuples.
Notes:
- If the image data is RGB, the color map is ignored and the user is informed.
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The grid layout and dimensions vary based on the presence of a mask.
Returns:
None
Example:
data = [(image1, label1, mask1), (image2, label2, mask2)]
grid_overview(data, num_images=5, cmap_im=
'
viridis
'
, cmap_segm=
'
hot
'
, alpha=0.8)
"""
# Check if data has a mask
has_mask
=
len
(
data
[
0
])
>
2
and
data
[
0
][
-
1
]
is
not
None
# Check if image data is RGB and inform the user if it's the case
if
len
(
data
[
0
][
0
].
squeeze
().
shape
)
>
2
:
log
.
info
(
"
Input images are RGB: color map is ignored
"
)
# Check if dataset have at least specified number of images
if
len
(
data
)
<
num_images
:
log
.
warning
(
"
Not enough images in the dataset. Changing num_images=%d to num_images=%d
"
,
num_images
,
len
(
data
),
)
num_images
=
len
(
data
)
# Adapt segmentation cmap so that background is transparent
colors_segm
=
cm
.
get_cmap
(
cmap_segm
)(
np
.
linspace
(
0
,
1
,
256
))
colors_segm
[:
128
,
3
]
=
0
custom_cmap
=
LinearSegmentedColormap
.
from_list
(
"
CustomCmap
"
,
colors_segm
)
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data elements must be tuples
"
)
# Define row titles
row_titles
=
[
"
Input images
"
,
"
Ground truth segmentation
"
,
"
Mask
"
]
# Make new list such that possible augmentations remain identical for all three rows
plot_data
=
list
(
data
[:
num_images
])
fig
=
plt
.
figure
(
figsize
=
(
2
*
num_images
,
9
if
has_mask
else
6
),
constrained_layout
=
True
)
# create 2 (3) x 1 subfigs
subfigs
=
fig
.
subfigures
(
nrows
=
3
if
has_mask
else
2
,
ncols
=
1
)
for
row
,
subfig
in
enumerate
(
subfigs
):
subfig
.
suptitle
(
row_titles
[
row
],
fontsize
=
22
)
# create 1 x num_images subplots per subfig
axs
=
subfig
.
subplots
(
nrows
=
1
,
ncols
=
num_images
)
for
col
,
ax
in
enumerate
(
np
.
atleast_1d
(
axs
)):
if
row
in
[
1
,
2
]:
# Ground truth segmentation and mask
ax
.
imshow
(
plot_data
[
col
][
0
].
squeeze
(),
cmap
=
cmap_im
)
ax
.
imshow
(
plot_data
[
col
][
row
].
squeeze
(),
cmap
=
custom_cmap
,
alpha
=
alpha
)
ax
.
axis
(
"
off
"
)
else
:
ax
.
imshow
(
plot_data
[
col
][
row
].
squeeze
(),
cmap
=
cmap_im
)
ax
.
axis
(
"
off
"
)
fig
.
show
()
def
grid_pred
(
data
,
model
,
num_images
=
7
,
cmap_im
=
"
gray
"
,
cmap_segm
=
"
viridis
"
,
alpha
=
0.5
):
"""
Displays a grid of input images, predicted segmentations, and ground truth segmentations.
Args:
data (torch.utils.data.Dataset): A Torch dataset containing input image and
ground truth label data.
model (torch.nn.Module): The trained network model used for predicting segmentations.
num_images (int, optional): The maximum number of images to display. Defaults to 7.
cmap_im (str, optional): The colormap to be used for displaying input images.
Defaults to
'
gray
'
.
cmap_segm (str, optional): The colormap to be used for displaying segmentations.
Defaults to
'
viridis
'
.
alpha (float, optional): The transparency level of the predicted segmentation overlay.
Defaults to 0.5.
Raises:
ValueError: If the data items are not tuples or data items do not consist of tensors.
ValueError: If the input image is not in (C, H, W) format.
Notes:
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The function does not assume that the model is already in evaluation mode (model.eval()).
- The function will execute faster on a CUDA-enabled GPU.
- The grid layout consists of three rows: input images, predicted segmentations,
and ground truth segmentations.
Returns:
None
Example:
dataset = MySegmentationDataset()
model = MySegmentationModel()
grid_pred(dataset, model, cmap_im=
'
viridis
'
, alpha=0.5)
"""
# Get device
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
# Check if data have the right format
if
not
isinstance
(
data
[
0
],
tuple
):
raise
ValueError
(
"
Data items must be tuples
"
)
# Check if data is torch tensors
for
element
in
data
[
0
]:
if
not
isinstance
(
element
,
torch
.
Tensor
):
raise
ValueError
(
"
Data items must consist of tensors
"
)
# Check if input image is (C,H,W) format
if
data
[
0
][
0
].
dim
()
==
3
and
(
data
[
0
][
0
].
shape
[
0
]
in
[
1
,
3
]):
pass
else
:
raise
ValueError
(
"
Input image must be (C,H,W) format
"
)
# Check if dataset have at least specified number of images
if
len
(
data
)
<
num_images
:
log
.
warning
(
"
Not enough images in the dataset. Changing num_images=%d to num_images=%d
"
,
num_images
,
len
(
data
),
)
num_images
=
len
(
data
)
# Adapt segmentation cmap so that background is transparent
colors_segm
=
cm
.
get_cmap
(
cmap_segm
)(
np
.
linspace
(
0
,
1
,
256
))
colors_segm
[:
128
,
3
]
=
0
custom_cmap
=
LinearSegmentedColormap
.
from_list
(
"
CustomCmap
"
,
colors_segm
)
model
.
eval
()
# Make new list such that possible augmentations remain identical for all three rows
plot_data
=
list
(
data
[:
num_images
])
# Create input and target batch
inputs
=
torch
.
stack
([
item
[
0
]
for
item
in
plot_data
],
dim
=
0
).
to
(
device
)
targets
=
torch
.
stack
([
item
[
1
]
for
item
in
plot_data
],
dim
=
0
)
# Get output predictions
with
torch
.
no_grad
():
outputs
=
model
(
inputs
)
# Prepare data for plotting
inputs
=
inputs
.
cpu
().
squeeze
()
targets
=
targets
.
squeeze
()
if
outputs
.
shape
[
1
]
==
1
:
preds
=
outputs
.
cpu
().
squeeze
()
>
0.5
else
:
preds
=
outputs
.
cpu
().
argmax
(
axis
=
1
)
N
=
len
(
plot_data
)
H
=
plot_data
[
0
][
0
].
shape
[
-
2
]
W
=
plot_data
[
0
][
0
].
shape
[
-
1
]
comp_rgb
=
torch
.
zeros
((
N
,
4
,
H
,
W
))
comp_rgb
[:,
1
,:,:]
=
targets
.
logical_and
(
preds
)
comp_rgb
[:,
0
,:,:]
=
targets
.
logical_xor
(
preds
)
comp_rgb
[:,
3
,:,:]
=
targets
.
logical_or
(
preds
)
row_titles
=
[
"
Input images
"
,
"
Predicted segmentation
"
,
"
Ground truth segmentation
"
,
"
True vs. predicted segmentation
"
,
]
fig
=
plt
.
figure
(
figsize
=
(
2
*
num_images
,
10
),
constrained_layout
=
True
)
# create 3 x 1 subfigs
subfigs
=
fig
.
subfigures
(
nrows
=
4
,
ncols
=
1
)
for
row
,
subfig
in
enumerate
(
subfigs
):
subfig
.
suptitle
(
row_titles
[
row
],
fontsize
=
22
)
# create 1 x num_images subplots per subfig
axs
=
subfig
.
subplots
(
nrows
=
1
,
ncols
=
num_images
)
for
col
,
ax
in
enumerate
(
np
.
atleast_1d
(
axs
)):
if
row
==
0
:
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
axis
(
"
off
"
)
elif
row
==
1
:
# Predicted segmentation
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
preds
[
col
],
cmap
=
custom_cmap
,
alpha
=
alpha
)
ax
.
axis
(
"
off
"
)
elif
row
==
2
:
# Ground truth segmentation
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
plot_data
[
col
][
1
].
cpu
().
squeeze
(),
cmap
=
custom_cmap
,
alpha
=
alpha
)
ax
.
axis
(
"
off
"
)
else
:
ax
.
imshow
(
inputs
[
col
],
cmap
=
cmap_im
)
ax
.
imshow
(
comp_rgb
[
col
].
permute
(
1
,
2
,
0
),
alpha
=
alpha
)
ax
.
axis
(
"
off
"
)
fig
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment