diff --git a/docs/notebooks/Unet.ipynb b/docs/notebooks/Unet.ipynb
index 4b7a3d9e52aeb8466c67fab035daa131c97e5b2b..37d71cb90229f6636711354d793519de508452e2 100644
--- a/docs/notebooks/Unet.ipynb
+++ b/docs/notebooks/Unet.ipynb
@@ -20,6 +20,7 @@
    "source": [
     "from os.path import join\n",
     "import qim3d\n",
+    "import os\n",
     "\n",
     "%matplotlib inline"
    ]
@@ -35,14 +36,25 @@
     "def get_dataset_path(name: str, datasets):\n",
     "    assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)\n",
     "    dataset_idx = datasets.index(name)\n",
-    "    datasets_path = [\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n",
-    "        '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n",
-    "    ]\n",
+    "    if os.name == 'nt':\n",
+    "        datasets_path = [\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n",
+    "            '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n",
+    "        ]\n",
+    "    else:\n",
+    "        datasets_path = [\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n",
+    "            '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n",
+    "        ]\n",
+    "\n",
     "    return datasets_path[dataset_idx]"
    ]
   },
@@ -154,7 +166,7 @@
    "outputs": [],
    "source": [
     "# model hyperparameters\n",
-    "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=25,\n",
+    "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,\n",
     "                                                  learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)\n",
     "\n",
     "# training model\n",
@@ -197,7 +209,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.11"
+   "version": "3.11.6"
   }
  },
  "nbformat": 4,
diff --git a/qim3d/io/load.py b/qim3d/io/load.py
index 620a2eb81d10a0913677f3505a256c1370f1035b..3721c07463505752d669e64b37dc5058ef5b3a45 100644
--- a/qim3d/io/load.py
+++ b/qim3d/io/load.py
@@ -401,4 +401,4 @@ class ImgExamples:
 
         # Generate loader for each image found
         for idx, name in enumerate(img_names):
-            exec(f"self.{name} = qim3d.io.load(path = img_paths[idx])")
+            exec(f"self.{name} = qim3d.io.load(path = img_paths[idx])")
\ No newline at end of file
diff --git a/qim3d/tests/models/test_unet.py b/qim3d/tests/models/test_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b17605329af78f3b04cdcedd35de60608577f6
--- /dev/null
+++ b/qim3d/tests/models/test_unet.py
@@ -0,0 +1,33 @@
+import qim3d
+import torch
+
+# unit tests for UNet()
+def test_starting_unet():
+    unet = qim3d.models.UNet()
+
+    assert unet.size == 'medium'
+
+
+def test_forward_pass():
+    unet = qim3d.models.UNet()
+
+    # Size: B x C x H x W
+    x = torch.ones([1,1,256,256])
+
+    output = unet(x)
+    assert x.shape == output.shape
+
+# unit tests for Hyperparameters()
+def test_hyper():
+    unet = qim3d.models.UNet()
+    hyperparams = qim3d.models.Hyperparameters(unet)
+
+    assert hyperparams.n_epochs == 10
+
+def test_hyper_dict():
+    unet = qim3d.models.UNet()
+    hyperparams = qim3d.models.Hyperparameters(unet)
+
+    hyper_dict = hyperparams()
+
+    assert type(hyper_dict) == dict    
diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..da6c490d2706011a4dddfe274bb661e3dc303127
--- /dev/null
+++ b/qim3d/tests/utils/test_augmentations.py
@@ -0,0 +1,32 @@
+import qim3d
+import albumentations
+import pytest
+
+# unit tests for Augmentation()
+def test_augmentation():
+    augment_class = qim3d.utils.Augmentation()
+
+    assert augment_class.resize == 'crop'
+
+def test_augment():
+    augment_class = qim3d.utils.Augmentation()
+
+    album_augment = augment_class.augment(256,256)
+
+    assert type(album_augment) == albumentations.core.composition.Compose
+
+# unit tests for ValueErrors in Augmentation()
+def test_resize():
+    resize_str = 'not valid resize'
+
+    with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."):
+        augment_class = qim3d.utils.Augmentation(resize = resize_str)
+
+
+def test_levels():
+    augment_class = qim3d.utils.Augmentation()
+
+    level = 'Not a valid level'
+
+    with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."):
+        augment_class.augment(256,256,level)
\ No newline at end of file
diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..928ef26f2f785b7760c073f2519f37eb9ebb61f6
--- /dev/null
+++ b/qim3d/tests/utils/test_data.py
@@ -0,0 +1,78 @@
+import qim3d
+import pytest
+
+from torch.utils.data.dataloader import DataLoader
+from qim3d.utils.internal_tools import temp_data
+
+# unit tests for Dataset()
+def test_dataset():
+    img_shape = (32,32)
+    folder = 'folder_data'
+    temp_data(folder, img_shape = img_shape)
+    
+    images = qim3d.utils.Dataset(folder)
+
+    assert images[0][0].shape == img_shape
+
+    temp_data(folder,remove=True)
+
+
+# unit tests for check_resize()
+def test_check_resize():
+    h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6)
+
+    assert (h_adjust,w_adjust) == (192,192)
+
+def test_check_resize_pad():
+    h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'padding',n_channels = 6)
+
+    assert (h_adjust,w_adjust) == (64,64)
+
+def test_check_resize_fail():
+
+    with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."):
+        h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'crop',n_channels = 6)
+
+
+# unit tests for prepare_datasets()
+def test_prepare_datasets():
+    n = 3
+    validation = 1/3
+    
+    folder = 'folder_data'
+    img = temp_data(folder,n = n)
+
+    my_model = qim3d.models.UNet()
+    my_augmentation = qim3d.utils.Augmentation(transform_test='light')
+    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation)
+
+    assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
+
+    temp_data(folder,remove=True)
+
+
+# unit test for validation in prepare_datasets()
+def test_validation():
+    validation = 10
+    
+    with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."):
+        augment_class = qim3d.utils.prepare_datasets('folder',validation,'my_model','my_augmentation')
+
+
+# unit test for prepare_dataloaders()
+def test_prepare_dataloaders():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    batch_size = 1
+    my_model = qim3d.models.UNet()
+    my_augmentation = qim3d.utils.Augmentation()
+    train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation)
+
+    _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
+                                                                           batch_size,num_workers = 1,
+                                                                           pin_memory = False)
+    
+    assert type(val_loader) == DataLoader
+
+    temp_data(folder,remove=True)
\ No newline at end of file
diff --git a/qim3d/tests/utils/test_models.py b/qim3d/tests/utils/test_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..37262ad6517b3e603c972de6526e7219cccf2611
--- /dev/null
+++ b/qim3d/tests/utils/test_models.py
@@ -0,0 +1,107 @@
+import qim3d
+import pytest
+from torch import ones
+
+from qim3d.utils.internal_tools import temp_data
+
+# unit test for model summary()
+def test_model_summary():
+    n = 10
+    img_shape = (32,32)
+    folder = 'folder_data'
+    temp_data(folder,img_shape=img_shape,n = n)
+
+    unet = qim3d.models.UNet(size = 'small')
+    augment = qim3d.utils.Augmentation(transform_train=None)
+    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+
+    _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
+                                                     batch_size = 1,num_workers = 1,
+                                                     pin_memory = False)
+    summary = qim3d.utils.model_summary(val_loader,unet)
+
+    assert summary.input_size[0] == (1,1) + img_shape
+
+    temp_data(folder,remove=True)
+
+
+# unit test for inference()
+def test_inference():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    unet = qim3d.models.UNet(size = 'small')
+    augment = qim3d.utils.Augmentation(transform_train=None)
+    train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+
+    _, targ,_ = qim3d.utils.inference(train_set,unet)
+
+    assert tuple(targ[0].unique()) == (0,1)
+
+    temp_data(folder,remove=True)
+
+
+#unit test for tuple ValueError().
+def test_inference_tuple():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    unet = qim3d.models.UNet(size = 'small')
+
+    data = [1,2,3]
+    with pytest.raises(ValueError,match="Data items must be tuples"):
+        qim3d.utils.inference(data,unet)
+    
+    temp_data(folder,remove=True)
+
+
+#unit test for tensor ValueError().
+def test_inference_tensor():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    unet = qim3d.models.UNet(size = 'small')
+
+    data = [(1,2)]
+    with pytest.raises(ValueError,match="Data items must consist of tensors"):
+        qim3d.utils.inference(data,unet)
+    
+    temp_data(folder,remove=True)
+
+
+#unit test for dimension ValueError().
+def test_inference_dim():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    unet = qim3d.models.UNet(size = 'small')
+
+    data = [(ones(1),ones(1))]
+    # need the r"" for special characters
+    with pytest.raises(ValueError,match=r"Input image must be \(C,H,W\) format"):
+        qim3d.utils.inference(data,unet)
+    
+    temp_data(folder,remove=True)
+
+
+# unit test for train_model()
+def test_train_model():
+    folder = 'folder_data'
+    temp_data(folder)
+
+    n_epochs = 1
+
+    unet = qim3d.models.UNet(size = 'small')
+    augment = qim3d.utils.Augmentation(transform_train=None)
+    hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs)
+    train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment)
+    train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set,
+                                                                batch_size = 1,num_workers = 1,
+                                                                pin_memory = False)
+
+    train_loss,_ = qim3d.utils.train_model(unet,hyperparams,train_loader,val_loader,
+                                           plot = False, return_loss = True)
+
+    assert len(train_loss['loss']) == n_epochs
+
+    temp_data(folder,remove=True)
\ No newline at end of file
diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
index 5eb9ccab061ad3ca6225622987bbed4cb11491bf..332856e43248ffc083059ac93276697625a48de3 100644
--- a/qim3d/utils/data.py
+++ b/qim3d/utils/data.py
@@ -160,21 +160,21 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
     orig_h, orig_w = image.size[:2]
         
     final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels)
-    
+
     train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
     val_set   = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
     test_set  = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
 
     split_idx = int(np.floor(val_fraction * len(train_set)))
     indices = torch.randperm(len(train_set))
-    
+
     train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
     val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
     
     return train_set, val_set, test_set
 
 
-def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = True):  
+def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False):  
     """
     Prepares the dataloaders for model training.
 
diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py
index e31001266288c405f3fb1d9092341d4e51cf1e98..b8a4b6b7bcb86be7cb4496b576dd68381db57556 100644
--- a/qim3d/utils/internal_tools.py
+++ b/qim3d/utils/internal_tools.py
@@ -8,6 +8,9 @@ import matplotlib
 import numpy as np
 import socket
 import os
+import shutil
+from PIL import Image
+from pathlib import Path
 
 
 
@@ -177,10 +180,52 @@ def is_server_running(ip, port):
         return True
     except:
         return False
+
+def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
+    folder_trte = ['train','test']
+    sub_folders = ['images','labels']
+
+    # Creating train/test folder
+    path_train = Path(folder) / folder_trte[0]
+    path_test = Path(folder) / folder_trte[1]
+
+    # Creating folders for images and labels
+    path_train_im = path_train / sub_folders[0]
+    path_train_lab = path_train / sub_folders[1]
+    path_test_im = path_test / sub_folders[0]
+    path_test_lab = path_test / sub_folders[1]
+
+    # Random image
+    img = np.random.randint(2,size = img_shape,dtype = np.uint8)
+    img = Image.fromarray(img)
+
+    if not os.path.exists(path_train):
+        os.makedirs(path_train_im)
+        os.makedirs(path_test_im)
+        os.makedirs(path_train_lab)
+        os.makedirs(path_test_lab)
+        for i in range(n):
+            img.save(path_train_im / f'img_train{i}.png')
+            img.save(path_train_lab / f'img_train{i}.png')
+            img.save(path_test_im / f'img_test{i}.png')
+            img.save(path_test_lab / f'img_test{i}.png')
+
+    if remove:
+        for filename in os.listdir(folder):
+            file_path = os.path.join(folder, filename)
+            try:
+                if os.path.isfile(file_path) or os.path.islink(file_path):
+                    os.unlink(file_path)
+                elif os.path.isdir(file_path):
+                    shutil.rmtree(file_path)
+            except Exception as e:
+                print('Failed to delete %s. Reason: %s' % (file_path, e))
+        
+        os.rmdir(folder)
     
 def stringify_path(path):
     """Converts an os.PathLike object to a string
     """
     if isinstance(path,os.PathLike):
         path = path.__fspath__()
-    return path
\ No newline at end of file
+    return path
diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py
index a011fe77ac42846ce963ad4a60c38ba9d52f2d0b..f0693186aecb2634a3d2d4113788bad74b6010ef 100644
--- a/qim3d/utils/models.py
+++ b/qim3d/utils/models.py
@@ -10,7 +10,7 @@ from tqdm.auto import tqdm
 from tqdm.contrib.logging import logging_redirect_tqdm
 
 
-def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
+def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True, return_loss = False):
     """ Function for training Neural Network models.
     
     Args:
@@ -20,6 +20,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
         val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
         eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
         print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
+        
 
     Returns:
         tuple:
@@ -65,10 +66,11 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
             for data in train_loader:
                 inputs, targets = data
                 inputs = inputs.to(device)
-                targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
+                targets = targets.to(device).unsqueeze(1)
     
                 optimizer.zero_grad()
                 outputs = model(inputs)
+                
                 loss = criterion(outputs, targets)
     
                 # Backpropagation
@@ -94,8 +96,8 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
                 for data in val_loader:
                     inputs, targets = data
                     inputs = inputs.to(device)
-                    targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
-                
+                    targets = targets.to(device).unsqueeze(1)
+                    
                     with torch.no_grad():
                         outputs = model(inputs)
                         loss = criterion(outputs, targets)
@@ -122,6 +124,9 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
         plot_metrics(val_loss,color = 'orange', label = 'Valid.')
         fig.show()
 
+    if return_loss:
+        return train_loss,val_loss
+
 
 def model_summary(dataloader,model):
     """Prints the summary of a PyTorch model.
@@ -196,7 +201,7 @@ def inference(data,model):
     else:
         raise ValueError("Input image must be (C,H,W) format")
 
-    
+    model.to(device)
     model.eval()
 
     # Make new list such that possible augmentations remain identical for all three rows