Skip to content
Snippets Groups Projects

Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.

Merged ofhkr requested to merge DL_unittests into main
5 files
+ 193
49
Compare changes
  • Side-by-side
  • Inline

Files

+ 37
23
import qim3d
from os import name
import pytest
from albumentations import Compose
from torch.utils.data.dataloader import DataLoader
from qim3d.utils.internal_tools import temp_data
# unit tests for Dataset()
def test_dataset():
if name == 'nt':
path = '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
else:
path = '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.utils.Dataset(folder)
assert images[0][0].shape == img_shape
images = qim3d.utils.Dataset(path)
temp_data(folder,remove=True)
assert images[0][0].shape == (256,256)
# unit tests for check_resize()
def test_check_resize():
@@ -30,35 +33,46 @@ def test_check_resize_fail():
with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."):
h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'crop',n_channels = 6)
# unit tests for prepare_datasets()
def test_prepare_datasets():
if name == 'nt':
path = '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
else:
path = '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
n = 3
validation = 1/3
folder = 'folder_data'
img = temp_data(folder,n = n)
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.utils.prepare_datasets(path,0.1,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation)
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
temp_data(folder,remove=True)
# unit test for validation in prepare_datasets()
def test_validation():
validation = 10
assert type(test_set.transform) == Compose
with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."):
augment_class = qim3d.utils.prepare_datasets('folder',validation,'my_model','my_augmentation')
# unit test for prepare_dataloaders()
def test_prepare_dataloaders():
batch_size = 1
if name == 'nt':
path = '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
else:
path = '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side'
folder = 'folder_data'
temp_data(folder)
batch_size = 1
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation()
train_set, val_set, test_set = qim3d.utils.prepare_datasets(path,0.1,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation)
train_loader,val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1,
pin_memory = False)
img,_ = next(iter(train_loader))
assert type(val_loader) == DataLoader
assert img.shape[0] == batch_size
\ No newline at end of file
temp_data(folder,remove=True)
\ No newline at end of file
Loading