Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
968d7efe
Commit
968d7efe
authored
6 months ago
by
s193396
Browse files
Options
Downloads
Patches
Plain Diff
updated notebook
parent
b2a8a35c
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
docs/notebooks/Unet2D.ipynb
+3
-3
3 additions, 3 deletions
docs/notebooks/Unet2D.ipynb
with
3 additions
and
3 deletions
docs/notebooks/Unet.ipynb
→
docs/notebooks/Unet
2D
.ipynb
+
3
−
3
View file @
968d7efe
...
@@ -97,7 +97,7 @@
...
@@ -97,7 +97,7 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"# defining model\n",
"# defining model\n",
"my_model = qim3d.ml.models.UNet(size = 'medium', dropout = 0.25)\n",
"my_model = qim3d.ml.models.UNet
2D
(size = 'medium', dropout = 0.25)\n",
"# defining augmentation\n",
"# defining augmentation\n",
"my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')"
"my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')"
]
]
...
@@ -340,7 +340,7 @@
...
@@ -340,7 +340,7 @@
],
],
"metadata": {
"metadata": {
"kernelspec": {
"kernelspec": {
"display_name": "
Python 3 (ipykernel)
",
"display_name": "
qim3d
",
"language": "python",
"language": "python",
"name": "python3"
"name": "python3"
},
},
...
@@ -354,7 +354,7 @@
...
@@ -354,7 +354,7 @@
"name": "python",
"name": "python",
"nbconvert_exporter": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"pygments_lexer": "ipython3",
"version": "3.
9
.1
1
"
"version": "3.
10
.1
4
"
}
}
},
},
"nbformat": 4,
"nbformat": 4,
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
``` python
``` python
%load_ext autoreload
%load_ext autoreload
%autoreload 2
%autoreload 2
```
```
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
``` python
``` python
from os.path import join
from os.path import join
import qim3d
import qim3d
import os
import os
%matplotlib inline
%matplotlib inline
```
```
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
``` python
``` python
# Define function for getting dataset path from string
# Define function for getting dataset path from string
def get_dataset_path(name: str, datasets):
def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
dataset_idx = datasets.index(name)
if os.name == 'nt':
if os.name == 'nt':
datasets_path = [
datasets_path = [
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
]
else:
else:
datasets_path = [
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
]
return datasets_path[dataset_idx]
return datasets_path[dataset_idx]
```
```
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
### Datasets
### Datasets
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
``` python
``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3]
dataset = datasets[3]
root = get_dataset_path(dataset,datasets)
root = get_dataset_path(dataset,datasets)
# should not use gaudez2022: 3d image
# should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation
# reichardt2021: multiclass segmentation
```
```
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
### Model and Augmentation
### Model and Augmentation
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
``` python
``` python
# defining model
# defining model
my_model = qim3d.ml.models.UNet(size = 'medium', dropout = 0.25)
my_model = qim3d.ml.models.UNet
2D
(size = 'medium', dropout = 0.25)
# defining augmentation
# defining augmentation
my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
my_aug = qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
```
```
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
### Loading the data
### Loading the data
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
``` python
``` python
# level of logging
# level of logging
qim3d.utils._logger.level('info')
qim3d.utils._logger.level('info')
# datasets and dataloaders
# datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = root, val_fraction = 0.3,
train_set, val_set, test_set = qim3d.ml.prepare_datasets(path = root, val_fraction = 0.3,
model = my_model , augmentation = my_aug)
model = my_model , augmentation = my_aug)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, val_set,
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(train_set, val_set,
test_set, batch_size = 6)
test_set, batch_size = 6)
```
```
%% Output
%% Output
The image size doesn't match the Unet model's depth. The image is changed with 'crop', from (852, 852) to (832, 832).
The image size doesn't match the Unet model's depth. The image is changed with 'crop', from (852, 852) to (832, 832).
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
``` python
``` python
qim3d.viz.grid_overview(train_set,alpha = 1)
qim3d.viz.grid_overview(train_set,alpha = 1)
```
```
%% Output
%% Output
<Figure size 1400x600 with 14 Axes>
<Figure size 1400x600 with 14 Axes>
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
``` python
``` python
# Summary of model
# Summary of model
model_s = qim3d.ml.model_summary(train_loader,my_model)
model_s = qim3d.ml.model_summary(train_loader,my_model)
print(model_s)
print(model_s)
```
```
%% Output
%% Output
=======================================================================================================================================
=======================================================================================================================================
Layer (type:depth-idx) Output Shape Param #
Layer (type:depth-idx) Output Shape Param #
=======================================================================================================================================
=======================================================================================================================================
UNet [6, 1, 832, 832] --
UNet [6, 1, 832, 832] --
├─UNet: 1-1 [6, 1, 832, 832] --
├─UNet: 1-1 [6, 1, 832, 832] --
│ └─Sequential: 2-1 [6, 1, 832, 832] --
│ └─Sequential: 2-1 [6, 1, 832, 832] --
│ │ └─Convolution: 3-1 [6, 64, 416, 416] --
│ │ └─Convolution: 3-1 [6, 64, 416, 416] --
│ │ │ └─Conv2d: 4-1 [6, 64, 416, 416] 640
│ │ │ └─Conv2d: 4-1 [6, 64, 416, 416] 640
│ │ │ └─ADN: 4-2 [6, 64, 416, 416] --
│ │ │ └─ADN: 4-2 [6, 64, 416, 416] --
│ │ │ │ └─InstanceNorm2d: 5-1 [6, 64, 416, 416] --
│ │ │ │ └─InstanceNorm2d: 5-1 [6, 64, 416, 416] --
│ │ │ │ └─Dropout: 5-2 [6, 64, 416, 416] --
│ │ │ │ └─Dropout: 5-2 [6, 64, 416, 416] --
│ │ │ │ └─PReLU: 5-3 [6, 64, 416, 416] 1
│ │ │ │ └─PReLU: 5-3 [6, 64, 416, 416] 1
│ │ └─SkipConnection: 3-2 [6, 128, 416, 416] --
│ │ └─SkipConnection: 3-2 [6, 128, 416, 416] --
│ │ │ └─Sequential: 4-3 [6, 64, 416, 416] --
│ │ │ └─Sequential: 4-3 [6, 64, 416, 416] --
│ │ │ │ └─Convolution: 5-4 [6, 128, 208, 208] --
│ │ │ │ └─Convolution: 5-4 [6, 128, 208, 208] --
│ │ │ │ │ └─Conv2d: 6-1 [6, 128, 208, 208] 73,856
│ │ │ │ │ └─Conv2d: 6-1 [6, 128, 208, 208] 73,856
│ │ │ │ │ └─ADN: 6-2 [6, 128, 208, 208] --
│ │ │ │ │ └─ADN: 6-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-1 [6, 128, 208, 208] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-1 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Dropout: 7-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Dropout: 7-2 [6, 128, 208, 208] --
│ │ │ │ │ │ └─PReLU: 7-3 [6, 128, 208, 208] 1
│ │ │ │ │ │ └─PReLU: 7-3 [6, 128, 208, 208] 1
│ │ │ │ └─SkipConnection: 5-5 [6, 256, 208, 208] --
│ │ │ │ └─SkipConnection: 5-5 [6, 256, 208, 208] --
│ │ │ │ │ └─Sequential: 6-3 [6, 128, 208, 208] --
│ │ │ │ │ └─Sequential: 6-3 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Convolution: 7-4 [6, 256, 104, 104] --
│ │ │ │ │ │ └─Convolution: 7-4 [6, 256, 104, 104] --
│ │ │ │ │ │ │ └─Conv2d: 8-1 [6, 256, 104, 104] 295,168
│ │ │ │ │ │ │ └─Conv2d: 8-1 [6, 256, 104, 104] 295,168
│ │ │ │ │ │ │ └─ADN: 8-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ └─ADN: 8-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-1 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-1 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Dropout: 9-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Dropout: 9-2 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─PReLU: 9-3 [6, 256, 104, 104] 1
│ │ │ │ │ │ │ │ └─PReLU: 9-3 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─SkipConnection: 7-5 [6, 512, 104, 104] --
│ │ │ │ │ │ └─SkipConnection: 7-5 [6, 512, 104, 104] --
│ │ │ │ │ │ │ └─Sequential: 8-3 [6, 256, 104, 104] --
│ │ │ │ │ │ │ └─Sequential: 8-3 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Convolution: 9-4 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ └─Convolution: 9-4 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Conv2d: 10-1 [6, 512, 52, 52] 1,180,160
│ │ │ │ │ │ │ │ │ └─Conv2d: 10-1 [6, 512, 52, 52] 1,180,160
│ │ │ │ │ │ │ │ │ └─ADN: 10-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ └─ADN: 10-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-1 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-1 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-2 [6, 512, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-3 [6, 512, 52, 52] 1
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-3 [6, 512, 52, 52] 1
│ │ │ │ │ │ │ │ └─SkipConnection: 9-5 [6, 1536, 52, 52] --
│ │ │ │ │ │ │ │ └─SkipConnection: 9-5 [6, 1536, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Convolution: 10-3 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ └─Convolution: 10-3 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─Conv2d: 11-4 [6, 1024, 52, 52] 4,719,616
│ │ │ │ │ │ │ │ │ │ └─Conv2d: 11-4 [6, 1024, 52, 52] 4,719,616
│ │ │ │ │ │ │ │ │ │ └─ADN: 11-5 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ └─ADN: 11-5 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 12-1 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 12-1 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─Dropout: 12-2 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─Dropout: 12-2 [6, 1024, 52, 52] --
│ │ │ │ │ │ │ │ │ │ │ └─PReLU: 12-3 [6, 1024, 52, 52] 1
│ │ │ │ │ │ │ │ │ │ │ └─PReLU: 12-3 [6, 1024, 52, 52] 1
│ │ │ │ │ │ │ │ └─Convolution: 9-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ └─Convolution: 9-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ └─ConvTranspose2d: 10-4 [6, 256, 104, 104] 3,539,200
│ │ │ │ │ │ │ │ │ └─ConvTranspose2d: 10-4 [6, 256, 104, 104] 3,539,200
│ │ │ │ │ │ │ │ │ └─ADN: 10-5 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ └─ADN: 10-5 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─InstanceNorm2d: 11-6 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-7 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─Dropout: 11-7 [6, 256, 104, 104] --
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-8 [6, 256, 104, 104] 1
│ │ │ │ │ │ │ │ │ │ └─PReLU: 11-8 [6, 256, 104, 104] 1
│ │ │ │ │ │ └─Convolution: 7-6 [6, 128, 208, 208] --
│ │ │ │ │ │ └─Convolution: 7-6 [6, 128, 208, 208] --
│ │ │ │ │ │ │ └─ConvTranspose2d: 8-4 [6, 128, 208, 208] 589,952
│ │ │ │ │ │ │ └─ConvTranspose2d: 8-4 [6, 128, 208, 208] 589,952
│ │ │ │ │ │ │ └─ADN: 8-5 [6, 128, 208, 208] --
│ │ │ │ │ │ │ └─ADN: 8-5 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-7 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─InstanceNorm2d: 9-7 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─Dropout: 9-8 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─Dropout: 9-8 [6, 128, 208, 208] --
│ │ │ │ │ │ │ │ └─PReLU: 9-9 [6, 128, 208, 208] 1
│ │ │ │ │ │ │ │ └─PReLU: 9-9 [6, 128, 208, 208] 1
│ │ │ │ └─Convolution: 5-6 [6, 64, 416, 416] --
│ │ │ │ └─Convolution: 5-6 [6, 64, 416, 416] --
│ │ │ │ │ └─ConvTranspose2d: 6-4 [6, 64, 416, 416] 147,520
│ │ │ │ │ └─ConvTranspose2d: 6-4 [6, 64, 416, 416] 147,520
│ │ │ │ │ └─ADN: 6-5 [6, 64, 416, 416] --
│ │ │ │ │ └─ADN: 6-5 [6, 64, 416, 416] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-7 [6, 64, 416, 416] --
│ │ │ │ │ │ └─InstanceNorm2d: 7-7 [6, 64, 416, 416] --
│ │ │ │ │ │ └─Dropout: 7-8 [6, 64, 416, 416] --
│ │ │ │ │ │ └─Dropout: 7-8 [6, 64, 416, 416] --
│ │ │ │ │ │ └─PReLU: 7-9 [6, 64, 416, 416] 1
│ │ │ │ │ │ └─PReLU: 7-9 [6, 64, 416, 416] 1
│ │ └─Convolution: 3-3 [6, 1, 832, 832] --
│ │ └─Convolution: 3-3 [6, 1, 832, 832] --
│ │ │ └─ConvTranspose2d: 4-4 [6, 1, 832, 832] 1,153
│ │ │ └─ConvTranspose2d: 4-4 [6, 1, 832, 832] 1,153
=======================================================================================================================================
=======================================================================================================================================
Total params: 10,547,273
Total params: 10,547,273
Trainable params: 10,547,273
Trainable params: 10,547,273
Non-trainable params: 0
Non-trainable params: 0
Total mult-adds (G): 675.50
Total mult-adds (G): 675.50
=======================================================================================================================================
=======================================================================================================================================
Input size (MB): 16.61
Input size (MB): 16.61
Forward/backward pass size (MB): 4153.34
Forward/backward pass size (MB): 4153.34
Params size (MB): 42.19
Params size (MB): 42.19
Estimated Total Size (MB): 4212.15
Estimated Total Size (MB): 4212.15
=======================================================================================================================================
=======================================================================================================================================
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
### Hyperparameters and training
### Hyperparameters and training
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
``` python
``` python
# model hyperparameters
# model hyperparameters
my_hyperparameters = qim3d.ml.Hyperparameters(my_model, n_epochs=5,
my_hyperparameters = qim3d.ml.Hyperparameters(my_model, n_epochs=5,
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
# training model
# training model
qim3d.ml.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
qim3d.ml.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
```
```
%% Output
%% Output
Epoch 0, train loss: 0.7937, val loss: 0.5800
Epoch 0, train loss: 0.7937, val loss: 0.5800
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
### Plotting
### Plotting
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
``` python
``` python
in_targ_preds_test = qim3d.ml.inference(test_set,my_model)
in_targ_preds_test = qim3d.ml.inference(test_set,my_model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
```
```
%% Output
%% Output
<Figure size 1400x1000 with 28 Axes>
<Figure size 1400x1000 with 28 Axes>
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment