Skip to content
Snippets Groups Projects
Commit e3826e01 authored by Oshkr's avatar Oshkr
Browse files

modification to tests where new dataset is involved

parent f9b1de88
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
......@@ -19,14 +19,14 @@ def test_augment():
def test_resize():
resize_str = 'not valid resize'
with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."):
with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'reshape' or 'padding'."):
augment_class = qim3d.utils.Augmentation(resize = resize_str)
def test_levels():
augment_class = qim3d.utils.Augmentation()
level = 'Not a valid level'
type = 'Not a valid level'
with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
augment_class.augment(256,256,level)
\ No newline at end of file
with pytest.raises(ValueError, match=f"Invalid transformation level: {type}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
augment_class.augment(256,256,type = type)
\ No newline at end of file
......@@ -4,19 +4,6 @@ import pytest
from torch.utils.data.dataloader import DataLoader
from qim3d.utils.internal_tools import temp_data
# unit tests for Dataset()
def test_dataset():
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.utils.Dataset(folder)
assert images[0][0].shape == img_shape
temp_data(folder,remove=True)
# unit tests for check_resize()
def test_check_resize():
h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
......@@ -38,13 +25,15 @@ def test_check_resize_fail():
def test_prepare_datasets():
n = 3
validation = 1/3
test = 0.1
folder = 'folder_data'
img = temp_data(folder,n = n)
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=validation,
train_folder="train",test_folder="test")
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
......@@ -67,8 +56,8 @@ def test_prepare_dataloaders():
batch_size = 1
my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation()
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation)
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1,
pin_memory = False)
......
......@@ -13,7 +13,8 @@ def test_model_summary():
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1,
......@@ -32,7 +33,8 @@ def test_inference():
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_, targ,_ = qim3d.utils.inference(train_set,unet)
......@@ -94,7 +96,8 @@ def test_train_model():
unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1,
pin_memory = False)
......
......@@ -30,7 +30,8 @@ def test_grid_pred():
model = qim3d.models.UNet()
augmentation = qim3d.utils.Augmentation()
train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,model,augmentation,
train_folder="train",test_folder="test")
in_targ_pred = qim3d.utils.models.inference(train_set,model)
......
......@@ -61,7 +61,9 @@ class Augmentation:
level = self.transform_validation
elif type=='test':
level = self.transform_test
else:
level = type
# Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment