Skip to content
Snippets Groups Projects
Commit e3826e01 authored by Oshkr's avatar Oshkr
Browse files

modification to tests where new dataset is involved

parent f9b1de88
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
...@@ -19,14 +19,14 @@ def test_augment(): ...@@ -19,14 +19,14 @@ def test_augment():
def test_resize(): def test_resize():
resize_str = 'not valid resize' resize_str = 'not valid resize'
with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."): with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'reshape' or 'padding'."):
augment_class = qim3d.utils.Augmentation(resize = resize_str) augment_class = qim3d.utils.Augmentation(resize = resize_str)
def test_levels(): def test_levels():
augment_class = qim3d.utils.Augmentation() augment_class = qim3d.utils.Augmentation()
level = 'Not a valid level' type = 'Not a valid level'
with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): with pytest.raises(ValueError, match=f"Invalid transformation level: {type}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
augment_class.augment(256,256,level) augment_class.augment(256,256,type = type)
\ No newline at end of file \ No newline at end of file
...@@ -4,19 +4,6 @@ import pytest ...@@ -4,19 +4,6 @@ import pytest
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for Dataset()
def test_dataset():
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.utils.Dataset(folder)
assert images[0][0].shape == img_shape
temp_data(folder,remove=True)
# unit tests for check_resize() # unit tests for check_resize()
def test_check_resize(): def test_check_resize():
h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6) h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
...@@ -38,13 +25,15 @@ def test_check_resize_fail(): ...@@ -38,13 +25,15 @@ def test_check_resize_fail():
def test_prepare_datasets(): def test_prepare_datasets():
n = 3 n = 3
validation = 1/3 validation = 1/3
test = 0.1
folder = 'folder_data' folder = 'folder_data'
img = temp_data(folder,n = n) img = temp_data(folder,n = n)
my_model = qim3d.models.UNet() my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation(transform_test='light') my_augmentation = qim3d.utils.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation) train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=validation,
train_folder="train",test_folder="test")
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
...@@ -67,8 +56,8 @@ def test_prepare_dataloaders(): ...@@ -67,8 +56,8 @@ def test_prepare_dataloaders():
batch_size = 1 batch_size = 1
my_model = qim3d.models.UNet() my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation() my_augmentation = qim3d.utils.Augmentation()
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation) train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1, batch_size,num_workers = 1,
pin_memory = False) pin_memory = False)
......
...@@ -13,7 +13,8 @@ def test_model_summary(): ...@@ -13,7 +13,8 @@ def test_model_summary():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1, batch_size = 1,num_workers = 1,
...@@ -32,7 +33,8 @@ def test_inference(): ...@@ -32,7 +33,8 @@ def test_inference():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,_,_ = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_, targ,_ = qim3d.utils.inference(train_set,unet) _, targ,_ = qim3d.utils.inference(train_set,unet)
...@@ -94,7 +96,8 @@ def test_train_model(): ...@@ -94,7 +96,8 @@ def test_train_model():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs) hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1, batch_size = 1,num_workers = 1,
pin_memory = False) pin_memory = False)
......
...@@ -30,7 +30,8 @@ def test_grid_pred(): ...@@ -30,7 +30,8 @@ def test_grid_pred():
model = qim3d.models.UNet() model = qim3d.models.UNet()
augmentation = qim3d.utils.Augmentation() augmentation = qim3d.utils.Augmentation()
train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation) train_set,_,_ = qim3d.utils.prepare_datasets(folder,model,augmentation,
train_folder="train",test_folder="test")
in_targ_pred = qim3d.utils.models.inference(train_set,model) in_targ_pred = qim3d.utils.models.inference(train_set,model)
......
...@@ -61,6 +61,8 @@ class Augmentation: ...@@ -61,6 +61,8 @@ class Augmentation:
level = self.transform_validation level = self.transform_validation
elif type=='test': elif type=='test':
level = self.transform_test level = self.transform_test
else:
level = type
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light','moderate','heavy']:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment