diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py index da6c490d2706011a4dddfe274bb661e3dc303127..4e315c97283fead599a5081de36f91ca1ff02aa7 100644 --- a/qim3d/tests/utils/test_augmentations.py +++ b/qim3d/tests/utils/test_augmentations.py @@ -19,14 +19,14 @@ def test_augment(): def test_resize(): resize_str = 'not valid resize' - with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."): + with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'reshape' or 'padding'."): augment_class = qim3d.utils.Augmentation(resize = resize_str) def test_levels(): augment_class = qim3d.utils.Augmentation() - level = 'Not a valid level' + type = 'Not a valid level' - with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): - augment_class.augment(256,256,level) \ No newline at end of file + with pytest.raises(ValueError, match=f"Invalid transformation level: {type}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): + augment_class.augment(256,256,type = type) \ No newline at end of file diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py index 928ef26f2f785b7760c073f2519f37eb9ebb61f6..f78ed53e33884938a2007dfe02dfb600c3128ac3 100644 --- a/qim3d/tests/utils/test_data.py +++ b/qim3d/tests/utils/test_data.py @@ -4,19 +4,6 @@ import pytest from torch.utils.data.dataloader import DataLoader from qim3d.utils.internal_tools import temp_data -# unit tests for Dataset() -def test_dataset(): - img_shape = (32,32) - folder = 'folder_data' - temp_data(folder, img_shape = img_shape) - - images = qim3d.utils.Dataset(folder) - - assert images[0][0].shape == img_shape - - temp_data(folder,remove=True) - - # unit tests for check_resize() def test_check_resize(): h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6) @@ -38,13 +25,15 @@ def test_check_resize_fail(): def test_prepare_datasets(): n = 3 validation = 1/3 + test = 0.1 folder = 'folder_data' img = temp_data(folder,n = n) my_model = qim3d.models.UNet() my_augmentation = qim3d.utils.Augmentation(transform_test='light') - train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation) + train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=validation, + train_folder="train",test_folder="test") assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) @@ -67,8 +56,8 @@ def test_prepare_dataloaders(): batch_size = 1 my_model = qim3d.models.UNet() my_augmentation = qim3d.utils.Augmentation() - train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation) - + train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=1/3, + train_folder="train",test_folder="test") _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, batch_size,num_workers = 1, pin_memory = False) diff --git a/qim3d/tests/utils/test_models.py b/qim3d/tests/utils/test_models.py index 37262ad6517b3e603c972de6526e7219cccf2611..56ea94c8741d92838d0f47f8732adcd6a54dc380 100644 --- a/qim3d/tests/utils/test_models.py +++ b/qim3d/tests/utils/test_models.py @@ -13,7 +13,8 @@ def test_model_summary(): unet = qim3d.models.UNet(size = 'small') augment = qim3d.utils.Augmentation(transform_train=None) - train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3, + train_folder="train",test_folder="test") _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, batch_size = 1,num_workers = 1, @@ -32,7 +33,8 @@ def test_inference(): unet = qim3d.models.UNet(size = 'small') augment = qim3d.utils.Augmentation(transform_train=None) - train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + train_set,_,_ = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3, + train_folder="train",test_folder="test") _, targ,_ = qim3d.utils.inference(train_set,unet) @@ -94,7 +96,8 @@ def test_train_model(): unet = qim3d.models.UNet(size = 'small') augment = qim3d.utils.Augmentation(transform_train=None) hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs) - train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3, + train_folder="train",test_folder="test") train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, batch_size = 1,num_workers = 1, pin_memory = False) diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 06fec0d0c41d0c0269ffc94291a77a8b63fa2b36..26b48de22769eced35ce472cf266d1b9b5a52ab6 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -30,7 +30,8 @@ def test_grid_pred(): model = qim3d.models.UNet() augmentation = qim3d.utils.Augmentation() - train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation) + train_set,_,_ = qim3d.utils.prepare_datasets(folder,model,augmentation, + train_folder="train",test_folder="test") in_targ_pred = qim3d.utils.models.inference(train_set,model) diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py index aee677798b0bada9e8a9ad541673a64016cc499e..7539541826eff2e389342790debb10c75b13cab9 100644 --- a/qim3d/utils/augmentations.py +++ b/qim3d/utils/augmentations.py @@ -61,7 +61,9 @@ class Augmentation: level = self.transform_validation elif type=='test': level = self.transform_test - + else: + level = type + # Check if one of standard augmentation levels if level not in [None,'light','moderate','heavy']: raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")