Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Merge requests
!4
Vizualization update
Code
Review changes
Check out branch
Open in Workspace
Download
Patches
Plain diff
Expand sidebar
Merged
Vizualization update
vizualization_update
into
main
Overview
0
Commits
6
Pipelines
0
Changes
9
Merged
Vizualization update
ofhkr
requested to merge
vizualization_update
into
main
Jul 5, 2023
Overview
0
Commits
6
Pipelines
0
Changes
9
Created a new python file 'models.py'.
Moved the inference part of the 'grid_pred' function (located under viz/img.py) into a 'inference' function (located under utils/models.py).
Solved issue when only 1 image given to 'grid_pred'. Solved in 'inference', lines 75-79.
Modified the 'Unet.ipynb' with the new version of 'grid_pred'.
0
0
Merge request reports
Viewing commit
8211b21d
Prev
Next
Show latest version
9 files
+
389
−
7
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
9
8211b21d
First version of dataset class
· 8211b21d
s184058
authored
Jul 5, 2023
docs/notebooks/Unet.ipynb
0 → 100644
+
312
−
0
View file @ 8211b21d
Edit in single-file editor
Open in Web IDE
Show full file
%% Cell type:code id:dd6781ce tags:
```
python
%
load_ext
autoreload
%
autoreload
2
```
%% Cell type:code id:fa88080a tags:
```
python
from
glob
import
glob
from
os.path
import
join
import
os
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
skimage.io
import
imread
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
,
trange
from
monai.networks.nets
import
UNet
from
torchvision
import
transforms
from
monai.losses
import
FocalLoss
,
DiceLoss
import
qim3d
import
albumentations
as
A
from
albumentations.pytorch
import
ToTensorV2
%
matplotlib
inline
```
%% Cell type:code id:d0a5eade tags:
```
python
# Define function for getting dataset path from string
def
get_dataset_path
(
name
:
str
):
datasets
=
[
'
belialev2020_side
'
,
'
gaudez2022_3d
'
,
'
guo2023_2d
'
,
'
stan2020_2d
'
,
'
reichardt2021_2d
'
,
'
testcircles_2dbinary
'
,
]
assert
name
in
datasets
,
'
Dataset name must be
'
+
'
or
'
.
join
(
datasets
)
dataset_idx
=
datasets
.
index
(
name
)
datasets_path
=
[
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side
'
,
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d
'
,
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/
'
,
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d
'
,
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d
'
,
'
/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary
'
]
return
datasets_path
[
dataset_idx
]
```
%% Cell type:markdown id:ee235f48 tags:
# Data loading
%% Cell type:markdown id:c088ceb8 tags:
### Check out https://albumentations.ai/docs/getting_started/transforms_and_targets/
%% Cell type:code id:ddfef29e tags:
```
python
# Training set transformation
aug_train
=
A
.
Compose
([
A
.
Resize
(
832
,
832
),
A
.
RandomRotate90
(),
A
.
Normalize
(
mean
=
(
0.5
),
std
=
(
0.5
)),
# Normalize to [-1, 1]
ToTensorV2
()
])
# Validation/test set transformation
aug_val_test
=
A
.
Compose
([
A
.
Resize
(
832
,
832
),
A
.
Normalize
(
mean
=
(
0.5
),
std
=
(
0.5
)),
# Normalize to [-1, 1]
ToTensorV2
()
])
```
%% Cell type:code id:766f0f8e tags:
```
python
### Possible datasets ####
# 'belialev2020_side'
# 'gaudez2022_3d'
# 'guo2023_2d'
# 'stan2020_2d'
# 'reichardt2021_2d'
# 'testcircles_2dbinary'
# Choose dataset
dataset
=
'
stan2020_2d
'
# Define class instances. First, both train and validation set is defined from train
# folder with different transformations and below divided into non-overlapping subsets
train_set
=
qim3d
.
qim3d
.
utils
.
Dataset
(
root_path
=
get_dataset_path
(
dataset
),
transform
=
aug_train
)
val_set
=
qim3d
.
qim3d
.
utils
.
Dataset
(
root_path
=
get_dataset_path
(
dataset
),
transform
=
aug_val_test
)
test_set
=
qim3d
.
qim3d
.
utils
.
Dataset
(
root_path
=
get_dataset_path
(
dataset
),
split
=
'
test
'
,
transform
=
aug_val_test
)
# Define fraction of training set used for validation
VAL_FRACTION
=
0.3
split_idx
=
int
(
np
.
floor
(
VAL_FRACTION
*
len
(
train_set
)))
# Define seed
# torch.manual_seed(42)
# Get randomly permuted indices
indices
=
torch
.
randperm
(
len
(
train_set
))
# Define train and validation sets as subsets
train_set
=
torch
.
utils
.
data
.
Subset
(
train_set
,
indices
[
split_idx
:])
val_set
=
torch
.
utils
.
data
.
Subset
(
val_set
,
indices
[:
split_idx
])
```
%% Cell type:markdown id:321495cc tags:
### Data overview
%% Cell type:code id:a794b739 tags:
```
python
# Check if data has mask
has_mask
=
False
#True if train_set[0][-1] is not None else False
print
(
f
'
No. of train images=
{
len
(
train_set
)
}
'
)
print
(
f
'
No. of validation images=
{
len
(
val_set
)
}
'
)
print
(
f
'
No. of test images=
{
len
(
test_set
)
}
'
)
print
(
f
'
{
train_set
[
0
][
0
].
dtype
=
}
'
)
print
(
f
'
{
train_set
[
0
][
1
].
dtype
=
}
'
)
print
(
f
'
image shape=
{
train_set
[
0
][
0
].
shape
}
'
)
print
(
f
'
label shape=
{
train_set
[
0
][
1
].
shape
}
'
)
print
(
f
'
Labels=
{
np
.
unique
(
train_set
[
0
][
1
])
}
'
)
print
(
f
'
Masked data?
{
has_mask
}
'
)
```
%% Cell type:markdown id:5efa7d33 tags:
### Data visualization
Display first seven image, labels, and masks if they exist
%% Cell type:code id:170577d3 tags:
```
python
qim3d
.
qim3d
.
viz
.
grid_overview
(
train_set
,
num_images
=
6
,
alpha
=
1
)
```
%% Cell type:code id:33368063 tags:
```
python
# Define batch sizes
TRAIN_BATCH_SIZE
=
4
VAL_BATCH_SIZE
=
4
TEST_BATCH_SIZE
=
4
# Define dataloaders
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
TRAIN_BATCH_SIZE
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
dataset
=
val_set
,
batch_size
=
VAL_BATCH_SIZE
,
num_workers
=
8
,
pin_memory
=
True
)
test_loader
=
DataLoader
(
dataset
=
test_set
,
batch_size
=
TEST_BATCH_SIZE
,
num_workers
=
8
,
pin_memory
=
True
)
```
%% Cell type:markdown id:35e83e38 tags:
# Train U-Net
%% Cell type:code id:36685b25 tags:
```
python
# Define model
model
=
UNet
(
spatial_dims
=
2
,
in_channels
=
1
,
out_channels
=
1
,
channels
=
(
64
,
128
,
256
,
512
,
1024
),
strides
=
(
2
,
2
,
2
,
2
),
)
orig_state
=
model
.
state_dict
()
# Save, so we can reset model to original state later
# Define loss function
#loss_fn = nn.CrossEntropyLoss()
loss_fn
=
FocalLoss
()
# Define device
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
```
%% Cell type:markdown id:137be29b tags:
### Run training
%% Cell type:code id:13d8a9f3 tags:
```
python
# Define hyperparameters
NUM_EPOCHS
=
5
EVAL_EVERY
=
1
PRINT_EVERY
=
1
LR
=
3e-3
model
.
load_state_dict
(
orig_state
)
# Restart training every time
model
.
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
LR
)
all_losses
=
[]
all_val_loss
=
[]
for
epoch
in
range
(
NUM_EPOCHS
):
model
.
train
()
epoch_loss
=
0
step
=
0
for
data
in
train_loader
:
if
has_mask
:
inputs
,
targets
,
masks
=
data
masks
=
masks
.
to
(
device
).
float
()
else
:
inputs
,
targets
=
data
inputs
=
inputs
.
to
(
device
)
targets
=
targets
.
to
(
device
).
float
().
unsqueeze
(
1
)
# Forward -> Backward -> Step
optimizer
.
zero_grad
()
outputs
=
model
(
inputs
)
#print(f'input {inputs.shape}, target: {targets.shape}, output: {outputs.shape}')
loss
=
loss_fn
(
outputs
*
masks
,
targets
*
masks
)
if
has_mask
else
loss_fn
(
outputs
,
targets
)
loss
.
backward
()
optimizer
.
step
()
epoch_loss
+=
loss
.
detach
()
step
+=
1
# Log and store average epoch loss
epoch_loss
=
epoch_loss
.
item
()
/
step
all_losses
.
append
(
epoch_loss
)
if
epoch
%
EVAL_EVERY
==
0
:
model
.
eval
()
with
torch
.
no_grad
():
# Do not need gradients for this part
loss_sum
=
0
step
=
0
for
data
in
val_loader
:
if
has_mask
:
inputs
,
targets
,
masks
=
data
masks
=
masks
.
to
(
device
).
float
()
else
:
inputs
,
targets
=
data
inputs
=
inputs
.
to
(
device
)
targets
=
targets
.
to
(
device
).
float
().
unsqueeze
(
1
)
outputs
=
model
(
inputs
)
loss_sum
+=
loss_fn
(
outputs
*
masks
,
targets
*
masks
)
if
has_mask
else
loss_fn
(
outputs
,
targets
)
step
+=
1
val_loss
=
loss_sum
.
item
()
/
step
all_val_loss
.
append
(
val_loss
)
# Log and store average accuracy
if
epoch
%
PRINT_EVERY
==
0
:
print
(
f
'
Epoch
{
epoch
:
3
}
, train loss:
{
epoch_loss
:
.
4
f
}
, val loss:
{
val_loss
:
.
4
f
}
'
)
print
(
'
Min val loss:
'
,
min
(
all_val_loss
))
```
%% Cell type:markdown id:a7a8e9d7 tags:
### Plot train and validation loss
%% Cell type:code id:851463c8 tags:
```
python
plt
.
figure
(
figsize
=
(
16
,
3
))
plt
.
plot
(
all_losses
,
'
-
'
,
label
=
'
Train
'
)
plt
.
plot
(
all_val_loss
,
'
-
'
,
label
=
'
Val.
'
)
plt
.
legend
()
plt
.
show
()
```
%% Cell type:markdown id:1a700f8a tags:
### Inspecting the Predicted Segmentations on training data
%% Cell type:code id:2ac83638 tags:
```
python
qim3d
.
qim3d
.
viz
.
grid_pred
(
train_set
,
model
,
num_images
=
5
,
alpha
=
1
)
```
%% Cell type:markdown id:a176ff96 tags:
### Inspecting the Predicted Segmentations on test data
%% Cell type:code id:ffb261c2 tags:
```
python
qim3d
.
qim3d
.
viz
.
grid_pred
(
test_set
,
model
,
alpha
=
1
)
```
Loading