Skip to content
Snippets Groups Projects
Commit 5f38d847 authored by ofhkr's avatar ofhkr Committed by fima
Browse files

Implementation of Deep Learning unit tests, as well as paths to the 2d data...

Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.
parent fe74b961
No related branches found
No related tags found
1 merge request!23Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.
%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
``` python
from os.path import join
import qim3d
import os
%matplotlib inline
```
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
``` python
# Define function for getting dataset path from string
def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
if os.name == 'nt':
datasets_path = [
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
else:
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
return datasets_path[dataset_idx]
```
%% Cell type:markdown id:7d07077a-cce3-4448-89f5-02413345becc tags:
### Datasets
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3]
root = get_dataset_path(dataset,datasets)
# should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation
```
%% Cell type:markdown id:254dc8cb-6f24-4b57-91c0-98fb6f62602c tags:
### Model and Augmentation
%% Cell type:code id:30098003-ec06-48e0-809f-82f44166fb2b tags:
``` python
# defining model
my_model = qim3d.models.UNet(size = 'medium', dropout = 0.25)
# defining augmentation
my_aug = qim3d.utils.Augmentation(resize = 'crop', transform_train = 'light')
```
%% Cell type:markdown id:7b56c654-720d-4c5f-8545-749daa5dbaf2 tags:
### Loading the data
%% Cell type:code id:84141298-054d-4322-8bda-5ec514528985 tags:
``` python
# level of logging
qim3d.io.logger.level('info')
# datasets and dataloaders
train_set, val_set, test_set = qim3d.utils.prepare_datasets(path = root, val_fraction = 0.3,
model = my_model , augmentation = my_aug)
train_loader, val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set, val_set,
test_set, batch_size = 6)
```
%% Cell type:code id:f320a4ae-f063-430c-b5a0-0d9fb64c2725 tags:
``` python
qim3d.viz.grid_overview(train_set,alpha = 1)
```
%% Cell type:code id:7fa3aa57-ba61-4c9a-934c-dce26bbc9e97 tags:
``` python
# Summary of model
model_s = qim3d.utils.model_summary(train_loader,my_model)
print(model_s)
```
%% Cell type:markdown id:a665ae28-d9a6-419f-9131-54283b47582c tags:
### Hyperparameters and training
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
``` python
# model hyperparameters
my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=25,
my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,
learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)
# training model
qim3d.utils.train_model(my_model, my_hyperparameters, train_loader, val_loader, plot=True)
```
%% Cell type:markdown id:7e14fac8-4fd3-4725-bd0d-9e2a95552278 tags:
### Plotting
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
``` python
in_targ_preds_test = qim3d.utils.inference(test_set,my_model)
qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
```
......
import qim3d
import torch
# unit tests for UNet()
def test_starting_unet():
unet = qim3d.models.UNet()
assert unet.size == 'medium'
def test_forward_pass():
unet = qim3d.models.UNet()
# Size: B x C x H x W
x = torch.ones([1,1,256,256])
output = unet(x)
assert x.shape == output.shape
# unit tests for Hyperparameters()
def test_hyper():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
assert hyperparams.n_epochs == 10
def test_hyper_dict():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
hyper_dict = hyperparams()
assert type(hyper_dict) == dict
import qim3d
import albumentations
import pytest
# unit tests for Augmentation()
def test_augmentation():
augment_class = qim3d.utils.Augmentation()
assert augment_class.resize == 'crop'
def test_augment():
augment_class = qim3d.utils.Augmentation()
album_augment = augment_class.augment(256,256)
assert type(album_augment) == albumentations.core.composition.Compose
# unit tests for ValueErrors in Augmentation()
def test_resize():
resize_str = 'not valid resize'
with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."):
augment_class = qim3d.utils.Augmentation(resize = resize_str)
def test_levels():
augment_class = qim3d.utils.Augmentation()
level = 'Not a valid level'
with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
augment_class.augment(256,256,level)
\ No newline at end of file
import qim3d
import pytest
from torch.utils.data.dataloader import DataLoader
from qim3d.utils.internal_tools import temp_data
# unit tests for Dataset()
def test_dataset():
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.utils.Dataset(folder)
assert images[0][0].shape == img_shape
temp_data(folder,remove=True)
# unit tests for check_resize()
def test_check_resize():
h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
assert (h_adjust,w_adjust) == (192,192)
def test_check_resize_pad():
h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'padding',n_channels = 6)
assert (h_adjust,w_adjust) == (64,64)
def test_check_resize_fail():
with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."):
h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'crop',n_channels = 6)
# unit tests for prepare_datasets()
def test_prepare_datasets():
n = 3
validation = 1/3
folder = 'folder_data'
img = temp_data(folder,n = n)
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation)
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
temp_data(folder,remove=True)
# unit test for validation in prepare_datasets()
def test_validation():
validation = 10
with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."):
augment_class = qim3d.utils.prepare_datasets('folder',validation,'my_model','my_augmentation')
# unit test for prepare_dataloaders()
def test_prepare_dataloaders():
folder = 'folder_data'
temp_data(folder)
batch_size = 1
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation()
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation)
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1,
pin_memory = False)
assert type(val_loader) == DataLoader
temp_data(folder,remove=True)
\ No newline at end of file
import qim3d
import pytest
from torch import ones
from qim3d.utils.internal_tools import temp_data
# unit test for model summary()
def test_model_summary():
n = 10
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder,img_shape=img_shape,n = n)
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1,
pin_memory = False)
summary = qim3d.utils.model_summary(val_loader,unet)
assert summary.input_size[0] == (1,1) + img_shape
temp_data(folder,remove=True)
# unit test for inference()
def test_inference():
folder = 'folder_data'
temp_data(folder)
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
_, targ,_ = qim3d.utils.inference(train_set,unet)
assert tuple(targ[0].unique()) == (0,1)
temp_data(folder,remove=True)
#unit test for tuple ValueError().
def test_inference_tuple():
folder = 'folder_data'
temp_data(folder)
unet = qim3d.models.UNet(size = 'small')
data = [1,2,3]
with pytest.raises(ValueError,match="Data items must be tuples"):
qim3d.utils.inference(data,unet)
temp_data(folder,remove=True)
#unit test for tensor ValueError().
def test_inference_tensor():
folder = 'folder_data'
temp_data(folder)
unet = qim3d.models.UNet(size = 'small')
data = [(1,2)]
with pytest.raises(ValueError,match="Data items must consist of tensors"):
qim3d.utils.inference(data,unet)
temp_data(folder,remove=True)
#unit test for dimension ValueError().
def test_inference_dim():
folder = 'folder_data'
temp_data(folder)
unet = qim3d.models.UNet(size = 'small')
data = [(ones(1),ones(1))]
# need the r"" for special characters
with pytest.raises(ValueError,match=r"Input image must be \(C,H,W\) format"):
qim3d.utils.inference(data,unet)
temp_data(folder,remove=True)
# unit test for train_model()
def test_train_model():
folder = 'folder_data'
temp_data(folder)
n_epochs = 1
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1,
pin_memory = False)
train_loss,_ = qim3d.utils.train_model(unet,hyperparams,train_loader,val_loader,
plot = False, return_loss = True)
assert len(train_loss['loss']) == n_epochs
temp_data(folder,remove=True)
\ No newline at end of file
......@@ -174,7 +174,7 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
return train_set, val_set, test_set
def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = True):
def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False):
"""
Prepares the dataloaders for model training.
......
......@@ -8,6 +8,9 @@ import matplotlib
import numpy as np
import socket
import os
import shutil
from PIL import Image
from pathlib import Path
......@@ -178,6 +181,48 @@ def is_server_running(ip, port):
except:
return False
def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
folder_trte = ['train','test']
sub_folders = ['images','labels']
# Creating train/test folder
path_train = Path(folder) / folder_trte[0]
path_test = Path(folder) / folder_trte[1]
# Creating folders for images and labels
path_train_im = path_train / sub_folders[0]
path_train_lab = path_train / sub_folders[1]
path_test_im = path_test / sub_folders[0]
path_test_lab = path_test / sub_folders[1]
# Random image
img = np.random.randint(2,size = img_shape,dtype = np.uint8)
img = Image.fromarray(img)
if not os.path.exists(path_train):
os.makedirs(path_train_im)
os.makedirs(path_test_im)
os.makedirs(path_train_lab)
os.makedirs(path_test_lab)
for i in range(n):
img.save(path_train_im / f'img_train{i}.png')
img.save(path_train_lab / f'img_train{i}.png')
img.save(path_test_im / f'img_test{i}.png')
img.save(path_test_lab / f'img_test{i}.png')
if remove:
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
os.rmdir(folder)
def stringify_path(path):
"""Converts an os.PathLike object to a string
"""
......
......@@ -10,7 +10,7 @@ from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True, return_loss = False):
""" Function for training Neural Network models.
Args:
......@@ -21,6 +21,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
Returns:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
......@@ -65,10 +66,11 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backpropagation
......@@ -94,7 +96,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
......@@ -122,6 +124,9 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
if return_loss:
return train_loss,val_loss
def model_summary(dataloader,model):
"""Prints the summary of a PyTorch model.
......@@ -196,7 +201,7 @@ def inference(data,model):
else:
raise ValueError("Input image must be (C,H,W) format")
model.to(device)
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment