diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py
index da6c490d2706011a4dddfe274bb661e3dc303127..4e315c97283fead599a5081de36f91ca1ff02aa7 100644
--- a/qim3d/tests/utils/test_augmentations.py
+++ b/qim3d/tests/utils/test_augmentations.py
@@ -19,14 +19,14 @@ def test_augment():
 def test_resize():
     resize_str = 'not valid resize'
 
-    with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."):
+    with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'reshape' or 'padding'."):
         augment_class = qim3d.utils.Augmentation(resize = resize_str)
 
 
 def test_levels():
     augment_class = qim3d.utils.Augmentation()
 
-    level = 'Not a valid level'
+    type = 'Not a valid level'
 
-    with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
-        augment_class.augment(256,256,level)
\ No newline at end of file
+    with pytest.raises(ValueError, match=f"Invalid transformation level: {type}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
+        augment_class.augment(256,256,type = type)
\ No newline at end of file
diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py
index 928ef26f2f785b7760c073f2519f37eb9ebb61f6..f78ed53e33884938a2007dfe02dfb600c3128ac3 100644
--- a/qim3d/tests/utils/test_data.py
+++ b/qim3d/tests/utils/test_data.py
@@ -4,19 +4,6 @@ import pytest
 from torch.utils.data.dataloader import DataLoader
 from qim3d.utils.internal_tools import temp_data
 
-# unit tests for Dataset()
-def test_dataset():
-    img_shape = (32,32)
-    folder = 'folder_data'
-    temp_data(folder, img_shape = img_shape)
-    
-    images = qim3d.utils.Dataset(folder)
-
-    assert images[0][0].shape == img_shape
-
-    temp_data(folder,remove=True)
-
-
 # unit tests for check_resize()
 def test_check_resize():
     h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
@@ -38,13 +25,15 @@ def test_check_resize_fail():
 def test_prepare_datasets():
     n = 3
     validation = 1/3
+    test = 0.1
     
     folder = 'folder_data'
     img = temp_data(folder,n = n)
 
     my_model = qim3d.models.UNet()
     my_augmentation = qim3d.utils.Augmentation(transform_test='light')
-    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation)
+    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=validation,
+                                                                train_folder="train",test_folder="test")
 
     assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
 
@@ -67,8 +56,8 @@ def test_prepare_dataloaders():
     batch_size = 1
     my_model = qim3d.models.UNet()
     my_augmentation = qim3d.utils.Augmentation()
-    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation)
-
+    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,my_model,my_augmentation,val_fraction=1/3,
+                                                                train_folder="train",test_folder="test")
     _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
                                                                            batch_size,num_workers = 1,
                                                                            pin_memory = False)
diff --git a/qim3d/tests/utils/test_models.py b/qim3d/tests/utils/test_models.py
index 37262ad6517b3e603c972de6526e7219cccf2611..56ea94c8741d92838d0f47f8732adcd6a54dc380 100644
--- a/qim3d/tests/utils/test_models.py
+++ b/qim3d/tests/utils/test_models.py
@@ -13,7 +13,8 @@ def test_model_summary():
 
     unet = qim3d.models.UNet(size = 'small')
     augment = qim3d.utils.Augmentation(transform_train=None)
-    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
+                                                              train_folder="train",test_folder="test")
 
     _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
                                                      batch_size = 1,num_workers = 1,
@@ -32,7 +33,8 @@ def test_inference():
 
     unet = qim3d.models.UNet(size = 'small')
     augment = qim3d.utils.Augmentation(transform_train=None)
-    train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+    train_set,_,_ = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
+                                                 train_folder="train",test_folder="test")
 
     _, targ,_ = qim3d.utils.inference(train_set,unet)
 
@@ -94,7 +96,8 @@ def test_train_model():
     unet = qim3d.models.UNet(size = 'small')
     augment = qim3d.utils.Augmentation(transform_train=None)
     hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
-    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,unet,augment,val_fraction=1/3,
+                                                              train_folder="train",test_folder="test")
     train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
                                                                 batch_size = 1,num_workers = 1,
                                                                 pin_memory = False)
diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index 06fec0d0c41d0c0269ffc94291a77a8b63fa2b36..26b48de22769eced35ce472cf266d1b9b5a52ab6 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -30,7 +30,8 @@ def test_grid_pred():
 
     model = qim3d.models.UNet()
     augmentation = qim3d.utils.Augmentation()
-    train_set,_,_ = qim3d.utils.prepare_datasets(folder,0.1,model,augmentation)
+    train_set,_,_ = qim3d.utils.prepare_datasets(folder,model,augmentation,
+                                                 train_folder="train",test_folder="test")
 
     in_targ_pred = qim3d.utils.models.inference(train_set,model)
 
diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py
index aee677798b0bada9e8a9ad541673a64016cc499e..7539541826eff2e389342790debb10c75b13cab9 100644
--- a/qim3d/utils/augmentations.py
+++ b/qim3d/utils/augmentations.py
@@ -61,7 +61,9 @@ class Augmentation:
             level = self.transform_validation
         elif type=='test':
             level = self.transform_test
-
+        else:
+            level = type
+        
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")