Skip to content
Snippets Groups Projects
Commit e8e54ebe authored by Oshkr's avatar Oshkr
Browse files

typo fix in data.py

parent 9a9ada7d
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
``` python
from os.path import join
import qim3d
import os
%matplotlib inline
```
%% Cell type:code id:813a3454 tags:
``` python
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
```
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
``` python
# Define function for getting dataset path from string
def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
if os.name == 'nt':
datasets_path = [
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
else:
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
return datasets_path[dataset_idx]
```
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
### Datasets
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3]
dataset = datasets[-1]
root = get_dataset_path(dataset,datasets)
# should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation
```
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
### Model and Augmentation
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
``` python
# defining model
my_model = qim3d.models.UNet(size = 'medium', dropout = 0.25)
my_model = qim3d.models.UNet(size = 'small', dropout = 0.25)
# defining augmentation
my_aug = qim3d.utils.Augmentation(resize = 'crop', transform_train = 'light')
```
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
### Loading the data
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
``` python
# level of logging
qim3d.io.logger.level('info')
# datasets and dataloaders
train_set, val_set, test_set = qim3d.utils.prepare_datasets(path = root, model = my_model , augmentation = my_aug,
val_fraction = 0.3,test_fraction = 0.1,
val_fraction = 0.3, test_fraction = 0.1,
train_folder='train', test_folder='test')
train_loader, val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set, val_set,
test_set, batch_size = 6)
```
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
``` python
qim3d.viz.grid_overview(train_set,alpha = 1)
```
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
``` python
# Summary of model
model_s = qim3d.utils.model_summary(train_loader,my_model)
print(model_s)
```
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
### Hyperparameters and training
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
``` python
# model hyperparameters
my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,
my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=20,
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
# training model
qim3d.utils.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
```
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
### Plotting
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
``` python
in_targ_preds_test = qim3d.utils.inference(test_set,my_model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
```
......
......@@ -171,7 +171,7 @@ class Dataset(torch.utils.data.Dataset):
# if the first folder contains the targets:
if any(ext in self.folder_names[0].lower() for ext in target_folder_names):
images = self.folders_names[1]
images = self.folder_names[1]
targets = self.folder_names[0]
# if the second folder contains the targets:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment