Skip to content
Snippets Groups Projects

Notebook unit testing

1 file
+ 9
9
Compare changes
  • Side-by-side
  • Inline
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
``` python
from os.path import join
import qim3d
import os
%matplotlib inline
```
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
``` python
# Define function for getting dataset path from string
def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
if os.name == 'nt':
datasets_path = [
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
else:
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
return datasets_path[dataset_idx]
```
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
### Datasets
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3]
root = get_dataset_path(dataset,datasets)
# should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation
```
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
### Model and Augmentation
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
``` python
# defining model
my_model = qim3d.models.UNet(size = 'medium', dropout = 0.25)
my_model = qim3d.ml.models.UNet(size = 'medium', dropout = 0.25)
# defining augmentation
my_aug = qim3d.utils.Augmentation(resize = 'crop', transform_train = 'light')
my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
```
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
### Loading the data
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
``` python
# level of logging
qim3d.io.logger.level('info')
qim3d.utils._logger.level('info')
# datasets and dataloaders
train_set, val_set, test_set = qim3d.utils.prepare_datasets(path = root, val_fraction = 0.3,
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = root, val_fraction = 0.3,
model = my_model , augmentation = my_aug)
train_loader, val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set, val_set,
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, val_set,
test_set, batch_size = 6)
```
%% Output
The image size doesn't match the Unet model's depth. The image is changed with 'crop', from (852, 852) to (832, 832).
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
``` python
qim3d.viz.grid_overview(train_set,alpha = 1)
```
%% Output
<Figure size 1400x600 with 14 Axes>
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
``` python
# Summary of model
model_s = qim3d.utils.model_summary(train_loader,my_model)
model_s = qim3d.ml.model_summary(train_loader,my_model)
print(model_s)
```
%% Output
=======================================================================================================================================
Layer (type:depth-idx) Output Shape Param #
=======================================================================================================================================
UNet [6, 1, 832, 832] --
├─UNet: 1-1 [6, 1, 832, 832] --
│ └─Sequential: 2-1 [6, 1, 832, 832] --
│ │ └─Convolution: 3-1 [6, 64, 416, 416] --
│ │ │ └─Conv2d: 4-1 [6, 64, 416, 416] 640
│ │ │ └─ADN: 4-2 [6, 64, 416, 416] --
│ │ │ │ └─InstanceNorm2d: 5-1 [6, 64, 416, 416] --
│ │ │ │ └─Dropout: 5-2 [6, 64, 416, 416] --
│ │ │ │ └─PReLU: 5-3 [6, 64, 416, 416] 1
│ │ └─SkipConnection: 3-2 [6, 128, 416, 416] --
│ │ │ └─Sequential: 4-3 [6, 64, 416, 416] --
│ │ │ │ └─Convolution: 5-4 [6, 128, 208, 208] --
│ │ │ │ │ └─Conv2d: 6-1 [6, 128, 208, 208] 73,856
│ │ │ │ │ └─ADN: 6-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-1 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Dropout: 7-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─PReLU: 7-3 [6, 128, 208, 208] 1
│ │ │ │ └─SkipConnection: 5-5 [6, 256, 208, 208] --
│ │ │ │ │ └─Sequential: 6-3 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Convolution: 7-4 [6, 256, 104, 104] --
│ │ │ │ │ │ │ └─Conv2d: 8-1 [6, 256, 104, 104] 295,168
│ │ │ │ │ │ │ └─ADN: 8-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-1 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Dropout: 9-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─PReLU: 9-3 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─SkipConnection: 7-5 [6, 512, 104, 104] --
│ │ │ │ │ │ │ └─Sequential: 8-3 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Convolution: 9-4 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Conv2d: 10-1 [6, 512, 52, 52] 1,180,160
│ │ │ │ │ │ │ │ │ └─ADN: 10-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-1 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-3 [6, 512, 52, 52] 1
│ │ │ │ │ │ │ │ └─SkipConnection: 9-5 [6, 1536, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Convolution: 10-3 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Conv2d: 11-4 [6, 1024, 52, 52] 4,719,616
│ │ │ │ │ │ │ │ │ │ └─ADN: 11-5 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 12-1 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─Dropout: 12-2 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─PReLU: 12-3 [6, 1024, 52, 52] 1
│ │ │ │ │ │ │ │ └─Convolution: 9-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ └─ConvTranspose2d: 10-4 [6, 256, 104, 104] 3,539,200
│ │ │ │ │ │ │ │ │ └─ADN: 10-5 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-7 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-8 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─Convolution: 7-6 [6, 128, 208, 208] --
│ │ │ │ │ │ │ └─ConvTranspose2d: 8-4 [6, 128, 208, 208] 589,952
│ │ │ │ │ │ │ └─ADN: 8-5 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-7 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─Dropout: 9-8 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─PReLU: 9-9 [6, 128, 208, 208] 1
│ │ │ │ └─Convolution: 5-6 [6, 64, 416, 416] --
│ │ │ │ │ └─ConvTranspose2d: 6-4 [6, 64, 416, 416] 147,520
│ │ │ │ │ └─ADN: 6-5 [6, 64, 416, 416] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-7 [6, 64, 416, 416] --
│ │ │ │ │ │ └─Dropout: 7-8 [6, 64, 416, 416] --
│ │ │ │ │ │ └─PReLU: 7-9 [6, 64, 416, 416] 1
│ │ └─Convolution: 3-3 [6, 1, 832, 832] --
│ │ │ └─ConvTranspose2d: 4-4 [6, 1, 832, 832] 1,153
=======================================================================================================================================
Total params: 10,547,273
Trainable params: 10,547,273
Non-trainable params: 0
Total mult-adds (G): 675.50
=======================================================================================================================================
Input size (MB): 16.61
Forward/backward pass size (MB): 4153.34
Params size (MB): 42.19
Estimated Total Size (MB): 4212.15
=======================================================================================================================================
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
### Hyperparameters and training
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
``` python
# model hyperparameters
my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,
my_hyperparameters = qim3d.ml.Hyperparameters(my_model, n_epochs=5,
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
# training model
qim3d.utils.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
qim3d.ml.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
```
%% Output
Epoch 0, train loss: 0.7937, val loss: 0.5800
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
### Plotting
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
``` python
in_targ_preds_test = qim3d.utils.inference(test_set,my_model)
in_targ_preds_test = qim3d.ml.inference(test_set,my_model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
```
%% Output
<Figure size 1400x1000 with 28 Axes>
Loading