diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index f62ffc2a7e2342363fa7e6313ecd9894fd70ba06..06fec0d0c41d0c0269ffc94291a77a8b63fa2b36 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -35,7 +35,7 @@ def test_grid_pred():
     in_targ_pred = qim3d.utils.models.inference(train_set,model)
 
     fig = qim3d.viz.grid_pred(in_targ_pred)
-
+    
     assert (fig.get_figwidth(),fig.get_figheight()) == (2*(n),10)
 
     temp_data(folder,remove = True)
@@ -45,7 +45,7 @@ def test_grid_pred():
 def test_slice_viz():
     example_volume = ones(10,10,10)
     img_width = 3
-    fig = qim3d.viz.slice_viz(example_volume,img_width = img_width)
+    fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width)
 
     assert fig.get_figwidth() == img_width
 
diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py
index a57880e7c52ede52247e7fe666f7abd8a40b68aa..19154b7010f92a893bae81f8ea8dbce48e227e4f 100644
--- a/qim3d/utils/augmentations.py
+++ b/qim3d/utils/augmentations.py
@@ -1,7 +1,6 @@
 """Class for choosing the level of data augmentations with albumentations"""
 import albumentations as A
 from albumentations.pytorch import ToTensorV2
-from qim3d.io.logger import log
 
 class Augmentation:
     """
diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py
index b8a4b6b7bcb86be7cb4496b576dd68381db57556..bc06e5f9285cdc2062903e9f8ec651484c7fb985 100644
--- a/qim3d/utils/internal_tools.py
+++ b/qim3d/utils/internal_tools.py
@@ -9,9 +9,10 @@ import numpy as np
 import socket
 import os
 import shutil
+
 from PIL import Image
 from pathlib import Path
-
+from qim3d.io.logger import log
 
 
 def mock_plot():
@@ -182,6 +183,21 @@ def is_server_running(ip, port):
         return False
 
 def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
+    """Creates a temporary folder to test deep learning tools.
+    
+    Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'.
+    n random images are then added to all four subfolders.
+    If the 'remove' variable is True, the folders and their content are removed.
+
+    Args:
+        folder (str): The path where the folders should be placed.
+        remove (bool, optional): If True, all folders are removed from their location.
+        n (int, optional): Number of random images and labels in the temporary dataset.
+        img_shape (tuple, options): Tuple with the height and width of the images and labels.
+
+    Example:
+        >>> tempdata('temporary_folder',n = 10, img_shape = (16,16))
+    """
     folder_trte = ['train','test']
     sub_folders = ['images','labels']
 
@@ -219,7 +235,7 @@ def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
                 elif os.path.isdir(file_path):
                     shutil.rmtree(file_path)
             except Exception as e:
-                print('Failed to delete %s. Reason: %s' % (file_path, e))
+                log.warning('Failed to delete %s. Reason: %s' % (file_path, e))
         
         os.rmdir(folder)
     
diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py
index 0294f14a0aa8bd8971f29d3ae6873ec55088685e..19a2a844fe225932e03b041c922328129aa771e8 100644
--- a/qim3d/utils/models.py
+++ b/qim3d/utils/models.py
@@ -15,17 +15,19 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
     
     Args:
         model (torch.nn.Module): PyTorch model.
-        hyperparameters (class): dictionary with n_epochs, optimizer and criterion.
+        hyperparameters (class): Dictionary with n_epochs, optimizer and criterion.
         train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
         val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
-        eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
-        print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
+        eval_every (int, optional): Frequency of model evaluation. Defaults to every epoch.
+        print_every (int, optional): Frequency of log for model performance. Defaults to every 5 epochs.
+        plot (bool, optional): If True, plots the training and validation loss after the model is done training.
+        return_loss (bool, optional), If True, returns a dictionary with the history of the train and validation losses.
         
-
     Returns:
-        tuple:
-            train_loss (dict): dictionary with average losses and batch losses for training loop.
-            val_loss (dict): dictionary with average losses and batch losses for validation loop.
+        if return_loss = True:
+            tuple:
+                train_loss (dict): Dictionary with average losses and batch losses for training loop.
+                val_loss (dict): Dictionary with average losses and batch losses for validation loop.
         
     Example:
         # defining the model.
@@ -141,7 +143,6 @@ def model_summary(dataloader,model):
         summary = model_summary(model, dataloader)
         print(summary)
     """
-    
     images,_ = next(iter(dataloader)) 
     batch_size = tuple(images.shape)
     model_s = summary(model,batch_size,depth = torch.inf)
diff --git a/qim3d/utils/system.py b/qim3d/utils/system.py
index 5fd6e8eed040d3a51de5219ce1188d74622932b8..383060df8c8fad875df4c4984325b7fb0f92ed6a 100644
--- a/qim3d/utils/system.py
+++ b/qim3d/utils/system.py
@@ -1,4 +1,4 @@
-"""Provides tools for obtaining informaion about the system."""
+"""Provides tools for obtaining information about the system."""
 import psutil
 from qim3d.utils.internal_tools import sizeof
 from qim3d.io.logger import log
diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py
index 3a30005db39ff6b34ea996ee3104cd5d7ff31ec4..83edde50b5c98af6a5b21cbd2b46e32dda8515a1 100644
--- a/qim3d/viz/img.py
+++ b/qim3d/viz/img.py
@@ -32,7 +32,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
         - The grid layout and dimensions vary based on the presence of a mask.
 
     Returns:
-        None
+        fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels.   
 
     Example:
         data = [(image1, label1, mask1), (image2, label2, mask2)]
@@ -117,7 +117,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
         show (bool, optional): If True, displays the plot. Defaults to False.
 
     Returns:
-        None
+        fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models.
 
     Raises:
         None
@@ -197,20 +197,31 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
 
     return fig
 
-def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2,show = False):
+def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False):
     """ Displays one or several slices from a 3d array.
 
+    By default if `position` is None, slice_viz plots an overview of the entire stack.
+    If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position.
+    If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted.
+    
     Args:
         input (str, numpy.ndarray): Path to the file or 3-dimensional array.
         position (str, int, list, array, optional): One or several slicing levels.
+        n_slices (int, optional): Defines how many slices the user wants.
         cmap (str, optional): Specifies the color map for the image.
         axis (bool, optional): Specifies whether the axes should be included.
+        img_height(int, optional): Height of the figure.
+        img_width(int, optional): Width of the figure.
         show (bool, optional): If True, displays the plot. Defaults to False.
 
+    Returns:
+        fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
+
     Raises:
+        ValueError: If the file or array is not a 3D volume.
         ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
-
-    Usage:
+    
+    Example:
         image_path = '/my_image_path/my_image.tif'
         slice_viz(image_path)
     """
@@ -218,50 +229,68 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
     # Filepath input
     if isinstance(input,str):
         vol = qim3d.io.load(input) # Function has its own ValueErrors
-        dim = vol.ndim
-        
-        if dim == 3:
-            pass
-        else:
-            raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
-            
+        dim = vol.ndim        
         
     # Numpy array input
     elif isinstance(input,(np.ndarray,torch.Tensor)):
+        vol = input
         dim = input.ndim
         
-        if dim == 3:
-            vol = input
-        else:
-            raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
+    if dim != 3:
+        raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
 
+    if position is None:
+        height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int)
+    
     # Position is a string
-    if isinstance(position,str):
+    elif isinstance(position,str):
+        
         if position.lower() in ['mid','middle']:
-            height = [int(vol.shape[0]/2)]
+            expansion_start = int(vol.shape[0]/2)
+            height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
+            
         elif position.lower() in ['top','upper', 'start']:
-            height = [0]
+            expansion_start = 0
+            height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int)
+            
         elif position.lower() in ['bot','bottom', 'end']:
-            height = [vol.shape[0]-1]
+            expansion_start = vol.shape[0]-1
+            height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int)
+        
         else:
             raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".')
+
     
     # Position is an integer
     elif isinstance(position,int):
-        height = [position]
+        expansion_start = position
+        n_stacks = vol.shape[0]-1
+
+        # if linspace would extend beyond n_stacks
+        if expansion_start + n_slices > n_stacks:
+            height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int)
+        
+        # if linspace would extend below 0 
+        elif expansion_start - n_slices < 0:
+            height = np.linspace(0,n_slices-1,n_slices).astype(int)
 
+        else:
+            height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
+
+    
     # Position is a list or array of integers
     elif isinstance(position,(list,np.ndarray)):
         height = position
 
     num_images = len(height)
-    
+
+
     fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True)
     axs = fig.subplots(nrows = 1, ncols = num_images)
     
     for col, ax in enumerate(np.atleast_1d(axs)):
         ax.imshow(vol[height[col],:,:],cmap = cmap)
-        ax.set_title(f'Slice {height[col]}', fontsize=8)
+        ax.set_title(f'Slice {height[col]}', fontsize=6*img_height)
         if not axis:
             ax.axis('off')
     
diff --git a/qim3d/viz/visualizations.py b/qim3d/viz/visualizations.py
index f412a2b60204c054fb968502092ea0d763a0c919..4d64b0ce5b397486c90d4e3af69363e4a333dc2c 100644
--- a/qim3d/viz/visualizations.py
+++ b/qim3d/viz/visualizations.py
@@ -22,15 +22,13 @@ def plot_metrics(*metrics,
         show (bool, optional): If True, displays the plot. Defaults to False.
 
     Returns:
-        if return_fig:
-            fig (matplotlib.figure.Figure): plot with metrics.
+        fig (matplotlib.figure.Figure): plot with metrics.
 
     Example:
         train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
         val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
         plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
     """
-    
     if labels == None:
         labels = [None]*len(metrics)
     elif len(metrics) != len(labels):