Skip to content
Snippets Groups Projects
Commit e3826e01 authored by Oshkr's avatar Oshkr
Browse files

modification to tests where new dataset is involved

parent f9b1de88
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
...@@ -19,14 +19,14 @@ def test_augment(): ...@@ -19,14 +19,14 @@ def test_augment():
def test_resize(): def test_resize():
resize_str = 'not valid resize' resize_str = 'not valid resize'
with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."): with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'reshape' or 'padding'."):
augment_class = qim3d.utils.Augmentation(resize = resize_str) augment_class = qim3d.utils.Augmentation(resize = resize_str)
def test_levels(): def test_levels():
augment_class = qim3d.utils.Augmentation() augment_class = qim3d.utils.Augmentation()
level = 'Not a valid level' type = 'Not a valid level'
with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): with pytest.raises(ValueError, match=f"Invalid transformation level: {type}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
augment_class.augment(256,256,level) augment_class.augment(256,256,type = type)
\ No newline at end of file \ No newline at end of file
...@@ -4,19 +4,6 @@ import pytest ...@@ -4,19 +4,6 @@ import pytest
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from qim3d.utils.internal_tools import temp_data from qim3d.utils.internal_tools import temp_data
# unit tests for Dataset()
def test_dataset():
img_shape = (32,32)
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.utils.Dataset(folder)
assert images[0][0].shape == img_shape
temp_data(folder,remove=True)
# unit tests for check_resize() # unit tests for check_resize()
def test_check_resize(): def test_check_resize():
h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6) h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
...@@ -38,13 +25,15 @@ def test_check_resize_fail(): ...@@ -38,13 +25,15 @@ def test_check_resize_fail():
def test_prepare_datasets(): def test_prepare_datasets():
n = 3 n = 3
validation = 1/3 validation = 1/3
test = 0.1
folder = 'folder_data' folder = 'folder_data'
img = temp_data(folder,n = n) img = temp_data(folder,n = n)
my_model = qim3d.models.UNet() my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation(transform_test='light') my_augmentation = qim3d.utils.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation) train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=validation,
train_folder="train",test_folder="test")
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
...@@ -67,8 +56,8 @@ def test_prepare_dataloaders(): ...@@ -67,8 +56,8 @@ def test_prepare_dataloaders():
batch_size = 1 batch_size = 1
my_model = qim3d.models.UNet() my_model = qim3d.models.UNet()
my_augmentation = qim3d.utils.Augmentation() my_augmentation = qim3d.utils.Augmentation()
train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation) train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1, batch_size,num_workers = 1,
pin_memory = False) pin_memory = False)
......
...@@ -13,7 +13,8 @@ def test_model_summary(): ...@@ -13,7 +13,8 @@ def test_model_summary():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1, batch_size = 1,num_workers = 1,
...@@ -32,7 +33,8 @@ def test_inference(): ...@@ -32,7 +33,8 @@ def test_inference():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,_,_ = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
_, targ,_ = qim3d.utils.inference(train_set,unet) _, targ,_ = qim3d.utils.inference(train_set,unet)
...@@ -94,7 +96,8 @@ def test_train_model(): ...@@ -94,7 +96,8 @@ def test_train_model():
unet = qim3d.models.UNet(size = 'small') unet = qim3d.models.UNet(size = 'small')
augment = qim3d.utils.Augmentation(transform_train=None) augment = qim3d.utils.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs) hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
train_folder="train",test_folder="test")
train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
batch_size = 1,num_workers = 1, batch_size = 1,num_workers = 1,
pin_memory = False) pin_memory = False)
......
...@@ -30,7 +30,8 @@ def test_grid_pred(): ...@@ -30,7 +30,8 @@ def test_grid_pred():
model = qim3d.models.UNet() model = qim3d.models.UNet()
augmentation = qim3d.utils.Augmentation() augmentation = qim3d.utils.Augmentation()
train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation) train_set,_,_ = qim3d.utils.prepare_datasets(folder,model,augmentation,
train_folder="train",test_folder="test")
in_targ_pred = qim3d.utils.models.inference(train_set,model) in_targ_pred = qim3d.utils.models.inference(train_set,model)
......
...@@ -61,6 +61,8 @@ class Augmentation: ...@@ -61,6 +61,8 @@ class Augmentation:
level = self.transform_validation level = self.transform_validation
elif type=='test': elif type=='test':
level = self.transform_test level = self.transform_test
else:
level = type
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light','moderate','heavy']:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment