Skip to content
Snippets Groups Projects
Commit 968d7efe authored by s193396's avatar s193396
Browse files

updated notebook

parent b2a8a35c
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags: %% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
   
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
``` ```
   
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags: %% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
   
``` python ``` python
from os.path import join from os.path import join
import qim3d import qim3d
import os import os
   
%matplotlib inline %matplotlib inline
``` ```
   
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags: %% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
   
``` python ``` python
# Define function for getting dataset path from string # Define function for getting dataset path from string
def get_dataset_path(name: str, datasets): def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets) assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name) dataset_idx = datasets.index(name)
if os.name == 'nt': if os.name == 'nt':
datasets_path = [ datasets_path = [
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side', '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d', '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/', '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d', '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d', '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary' '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
] ]
else: else:
datasets_path = [ datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side', '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d', '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/', '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d', '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d', '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary' '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
] ]
   
return datasets_path[dataset_idx] return datasets_path[dataset_idx]
``` ```
   
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags: %% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
   
### Datasets ### Datasets
   
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags: %% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
   
``` python ``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary'] datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3] dataset = datasets[3]
root = get_dataset_path(dataset,datasets) root = get_dataset_path(dataset,datasets)
   
# should not use gaudez2022: 3d image # should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation # reichardt2021: multiclass segmentation
``` ```
   
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags: %% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
   
### Model and Augmentation ### Model and Augmentation
   
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags: %% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
   
``` python ``` python
# defining model # defining model
my_model = qim3d.ml.models.UNet(size = 'medium', dropout = 0.25) my_model = qim3d.ml.models.UNet2D(size = 'medium', dropout = 0.25)
# defining augmentation # defining augmentation
my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light') my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
``` ```
   
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags: %% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
   
### Loading the data ### Loading the data
   
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags: %% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
   
``` python ``` python
# level of logging # level of logging
qim3d.utils._logger.level('info') qim3d.utils._logger.level('info')
   
# datasets and dataloaders # datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = root, val_fraction = 0.3, train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = root, val_fraction = 0.3,
model = my_model , augmentation = my_aug) model = my_model , augmentation = my_aug)
   
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, val_set, train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, val_set,
test_set, batch_size = 6) test_set, batch_size = 6)
``` ```
   
%% Output %% Output
   
The image size doesn't match the Unet model's depth. The image is changed with 'crop', from (852, 852) to (832, 832). The image size doesn't match the Unet model's depth. The image is changed with 'crop', from (852, 852) to (832, 832).
   
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags: %% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
   
``` python ``` python
qim3d.viz.grid_overview(train_set,alpha = 1) qim3d.viz.grid_overview(train_set,alpha = 1)
``` ```
   
%% Output %% Output
   
<Figure size 1400x600 with 14 Axes> <Figure size 1400x600 with 14 Axes>
   
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags: %% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
   
``` python ``` python
# Summary of model # Summary of model
model_s = qim3d.ml.model_summary(train_loader,my_model) model_s = qim3d.ml.model_summary(train_loader,my_model)
print(model_s) print(model_s)
``` ```
   
%% Output %% Output
   
======================================================================================================================================= =======================================================================================================================================
Layer (type:depth-idx) Output Shape Param # Layer (type:depth-idx) Output Shape Param #
======================================================================================================================================= =======================================================================================================================================
UNet [6, 1, 832, 832] -- UNet [6, 1, 832, 832] --
├─UNet: 1-1 [6, 1, 832, 832] -- ├─UNet: 1-1 [6, 1, 832, 832] --
│ └─Sequential: 2-1 [6, 1, 832, 832] -- │ └─Sequential: 2-1 [6, 1, 832, 832] --
│ │ └─Convolution: 3-1 [6, 64, 416, 416] -- │ │ └─Convolution: 3-1 [6, 64, 416, 416] --
│ │ │ └─Conv2d: 4-1 [6, 64, 416, 416] 640 │ │ │ └─Conv2d: 4-1 [6, 64, 416, 416] 640
│ │ │ └─ADN: 4-2 [6, 64, 416, 416] -- │ │ │ └─ADN: 4-2 [6, 64, 416, 416] --
│ │ │ │ └─InstanceNorm2d: 5-1 [6, 64, 416, 416] -- │ │ │ │ └─InstanceNorm2d: 5-1 [6, 64, 416, 416] --
│ │ │ │ └─Dropout: 5-2 [6, 64, 416, 416] -- │ │ │ │ └─Dropout: 5-2 [6, 64, 416, 416] --
│ │ │ │ └─PReLU: 5-3 [6, 64, 416, 416] 1 │ │ │ │ └─PReLU: 5-3 [6, 64, 416, 416] 1
│ │ └─SkipConnection: 3-2 [6, 128, 416, 416] -- │ │ └─SkipConnection: 3-2 [6, 128, 416, 416] --
│ │ │ └─Sequential: 4-3 [6, 64, 416, 416] -- │ │ │ └─Sequential: 4-3 [6, 64, 416, 416] --
│ │ │ │ └─Convolution: 5-4 [6, 128, 208, 208] -- │ │ │ │ └─Convolution: 5-4 [6, 128, 208, 208] --
│ │ │ │ │ └─Conv2d: 6-1 [6, 128, 208, 208] 73,856 │ │ │ │ │ └─Conv2d: 6-1 [6, 128, 208, 208] 73,856
│ │ │ │ │ └─ADN: 6-2 [6, 128, 208, 208] -- │ │ │ │ │ └─ADN: 6-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-1 [6, 128, 208, 208] -- │ │ │ │ │ │ └─InstanceNorm2d: 7-1 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Dropout: 7-2 [6, 128, 208, 208] -- │ │ │ │ │ │ └─Dropout: 7-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─PReLU: 7-3 [6, 128, 208, 208] 1 │ │ │ │ │ │ └─PReLU: 7-3 [6, 128, 208, 208] 1
│ │ │ │ └─SkipConnection: 5-5 [6, 256, 208, 208] -- │ │ │ │ └─SkipConnection: 5-5 [6, 256, 208, 208] --
│ │ │ │ │ └─Sequential: 6-3 [6, 128, 208, 208] -- │ │ │ │ │ └─Sequential: 6-3 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Convolution: 7-4 [6, 256, 104, 104] -- │ │ │ │ │ │ └─Convolution: 7-4 [6, 256, 104, 104] --
│ │ │ │ │ │ │ └─Conv2d: 8-1 [6, 256, 104, 104] 295,168 │ │ │ │ │ │ │ └─Conv2d: 8-1 [6, 256, 104, 104] 295,168
│ │ │ │ │ │ │ └─ADN: 8-2 [6, 256, 104, 104] -- │ │ │ │ │ │ │ └─ADN: 8-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-1 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-1 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Dropout: 9-2 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ └─Dropout: 9-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─PReLU: 9-3 [6, 256, 104, 104] 1 │ │ │ │ │ │ │ │ └─PReLU: 9-3 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─SkipConnection: 7-5 [6, 512, 104, 104] -- │ │ │ │ │ │ └─SkipConnection: 7-5 [6, 512, 104, 104] --
│ │ │ │ │ │ │ └─Sequential: 8-3 [6, 256, 104, 104] -- │ │ │ │ │ │ │ └─Sequential: 8-3 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Convolution: 9-4 [6, 512, 52, 52] -- │ │ │ │ │ │ │ │ └─Convolution: 9-4 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Conv2d: 10-1 [6, 512, 52, 52] 1,180,160 │ │ │ │ │ │ │ │ │ └─Conv2d: 10-1 [6, 512, 52, 52] 1,180,160
│ │ │ │ │ │ │ │ │ └─ADN: 10-2 [6, 512, 52, 52] -- │ │ │ │ │ │ │ │ │ └─ADN: 10-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-1 [6, 512, 52, 52] -- │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-1 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-2 [6, 512, 52, 52] -- │ │ │ │ │ │ │ │ │ │ └─Dropout: 11-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-3 [6, 512, 52, 52] 1 │ │ │ │ │ │ │ │ │ │ └─PReLU: 11-3 [6, 512, 52, 52] 1
│ │ │ │ │ │ │ │ └─SkipConnection: 9-5 [6, 1536, 52, 52] -- │ │ │ │ │ │ │ │ └─SkipConnection: 9-5 [6, 1536, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Convolution: 10-3 [6, 1024, 52, 52] -- │ │ │ │ │ │ │ │ │ └─Convolution: 10-3 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Conv2d: 11-4 [6, 1024, 52, 52] 4,719,616 │ │ │ │ │ │ │ │ │ │ └─Conv2d: 11-4 [6, 1024, 52, 52] 4,719,616
│ │ │ │ │ │ │ │ │ │ └─ADN: 11-5 [6, 1024, 52, 52] -- │ │ │ │ │ │ │ │ │ │ └─ADN: 11-5 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 12-1 [6, 1024, 52, 52] -- │ │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 12-1 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─Dropout: 12-2 [6, 1024, 52, 52] -- │ │ │ │ │ │ │ │ │ │ │ └─Dropout: 12-2 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─PReLU: 12-3 [6, 1024, 52, 52] 1 │ │ │ │ │ │ │ │ │ │ │ └─PReLU: 12-3 [6, 1024, 52, 52] 1
│ │ │ │ │ │ │ │ └─Convolution: 9-6 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ └─Convolution: 9-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ └─ConvTranspose2d: 10-4 [6, 256, 104, 104] 3,539,200 │ │ │ │ │ │ │ │ │ └─ConvTranspose2d: 10-4 [6, 256, 104, 104] 3,539,200
│ │ │ │ │ │ │ │ │ └─ADN: 10-5 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ │ └─ADN: 10-5 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-6 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-7 [6, 256, 104, 104] -- │ │ │ │ │ │ │ │ │ │ └─Dropout: 11-7 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-8 [6, 256, 104, 104] 1 │ │ │ │ │ │ │ │ │ │ └─PReLU: 11-8 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─Convolution: 7-6 [6, 128, 208, 208] -- │ │ │ │ │ │ └─Convolution: 7-6 [6, 128, 208, 208] --
│ │ │ │ │ │ │ └─ConvTranspose2d: 8-4 [6, 128, 208, 208] 589,952 │ │ │ │ │ │ │ └─ConvTranspose2d: 8-4 [6, 128, 208, 208] 589,952
│ │ │ │ │ │ │ └─ADN: 8-5 [6, 128, 208, 208] -- │ │ │ │ │ │ │ └─ADN: 8-5 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-7 [6, 128, 208, 208] -- │ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-7 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─Dropout: 9-8 [6, 128, 208, 208] -- │ │ │ │ │ │ │ │ └─Dropout: 9-8 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─PReLU: 9-9 [6, 128, 208, 208] 1 │ │ │ │ │ │ │ │ └─PReLU: 9-9 [6, 128, 208, 208] 1
│ │ │ │ └─Convolution: 5-6 [6, 64, 416, 416] -- │ │ │ │ └─Convolution: 5-6 [6, 64, 416, 416] --
│ │ │ │ │ └─ConvTranspose2d: 6-4 [6, 64, 416, 416] 147,520 │ │ │ │ │ └─ConvTranspose2d: 6-4 [6, 64, 416, 416] 147,520
│ │ │ │ │ └─ADN: 6-5 [6, 64, 416, 416] -- │ │ │ │ │ └─ADN: 6-5 [6, 64, 416, 416] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-7 [6, 64, 416, 416] -- │ │ │ │ │ │ └─InstanceNorm2d: 7-7 [6, 64, 416, 416] --
│ │ │ │ │ │ └─Dropout: 7-8 [6, 64, 416, 416] -- │ │ │ │ │ │ └─Dropout: 7-8 [6, 64, 416, 416] --
│ │ │ │ │ │ └─PReLU: 7-9 [6, 64, 416, 416] 1 │ │ │ │ │ │ └─PReLU: 7-9 [6, 64, 416, 416] 1
│ │ └─Convolution: 3-3 [6, 1, 832, 832] -- │ │ └─Convolution: 3-3 [6, 1, 832, 832] --
│ │ │ └─ConvTranspose2d: 4-4 [6, 1, 832, 832] 1,153 │ │ │ └─ConvTranspose2d: 4-4 [6, 1, 832, 832] 1,153
======================================================================================================================================= =======================================================================================================================================
Total params: 10,547,273 Total params: 10,547,273
Trainable params: 10,547,273 Trainable params: 10,547,273
Non-trainable params: 0 Non-trainable params: 0
Total mult-adds (G): 675.50 Total mult-adds (G): 675.50
======================================================================================================================================= =======================================================================================================================================
Input size (MB): 16.61 Input size (MB): 16.61
Forward/backward pass size (MB): 4153.34 Forward/backward pass size (MB): 4153.34
Params size (MB): 42.19 Params size (MB): 42.19
Estimated Total Size (MB): 4212.15 Estimated Total Size (MB): 4212.15
======================================================================================================================================= =======================================================================================================================================
   
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags: %% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
   
### Hyperparameters and training ### Hyperparameters and training
   
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags: %% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
   
``` python ``` python
# model hyperparameters # model hyperparameters
my_hyperparameters = qim3d.ml.Hyperparameters(my_model, n_epochs=5, my_hyperparameters = qim3d.ml.Hyperparameters(my_model, n_epochs=5,
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3) learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
   
# training model # training model
qim3d.ml.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True) qim3d.ml.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
``` ```
   
%% Output %% Output
   
   
Epoch 0, train loss: 0.7937, val loss: 0.5800 Epoch 0, train loss: 0.7937, val loss: 0.5800
   
   
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags: %% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
   
### Plotting ### Plotting
   
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags: %% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
   
``` python ``` python
in_targ_preds_test = qim3d.ml.inference(test_set,my_model) in_targ_preds_test = qim3d.ml.inference(test_set,my_model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1) qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
``` ```
   
%% Output %% Output
   
<Figure size 1400x1000 with 28 Axes> <Figure size 1400x1000 with 28 Axes>
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment