diff --git a/qim3d/detection/_common_detection_methods.py b/qim3d/detection/_common_detection_methods.py
index 0a985d8543eacd858e968963b4a6c8491a011b4d..131ec1e45e40ebcd15c83de260938bcdbd360c59 100644
--- a/qim3d/detection/_common_detection_methods.py
+++ b/qim3d/detection/_common_detection_methods.py
@@ -15,7 +15,7 @@ def blobs(
     overlap: float = 0.5,
     threshold_rel: float = None,
     exclude_border: bool = False,
-) -> np.ndarray:
+) -> tuple[np.ndarray, np.ndarray]:
     """
     Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True
 
diff --git a/qim3d/features/_common_features_methods.py b/qim3d/features/_common_features_methods.py
index e69b6cba63867c27cbedeba05999e201917f77b7..63838ef17634d493fd5de0ec4e806d4568200918 100644
--- a/qim3d/features/_common_features_methods.py
+++ b/qim3d/features/_common_features_methods.py
@@ -5,7 +5,9 @@ import trimesh
 import qim3d
 
 
-def volume(obj, **mesh_kwargs) -> float:
+def volume(obj: np.ndarray|trimesh.Trimesh, 
+           **mesh_kwargs
+           ) -> float:
     """
     Compute the volume of a 3D volume or mesh.
 
@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
     return obj.volume
 
 
-def area(obj, **mesh_kwargs) -> float:
+def area(obj: np.ndarray|trimesh.Trimesh, 
+         **mesh_kwargs
+         ) -> float:
     """
     Compute the surface area of a 3D volume or mesh.
 
@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
     return obj.area
 
 
-def sphericity(obj, **mesh_kwargs) -> float:
+def sphericity(obj: np.ndarray|trimesh.Trimesh, 
+               **mesh_kwargs
+               ) -> float:
     """
     Compute the sphericity of a 3D volume or mesh.
 
diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py
index f2d421979259e1f676a54f5ec9512cf2cb6c364e..59b0be1291ff2f3480b44dc3df524c9794a2c6ed 100644
--- a/qim3d/filters/_common_filter_methods.py
+++ b/qim3d/filters/_common_filter_methods.py
@@ -26,7 +26,10 @@ __all__ = [
 
 
 class FilterBase:
-    def __init__(self, dask=False, chunks="auto", *args, **kwargs):
+    def __init__(self, 
+                 dask: bool = False, 
+                 chunks: str = "auto", 
+                 *args, **kwargs):
         """
         Base class for image filters.
 
@@ -40,7 +43,7 @@ class FilterBase:
         self.kwargs = kwargs
 
 class Gaussian(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a Gaussian filter to the input.
 
@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
 
 
 class Median(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a median filter to the input.
 
@@ -68,7 +71,7 @@ class Median(FilterBase):
 
 
 class Maximum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a maximum filter to the input.
 
@@ -82,7 +85,7 @@ class Maximum(FilterBase):
 
 
 class Minimum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a minimum filter to the input.
 
@@ -95,7 +98,7 @@ class Minimum(FilterBase):
         return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
 
 class Tophat(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a tophat filter to the input.
 
@@ -210,7 +213,10 @@ class Pipeline:
         return input
 
 
-def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
+def gaussian(vol: np.ndarray, 
+             dask: bool = False, 
+             chunks: str = 'auto',
+             *args, **kwargs) -> np.ndarray:
     """
     Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
 
@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
         return res
 
 
-def median(vol, dask=False, chunks='auto', **kwargs):
+def median(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str ='auto', 
+           **kwargs) -> np.ndarray:
     """
     Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
 
@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def maximum(vol, dask=False, chunks='auto', **kwargs):
+def maximum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
 
@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def minimum(vol, dask=False, chunks='auto', **kwargs):
+def minimum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
 
@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
         res = ndimage.minimum_filter(vol, **kwargs)
         return res
 
-def tophat(vol, dask=False, chunks='auto', **kwargs):
+def tophat(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str = 'auto', 
+           **kwargs) -> np.ndarray:
     """
     Remove background from the volume.
 
diff --git a/qim3d/generate/_aggregators.py b/qim3d/generate/_aggregators.py
index 9d10df37e4efb270ac539e7ff4c3f22cf7cd6b06..262b047ef703ef379432f4a18441617cf5475613 100644
--- a/qim3d/generate/_aggregators.py
+++ b/qim3d/generate/_aggregators.py
@@ -142,7 +142,7 @@ def noise_object_collection(
     object_shape: str = None,
     seed: int = 0,
     verbose: bool = False,
-) -> tuple[np.ndarray, object]:
+) -> tuple[np.ndarray, np.ndarray]:
     """
     Generate a 3D volume of multiple synthetic objects using Perlin noise.
 
diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py
index 4498e1ad2fda3201efec9937d997232ec98db3b5..5ebae5f76cefe4d983f9dd1ac1dd24e5127feb62 100644
--- a/qim3d/gui/annotation_tool.py
+++ b/qim3d/gui/annotation_tool.py
@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
 
 
 class Interface(BaseInterface):
-    def __init__(self, name_suffix: str = "", verbose: bool = False, img=None):
+    def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
         super().__init__(
             title="Annotation Tool",
             height=768,
@@ -55,7 +55,7 @@ class Interface(BaseInterface):
         self.masks_rgb = None
         self.temp_files = []
 
-    def get_result(self):
+    def get_result(self) -> dict:
         # Get the temporary files from gradio
         temp_path_list = []
         for filename in os.listdir(self.temp_dir):
@@ -95,13 +95,13 @@ class Interface(BaseInterface):
         except FileNotFoundError:
             files = None
 
-    def create_preview(self, img_editor):
+    def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
         background = img_editor["background"]
         masks = img_editor["layers"][0]
         overlay_image = overlay_rgb_images(background, masks)
         return overlay_image
 
-    def cerate_download_list(self, img_editor):
+    def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
         masks_rgb = img_editor["layers"][0]
         mask_threshold = 200  # This value is based
 
diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py
index ecabeb81135bdbf1d4b0f64c184be62af3dbbf65..ee3ba16d5a90f3facbebc8580b906e52c745bc50 100644
--- a/qim3d/gui/data_explorer.py
+++ b/qim3d/gui/data_explorer.py
@@ -20,6 +20,7 @@ import os
 import re
 
 import gradio as gr
+import matplotlib.figure
 import matplotlib.pyplot as plt
 import numpy as np
 import outputformat as ouf
@@ -29,6 +30,8 @@ from qim3d.utils._logger import log
 from qim3d.utils import _misc
 
 from qim3d.gui.interface import BaseInterface
+from typing import Callable, Any, Dict
+import matplotlib
 
 
 class Interface(BaseInterface):
@@ -271,7 +274,7 @@ class Interface(BaseInterface):
         operations.change(fn=self.show_results, inputs = operations, outputs = results)
         cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
 
-    def update_explorer(self, new_path):
+    def update_explorer(self, new_path: str):
         new_path = os.path.expanduser(new_path)
 
         # In case we have a directory
@@ -367,7 +370,7 @@ class Interface(BaseInterface):
         except Exception as error_message:
             self.error_message = F"Error when loading data: {error_message}"
     
-    def run_operations(self, operations, *args):
+    def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
         outputs = []
         self.calculated_operations = []
         for operation in self.all_operations:
@@ -411,7 +414,7 @@ class Interface(BaseInterface):
             case _:
                 raise NotImplementedError(F"Operation '{operation} is not defined")
 
-    def show_results(self, operations):
+    def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
         update_list = []
         for operation in self.all_operations:
             if operation in operations and operation in self.calculated_operations:
@@ -426,7 +429,7 @@ class Interface(BaseInterface):
 #
 #######################################################
 
-    def create_img_fig(self, img, **kwargs):
+    def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
         fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
 
         ax.imshow(img, interpolation="nearest", **kwargs)
@@ -437,8 +440,8 @@ class Interface(BaseInterface):
 
         return fig
 
-    def update_slice_wrapper(self, letter):
-        def update_slice(position_slider:float, cmap:str):
+    def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
+        def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]:
             """
             position_slider: float from gradio slider, saying which relative slice we want to see
             cmap: string gradio drop down menu, saying what cmap we want to use for display
@@ -465,7 +468,7 @@ class Interface(BaseInterface):
             return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
         return update_slice
     
-    def vol_histogram(self, nbins, min_value, max_value):
+    def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
         # Start histogram
         vol_hist = np.zeros(nbins)
 
@@ -478,7 +481,7 @@ class Interface(BaseInterface):
 
         return vol_hist, bin_edges
 
-    def plot_histogram(self):
+    def plot_histogram(self) -> matplotlib.figure.Figure:
         # The Histogram needs results from the projections
         if not self.projections_calculated:
             _ = self.get_projections()
@@ -498,7 +501,7 @@ class Interface(BaseInterface):
 
         return fig
     
-    def create_projections_figs(self):
+    def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
         if not self.projections_calculated:
             projections = self.get_projections()
             self.max_projection = projections[0]
@@ -519,7 +522,7 @@ class Interface(BaseInterface):
         self.projections_calculated = True
         return max_projection_fig, min_projection_fig
 
-    def get_projections(self):
+    def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
         # Create arrays for iteration
         max_projection = np.zeros(np.shape(self.vol[0]))
         min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py
index c559685095cfb69ecf950fcf37e961c7ad0c6bda..36d820dde8ab24ac13c94590ae1f69c542f1a841 100644
--- a/qim3d/gui/interface.py
+++ b/qim3d/gui/interface.py
@@ -6,6 +6,7 @@ import gradio as gr
 
 from .qim_theme import QimTheme
 import qim3d.gui
+import numpy as np
 
 
 # TODO: when offline it throws an error in cli
@@ -48,10 +49,10 @@ class BaseInterface(ABC):
     def set_invisible(self):
         return gr.update(visible=False)
     
-    def change_visibility(self, is_visible):
+    def change_visibility(self, is_visible: bool):
         return gr.update(visible = is_visible)
 
-    def launch(self, img=None, force_light_mode: bool = True, **kwargs):
+    def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
         """
         img: If None, user can upload image after the interface is launched.
             If defined, the interface will be launched with the image already there
@@ -76,7 +77,7 @@ class BaseInterface(ABC):
             **kwargs,
         )
 
-    def clear(self):
+    def clear(self) -> None:
         """Used to reset outputs with the clear button"""
         return None
 
diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py
index 1a20f91ea093a1c65daa0aefd73117445289c588..c7d4620ff1ebdf0df84255cf9a733de047b54d9e 100644
--- a/qim3d/gui/iso3d.py
+++ b/qim3d/gui/iso3d.py
@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
         self.img = img
         self.plot_height = plot_height
 
-    def load_data(self, gradiofile):
+    def load_data(self, gradiofile: gr.File):
         try:
             self.vol = load(gradiofile.name)
             assert self.vol.ndim == 3
@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
         except AssertionError:
             raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
 
-    def resize_vol(self, display_size):
+    def resize_vol(self, display_size: int):
         """Resizes the loaded volume to the display size"""
 
         # Get original size
@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
                 f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
             )
 
-    def save_fig(self, fig, filename):
+    def save_fig(self, fig: go.Figure, filename: str):
         # Write Plotly figure to disk
         fig.write_html(filename)
 
     def create_fig(self, 
-        gradio_file,
-        display_size,
-        opacity,
-        opacityscale,
-        only_wireframe,
-        min_value,
-        max_value,
-        surface_count,
-        colormap,
-        show_colorbar,
-        reversescale,
-        flip_z,
-        show_axis,
-        show_ticks,
-        show_caps,
-        show_z_slice,
-        slice_z_location,
-        show_y_slice,
-        slice_y_location,
-        show_x_slice,
-        slice_x_location,):
+        gradio_file: gr.File,
+        display_size: int ,
+        opacity: float,
+        opacityscale: str,
+        only_wireframe: bool,
+        min_value: float,
+        max_value: float,
+        surface_count: int,
+        colormap: str,
+        show_colorbar: bool,
+        reversescale: bool,
+        flip_z: bool,
+        show_axis: bool,
+        show_ticks: bool,
+        show_caps: bool,
+        show_z_slice: bool,
+        slice_z_location: int,
+        show_y_slice: bool,
+        slice_y_location: int,
+        show_x_slice: bool,
+        slice_x_location: int,
+        ) -> tuple[go.Figure, str]:
 
         # Load volume
         self.load_data(gradio_file)
diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py
index 26d9639f889a4323460ccd94580385d57493e718..febd16a506bcba477eed1d0f42ba6e20b763d610 100644
--- a/qim3d/gui/layers2d.py
+++ b/qim3d/gui/layers2d.py
@@ -29,6 +29,7 @@ from qim3d.processing import segment_layers, get_lines
 from qim3d.operations import overlay_rgb_images
 from qim3d.io import load
 from qim3d.viz._layers2d import image_with_lines
+from typing import Dict, Any
 
 #TODO figure out how not update anything and go through processing when there are no data loaded
 # So user could play with the widgets but it doesnt throw error
@@ -303,14 +304,14 @@ class Interface(BaseInterface):
             
         
 
-    def change_plot_type(self, plot_type, ):
+    def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
         self.plot_type = plot_type
         if plot_type == 'Segmentation lines':
             return gr.update(visible = False), gr.update(visible = True)
         else:  
             return gr.update(visible = True), gr.update(visible = False)
         
-    def change_plot_size(self, x_check, y_check, z_check):
+    def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
         """
         Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
         also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
@@ -320,10 +321,10 @@ class Interface(BaseInterface):
         height = self.heights[index] # also used to define heights of plots in the begining
         return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
 
-    def change_row_visibility(self, x_check, y_check, z_check):
+    def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
         return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
     
-    def update_explorer(self, new_path):
+    def update_explorer(self, new_path: str):
         # Refresh the file explorer object
         new_path = os.path.expanduser(new_path)
 
@@ -342,13 +343,13 @@ class Interface(BaseInterface):
     def set_relaunch_button(self):
         return gr.update(value=f"Relaunch", interactive=True)
 
-    def set_spinner(self, message):
+    def set_spinner(self, message: str):
         if self.error:
             return gr.Button()
         # spinner icon/shows the user something is happeing
         return gr.update(value=f"{message}", interactive=False)
     
-    def load_data(self, base_path, explorer):
+    def load_data(self, base_path: str, explorer: str):
         if base_path and os.path.isfile(base_path):
             file_path = base_path
         elif explorer and os.path.isfile(explorer):
diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py
index 0c76f1aad4181b5e23d5e2e054009075778f71e4..461f3e1774fedb3ac61e0f916b3c72bdafff6f92 100644
--- a/qim3d/gui/local_thickness.py
+++ b/qim3d/gui/local_thickness.py
@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples
 
 class Interface(InterfaceWithExamples):
     def __init__(self,
-                 img = None,
+                 img: np.ndarray = None,
                  verbose:bool = False,
                  plot_height:int = 768,
                  figsize:int = 6): 
@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
     #
     #######################################################
 
-    def process_input(self, data, dark_objects):
+    def process_input(self, data: np.ndarray, dark_objects: bool):
         # Load volume
         try:
             self.vol = load(data.name)
@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
         self.vmin = np.min(self.vol)
         self.vmax = np.max(self.vol)
 
-    def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"):
+    def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
         plt.close()
         z_idx = int(zpos * (vol.shape[0] - 1))
         fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
 
         return fig
 
-    def make_binary(self, threshold):
+    def make_binary(self, threshold: float):
         # Make a binary volume
         # Nothing fancy, but we could add new features here
         self.vol_binary = self.vol > (threshold * np.max(self.vol))
     
-    def compute_localthickness(self, lt_scale):
+    def compute_localthickness(self, lt_scale: float):
         self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
 
         # Valus for visualization
         self.vmin_lt = np.min(self.vol_thickness)
         self.vmax_lt = np.max(self.vol_thickness)
 
-    def thickness_histogram(self, nbins):
+    def thickness_histogram(self, nbins: int):
         # Ignore zero thickness
         non_zero_values = self.vol_thickness[self.vol_thickness > 0]
 
diff --git a/qim3d/gui/qim_theme.py b/qim3d/gui/qim_theme.py
index f13594e4c03bcac851d4e5416e8675cd3fd5f837..5cdc844b695a0fb64e6ff6588ea6b48f0115088e 100644
--- a/qim3d/gui/qim_theme.py
+++ b/qim3d/gui/qim_theme.py
@@ -7,7 +7,7 @@ class QimTheme(gr.themes.Default):
     there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
     in comments.
     """
-    def __init__(self, force_light_mode:bool = True):
+    def __init__(self, force_light_mode: bool = True):
         """
         Parameters:
         -----------
diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py
index 4ab6c0a99e71eff51949309359dd42c1276f96dd..0d776822689430b733dfa1467ecdff0b48d87bb1 100644
--- a/qim3d/io/_convert.py
+++ b/qim3d/io/_convert.py
@@ -7,6 +7,7 @@ import numpy as np
 import tifffile as tiff
 import zarr
 from tqdm import tqdm
+import zarr.core
 
 from qim3d.utils._misc import stringify_path
 from qim3d.io._saving import save
@@ -21,7 +22,7 @@ class Convert:
         """
         self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
 
-    def convert(self, input_path, output_path):
+    def convert(self, input_path: str, output_path: str):
         def get_file_extension(file_path):
             root, ext = os.path.splitext(file_path)
             if ext in ['.gz', '.bz2', '.xz']:  # handle common compressed extensions
@@ -67,7 +68,7 @@ class Convert:
             else:
                 raise ValueError("Invalid path")
 
-    def convert_tif_to_zarr(self, tif_path, zarr_path):
+    def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a tiff file to a zarr file
 
         Args:
@@ -97,7 +98,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_tif(self, zarr_path, tif_path):
+    def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
         """Convert a zarr file to a tiff file
 
         Args:
@@ -110,7 +111,7 @@ class Convert:
         z = zarr.open(zarr_path)
         save(tif_path, z)
 
-    def convert_nifti_to_zarr(self, nifti_path, zarr_path):
+    def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a nifti file to a zarr file
 
         Args:
@@ -139,7 +140,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_nifti(self, zarr_path, nifti_path, compression=False):
+    def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
         """Convert a zarr file to a nifti file
 
         Args:
@@ -153,7 +154,7 @@ class Convert:
         save(nifti_path, z, compression=compression)
         
 
-def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
+def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
     """Convert a file to another format without loading the entire file into memory
 
     Args:
diff --git a/qim3d/io/_downloader.py b/qim3d/io/_downloader.py
index fe42e410c120b5ab6efe77037d1071579add55de..18e36ca591ad2bf0859cc2b05f18d0033de3f8ee 100644
--- a/qim3d/io/_downloader.py
+++ b/qim3d/io/_downloader.py
@@ -76,7 +76,7 @@ class Downloader:
             [file_name_n](load_file,optional): Function to download file number n in the given folder.
         """
 
-        def __init__(self, folder):
+        def __init__(self, folder: str):
             files = _extract_names(folder)
 
             for idx, file in enumerate(files):
@@ -88,7 +88,7 @@ class Downloader:
 
                 setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
 
-        def _make_fn(self, folder, file):
+        def _make_fn(self, folder: str, file: str):
             """Private method that returns a function. The function downloads the chosen file from the folder.
 
             Args:
@@ -101,7 +101,7 @@ class Downloader:
 
             url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
 
-            def _download(load_file=False, virtual_stack=True):
+            def _download(load_file: bool = False, virtual_stack: bool = True):
                 """Downloads the file and optionally also loads it.
 
                 Args:
@@ -121,7 +121,7 @@ class Downloader:
             return _download
 
 
-def _update_progress(pbar, blocknum, bs):
+def _update_progress(pbar: tqdm, blocknum: int, bs: int):
     """
     Helper function for the ´download_file()´ function. Updates the progress bar.
     """
@@ -129,7 +129,7 @@ def _update_progress(pbar, blocknum, bs):
     pbar.update(blocknum * bs - pbar.n)
 
 
-def _get_file_size(url):
+def _get_file_size(url: str):
     """
     Helper function for the ´download_file()´ function. Finds the size of the file.
     """
@@ -137,7 +137,7 @@ def _get_file_size(url):
     return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
 
 
-def download_file(path, name, file):
+def download_file(path: str, name: str, file: str):
     """Downloads the file from path / name / file.
 
     Args:
@@ -177,7 +177,7 @@ def download_file(path, name, file):
         )
 
 
-def _extract_html(url):
+def _extract_html(url: str):
     """Extracts the html content of a webpage in "utf-8"
 
     Args:
@@ -198,7 +198,7 @@ def _extract_html(url):
     return html_content
 
 
-def _extract_names(name=None):
+def _extract_names(name: str = None):
     """Extracts the names of the folders and files.
 
     Finds the names of either the folders if no name is given,
diff --git a/qim3d/io/_loading.py b/qim3d/io/_loading.py
index e5f2191e8f71fd276e9b9460d6dabf4c6e8d4995..e3d7d45bc220a04c0e96e4efd5fbb2f94cc3c99a 100644
--- a/qim3d/io/_loading.py
+++ b/qim3d/io/_loading.py
@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
 from qim3d.utils._progress_bar import FileLoadingProgressBar
 import trimesh
 
+from typing import Optional, Dict
+
 dask.config.set(scheduler="processes") 
 
 
@@ -76,7 +78,7 @@ class DataLoader:
         self.dim_order = kwargs.get("dim_order", (2, 1, 0))
         self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
 
-    def load_tiff(self, path):
+    def load_tiff(self, path: str|os.PathLike):
         """Load a TIFF file from the specified path.
 
         Args:
@@ -100,7 +102,7 @@ class DataLoader:
 
         return vol
 
-    def load_h5(self, path):
+    def load_h5(self, path: str|os.PathLike) -> tuple[np.ndarray, Optional[Dict]]:
         """Load an HDF5 file from the specified path.
 
         Args:
@@ -183,7 +185,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_tiff_stack(self, path):
+    def load_tiff_stack(self, path: str|os.PathLike) -> np.ndarray|np.memmap:
         """Load a stack of TIFF files from the specified path.
 
         Args:
@@ -237,7 +239,7 @@ class DataLoader:
 
         return vol
 
-    def load_txrm(self, path):
+    def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]:
         """Load a TXRM/XRM/TXM file from the specified path.
 
         Args:
@@ -308,7 +310,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_nifti(self, path):
+    def load_nifti(self, path: str|os.PathLike):
         """Load a NIfTI file from the specified path.
 
         Args:
@@ -338,7 +340,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_pil(self, path):
+    def load_pil(self, path: str|os.PathLike):
         """Load a PIL image from the specified path
 
         Args:
@@ -349,7 +351,7 @@ class DataLoader:
         """
         return np.array(Image.open(path))
 
-    def load_PIL_stack(self, path):
+    def load_PIL_stack(self, path: str|os.PathLike):
         """Load a stack of PIL files from the specified path.
 
         Args:
@@ -433,7 +435,7 @@ class DataLoader:
 
       
 
-    def _load_vgi_metadata(self, path):
+    def _load_vgi_metadata(self, path: str|os.PathLike):
         """Helper functions that loads metadata from a VGI file
 
         Args:
@@ -482,7 +484,7 @@ class DataLoader:
 
         return meta_data
 
-    def load_vol(self, path):
+    def load_vol(self, path: str|os.PathLike):
         """Load a VOL filed based on the VGI metadata file
 
         Args:
@@ -548,7 +550,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_dicom(self, path):
+    def load_dicom(self, path: str|os.PathLike):
         """Load a DICOM file
 
         Args:
@@ -563,7 +565,7 @@ class DataLoader:
         else:
             return dcm_data.pixel_array
 
-    def load_dicom_dir(self, path):
+    def load_dicom_dir(self, path: str|os.PathLike):
         """Load a directory of DICOM files into a numpy 3d array
 
         Args:
@@ -605,7 +607,7 @@ class DataLoader:
             return vol
         
 
-    def load_zarr(self, path: str):
+    def load_zarr(self, path: str|os.PathLike):
         """ Loads a Zarr array from disk.
 
         Args:
@@ -654,7 +656,7 @@ class DataLoader:
                     message + " Set 'force_load=True' to ignore this error."
                 )
 
-    def load(self, path):
+    def load(self, path: str|os.PathLike):
         """
         Load a file or directory based on the given path.
 
@@ -757,16 +759,16 @@ def _get_ole_offsets(ole):
 
 
 def load(
-    path,
-    virtual_stack=False,
-    dataset_name=None,
-    return_metadata=False,
-    contains=None,
-    progress_bar:bool = True,
+    path: str|os.PathLike,
+    virtual_stack: bool = False,
+    dataset_name: bool = None,
+    return_metadata: bool = False,
+    contains: bool = None,
+    progress_bar: bool = True,
     force_load: bool = False,
-    dim_order=(2, 1, 0),
+    dim_order: tuple = (2, 1, 0),
     **kwargs,
-):
+) -> np.ndarray:
     """
     Load data from the specified file or directory.
 
@@ -854,7 +856,7 @@ def load(
 
     return data
 
-def load_mesh(filename):
+def load_mesh(filename: str) -> trimesh.Trimesh:
     """
     Load a mesh from an .obj file using trimesh.
 
diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py
index ae5173157927dbf4555b05b3a9bacc53f193380a..36976316b8811380d68cbe1b7f5f43e5c001c4f9 100644
--- a/qim3d/io/_ome_zarr.py
+++ b/qim3d/io/_ome_zarr.py
@@ -46,13 +46,13 @@ class OMEScaler(
     """Scaler in the style of OME-Zarr.
     This is needed because their current zoom implementation is broken."""
 
-    def __init__(self, order=0, downscale=2, max_layer=5, method="scaleZYXdask"):
+    def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
         self.order = order
         self.downscale = downscale
         self.max_layer = max_layer
         self.method = method
 
-    def scaleZYX(self, base):
+    def scaleZYX(self, base: da.core.Array):
         """Downsample using :func:`scipy.ndimage.zoom`."""
         rv = [base]
         log.info(f"- Scale 0: {rv[-1].shape}")
@@ -63,7 +63,7 @@ class OMEScaler(
 
         return list(rv)
 
-    def scaleZYXdask(self, base):
+    def scaleZYXdask(self, base: da.core.Array):
         """
         Downsample a 3D volume using Dask and scipy.ndimage.zoom.
 
@@ -82,7 +82,7 @@ class OMEScaler(
 
 
         """
-        def resize_zoom(vol, scale_factors, order, scaled_shape):
+        def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
 
             # Get the chunksize needed so that all the blocks match the new shape
             # This snippet comes from the original OME-Zarr-python library
@@ -181,16 +181,16 @@ class OMEScaler(
 
 
 def export_ome_zarr(
-    path,
-    data,
-    chunk_size=256,
-    downsample_rate=2,
-    order=1,
-    replace=False,
-    method="scaleZYX",
+    path: str|os.PathLike,
+    data: np.ndarray|da.core.Array,
+    chunk_size: int = 256,
+    downsample_rate: int = 2,
+    order: int = 1,
+    replace: bool = False,
+    method: str = "scaleZYX",
     progress_bar: bool = True,
-    progress_bar_repeat_time="auto",
-):
+    progress_bar_repeat_time: str = "auto",
+) -> None:
     """
     Export 3D image data to OME-Zarr format with pyramidal downsampling.
 
@@ -299,7 +299,11 @@ def export_ome_zarr(
     return
 
 
-def import_ome_zarr(path, scale=0, load=True):
+def import_ome_zarr(
+        path: str|os.PathLike, 
+        scale: int = 0, 
+        load: bool = True
+        ) -> np.ndarray:
     """
     Import image data from an OME-Zarr file.
 
diff --git a/qim3d/io/_saving.py b/qim3d/io/_saving.py
index d7721ee2a082a0a7db3f993c58ec9207eefcc5f4..a7053c00fde4b7d4541b1a01286f18ad669e811d 100644
--- a/qim3d/io/_saving.py
+++ b/qim3d/io/_saving.py
@@ -76,7 +76,7 @@ class DataSaver:
         self.sliced_dim = kwargs.get("sliced_dim", 0)
         self.chunk_shape = kwargs.get("chunk_shape", "auto")
 
-    def save_tiff(self, path, data):
+    def save_tiff(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a TIFF file to the given path.
 
         Args:
@@ -85,7 +85,7 @@ class DataSaver:
         """
         tifffile.imwrite(path, data, compression=self.compression)
 
-    def save_tiff_stack(self, path, data):
+    def save_tiff_stack(self, path: str|os.PathLike, data: np.ndarray):
         """Save data as a TIFF stack containing slices in separate files to the given path.
         The slices will be named according to the basename plus a suffix with a zero-filled
         value corresponding to the slice number
@@ -124,7 +124,7 @@ class DataSaver:
                 f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
             )
 
-    def save_nifti(self, path, data):
+    def save_nifti(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a NIfTI file to the given path.
 
         Args:
@@ -154,7 +154,7 @@ class DataSaver:
         # Save image
         nib.save(img, path)
 
-    def save_vol(self, path, data):
+    def save_vol(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a VOL file to the given path.
 
         Args:
@@ -200,7 +200,7 @@ class DataSaver:
                 "dataset", data=data, compression="gzip" if self.compression else None
             )
 
-    def save_dicom(self, path, data):
+    def save_dicom(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a DICOM file to the given path.
 
         Args:
@@ -255,7 +255,7 @@ class DataSaver:
 
         ds.save_as(path)
 
-    def save_to_zarr(self, path, data):
+    def save_to_zarr(self, path: str|os.PathLike, data: da.core.Array):
         """Saves a Dask array to a Zarr array on disk.
 
         Args:
@@ -284,7 +284,7 @@ class DataSaver:
             )
             zarr_array[:] = data
 
-    def save_PIL(self, path, data):
+    def save_PIL(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a PIL file to the given path.
 
         Args:
@@ -303,7 +303,7 @@ class DataSaver:
         # Save image
         img.save(path)
 
-    def save(self, path, data):
+    def save(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to the given path.
 
         Args:
@@ -401,15 +401,15 @@ class DataSaver:
 
 
 def save(
-    path,
-    data,
-    replace=False,
-    compression=False,
-    basename=None,
-    sliced_dim=0,
-    chunk_shape="auto",
+    path: str|os.PathLike,
+    data: np.ndarray,
+    replace: bool = False,
+    compression: bool = False,
+    basename: bool = None,
+    sliced_dim: int = 0,
+    chunk_shape: str = "auto",
     **kwargs,
-):
+) -> None:
     """Save data to a specified file path.
 
     Args:
@@ -464,7 +464,10 @@ def save(
     ).save(path, data)
 
 
-def save_mesh(filename, mesh):
+def save_mesh(
+        filename: str, 
+        mesh: trimesh.Trimesh
+        ) -> None:
     """
     Save a trimesh object to an .obj file.
 
diff --git a/qim3d/io/_sync.py b/qim3d/io/_sync.py
index 9085ae88b3a4e60b6ce2a0744c8983c825b635a3..953bf15b2e5ad6690bd6a077cd29b7b430b22c40 100644
--- a/qim3d/io/_sync.py
+++ b/qim3d/io/_sync.py
@@ -28,7 +28,7 @@ class Sync:
 
             return False
 
-    def check_destination(self, source, destination, checksum=False, verbose=True):
+    def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
         """Check if all files from 'source' are in 'destination'
 
         This function compares the files in the 'source' directory to those in
@@ -80,7 +80,7 @@ class Sync:
 
         return diff_files
 
-    def compare_dirs(self, source, destination, checksum=False, verbose=True):
+    def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
         """Checks whether 'source' and 'destination' directories are synchronized.
 
         This function compares the contents of two directories
@@ -168,7 +168,7 @@ class Sync:
             )
         return
 
-    def count_files_and_dirs(self, path, verbose=True):
+    def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
         """Count the number of files and directories in the given path.
 
         This function recursively counts the number of files and
diff --git a/qim3d/mesh/_common_mesh_methods.py b/qim3d/mesh/_common_mesh_methods.py
index 535fcae886ce081c20c5bb01628341beba5163c3..87cb9d8bfc792ca939d2e0afab00eff4510c2a91 100644
--- a/qim3d/mesh/_common_mesh_methods.py
+++ b/qim3d/mesh/_common_mesh_methods.py
@@ -8,8 +8,8 @@ from qim3d.utils._logger import log
 def from_volume(
     volume: np.ndarray,
     level: float = None,
-    step_size=1,
-    allow_degenerate=False,
+    step_size: int = 1,
+    allow_degenerate: bool = False,
     padding: Tuple[int, int, int] = (2, 2, 2),
     **kwargs: Any,
 ) -> trimesh.Trimesh:
diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index ea81e53ab039744663eac65d46ea85efbd0c0495..e12b1a51849593e0c14287f0c2ab1795eb86ace9 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -20,10 +20,10 @@ class Augmentation:
     """
     
     def __init__(self, 
-                 resize = 'crop', 
-                 transform_train = 'moderate', 
-                 transform_validation = None,
-                 transform_test = None,
+                 resize: str = 'crop', 
+                 transform_train: str = 'moderate', 
+                 transform_validation: str | None = None,
+                 transform_test: str | None = None,
                  mean: float = 0.5, 
                  std: float = 0.5
                 ):
@@ -38,7 +38,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h, im_w, level=None):
+    def augment(self, im_h: int, im_w: int, level: bool | None = None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 38fbdac79c3fb7d055c457376351a29b1ad0bf79..d54f1bf6d7a457457b8fea638fd8057410d2a853 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -4,7 +4,9 @@ from PIL import Image
 from qim3d.utils._logger import log
 import torch
 import numpy as np
-
+from typing import Optional, Callable
+import torch.nn as nn
+from ._data import Augmentation
 
 class Dataset(torch.utils.data.Dataset):
     """
@@ -36,7 +38,7 @@ class Dataset(torch.utils.data.Dataset):
             transform=albumentations.Compose([ToTensorV2()]))
         image, target = dataset[idx]
     """
-    def __init__(self, root_path: str, split="train", transform=None):
+    def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None):
         super().__init__()
 
         # Check if split is valid
@@ -58,7 +60,7 @@ class Dataset(torch.utils.data.Dataset):
     def __len__(self):
         return len(self.sample_images)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
@@ -76,7 +78,7 @@ class Dataset(torch.utils.data.Dataset):
 
 
     # TODO: working with images of different sizes
-    def check_shape_consistency(self,sample_images):
+    def check_shape_consistency(self,sample_images: tuple[str]):
         image_shapes= []
         for image_path in sample_images:
             image_shape = self._get_shape(image_path)
@@ -99,7 +101,7 @@ class Dataset(torch.utils.data.Dataset):
         return Image.open(str(image_path)).size
 
 
-def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
+def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
     """
     Checks the compatibility of the image shape with the depth of the model.
     If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
@@ -131,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
     return h_adjust, w_adjust 
 
 
-def prepare_datasets(path: str, val_fraction: float, model, augmentation):
+def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
     """
     Splits and augments the train/validation/test datasets.
 
@@ -172,7 +174,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
     return train_set, val_set, test_set
 
 
-def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False):  
+def prepare_dataloaders(train_set: torch.utils.data, 
+                        val_set: torch.utils.data, 
+                        test_set: torch.utils.data, 
+                        batch_size: int, 
+                        shuffle_train: bool = True, 
+                        num_workers: int = 8, 
+                        pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:  
     """
     Prepares the dataloaders for model training.
 
diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index 7c8a3b86d005e467c15a378e3216011cbf3ac60c..f46a7481fe93cd299eb43566bd947c7ea3355daf 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -3,24 +3,24 @@
 import torch
 import numpy as np
 
-from torchinfo import summary
+from torchinfo import summary, ModelStatistics
 from qim3d.utils._logger import log
 from qim3d.viz._metrics import plot_metrics
 
 from tqdm.auto import tqdm
 from tqdm.contrib.logging import logging_redirect_tqdm
-
+from models._unet import Hyperparameters
 
 def train_model(
-    model,
-    hyperparameters,
-    train_loader,
-    val_loader,
-    eval_every=1,
-    print_every=5,
-    plot=True,
-    return_loss=False,
-):
+    model: torch.nn.Module,
+    hyperparameters: Hyperparameters,
+    train_loader: torch.utils.data.DataLoader,
+    val_loader: torch.utils.data.DataLoader,
+    eval_every: int = 1,
+    print_every: int = 5,
+    plot: bool = True,
+    return_loss: bool = False,
+) -> tuple[tuple[float], tuple[float]]:
     """Function for training Neural Network models.
 
     Args:
@@ -137,7 +137,7 @@ def train_model(
         return train_loss, val_loss
 
 
-def model_summary(dataloader, model):
+def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module) -> ModelStatistics:
     """Prints the summary of a PyTorch model.
 
     Args:
@@ -160,7 +160,7 @@ def model_summary(dataloader, model):
     return model_s
 
 
-def inference(data, model):
+def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """Performs inference on input data using the specified model.
 
     Performs inference on the input data using the provided model. The input data should be in the form of a list,
@@ -242,7 +242,7 @@ def inference(data, model):
     return inputs, targets, preds
 
 
-def volume_inference(volume, model, threshold=0.5):
+def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray:
     """
     Compute on the entire volume
     Args:
diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py
index 353d10bde5dc08f1b44044b9e482d2adf66a346d..27ee78a4009b235253662c8af50c76baca4870f4 100644
--- a/qim3d/ml/models/_unet.py
+++ b/qim3d/ml/models/_unet.py
@@ -119,13 +119,13 @@ class Hyperparameters:
 
     def __init__(
         self,
-        model,
-        n_epochs=10,
-        learning_rate=1e-3,
-        optimizer="Adam",
-        momentum=0,
-        weight_decay=0,
-        loss_function="Focal",
+        model: nn.Module,
+        n_epochs: int = 10,
+        learning_rate: float = 1e-3,
+        optimizer: str = "Adam",
+        momentum: float = 0,
+        weight_decay: float = 0,
+        loss_function: str = "Focal",
     ):
 
         # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
@@ -168,13 +168,13 @@ class Hyperparameters:
 
     def model_params(
         self,
-        model,
-        n_epochs,
-        optimizer,
-        learning_rate,
-        weight_decay,
-        momentum,
-        loss_function,
+        model: nn.Module,
+        n_epochs: int,
+        optimizer: str,
+        learning_rate: float,
+        weight_decay: float,
+        momentum: float,
+        loss_function: str,
     ):
 
         optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
@@ -188,7 +188,12 @@ class Hyperparameters:
         return hyper_dict
 
     # selecting the optimizer
-    def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
+    def _optimizer(self, 
+        model: nn.Module, 
+        optimizer: str, 
+        learning_rate: float, 
+        weight_decay: float, 
+        momentum: float):
         from torch.optim import Adam, SGD, RMSprop
 
         if optimizer == "Adam":
@@ -212,7 +217,7 @@ class Hyperparameters:
         return optim
 
     # selecting the loss function
-    def _loss_functions(self, loss_function):
+    def _loss_functions(self, loss_function: str):
         from monai.losses import FocalLoss, DiceLoss, DiceCELoss
         from torch.nn import BCEWithLogitsLoss
 
diff --git a/qim3d/operations/_common_operations_methods.py b/qim3d/operations/_common_operations_methods.py
index 1c14e16329dbf0a9cce14695c56f927c5b69d820..5975c30ea3c43f95995559597c8a61f02790809a 100644
--- a/qim3d/operations/_common_operations_methods.py
+++ b/qim3d/operations/_common_operations_methods.py
@@ -153,7 +153,7 @@ def fade_mask(
 
 
 def overlay_rgb_images(
-    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black:bool = True,
+    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black: bool = True,
 ) -> np.ndarray:
     """
     Overlay an RGB foreground onto an RGB background using alpha blending.
diff --git a/qim3d/processing/_layers.py b/qim3d/processing/_layers.py
index d91ef0a42f1ba8715793a674321219cbbf1d634b..306a12844f8be1356ff1d48f4cd832532d6acb75 100644
--- a/qim3d/processing/_layers.py
+++ b/qim3d/processing/_layers.py
@@ -2,7 +2,14 @@ import numpy as np
 from slgbuilder import GraphObject 
 from slgbuilder import MaxflowBuilder
 
-def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, delta:float = 1, min_margin:int = 10, max_margin:int = None, wrap:bool = False):
+def segment_layers(data: np.ndarray, 
+                   inverted: bool = False, 
+                   n_layers: int = 1, 
+                   delta: float = 1, 
+                   min_margin: int = 10, 
+                   max_margin: int = None, 
+                   wrap: bool = False
+                   ) -> list:
     """
     Works on 2D and 3D data.
     Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation
@@ -82,7 +89,7 @@ def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, del
 
     return segmentations
 
-def get_lines(segmentations:list|np.ndarray) -> list:
+def get_lines(segmentations:list[np.ndarray]) -> list:
     """
     Expects list of arrays where each array is 2D segmentation with only 2 classes. This function gets the border between those two
     so it could be plotted. Used with qim3d.processing.segment_layers
diff --git a/qim3d/processing/_local_thickness.py b/qim3d/processing/_local_thickness.py
index a7e877d6d01b51c1aa3af9d403dd5447b7a42f6d..947f8865cfdbdbed4438d0f1e416498502de0e4a 100644
--- a/qim3d/processing/_local_thickness.py
+++ b/qim3d/processing/_local_thickness.py
@@ -10,7 +10,7 @@ def local_thickness(
     image: np.ndarray,
     scale: float = 1,
     mask: Optional[np.ndarray] = None,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> np.ndarray:
     """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness)
diff --git a/qim3d/processing/_structure_tensor.py b/qim3d/processing/_structure_tensor.py
index 10eba89076a3c1158c810f167ec7049735b15237..f6dc98cd253f05bd89c4a8a8c5cdf9a8fd228c2d 100644
--- a/qim3d/processing/_structure_tensor.py
+++ b/qim3d/processing/_structure_tensor.py
@@ -12,7 +12,7 @@ def structure_tensor(
     rho: float = 6.0,
     base_noise: bool = True,
     full: bool = False,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> Tuple[np.ndarray, np.ndarray]:
     """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/).
diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py
index f43f0dc216a1c49f2d35719e8a6a397340092107..45725feec526e0a31de67a0737f99d6f955d2333 100644
--- a/qim3d/segmentation/_connected_components.py
+++ b/qim3d/segmentation/_connected_components.py
@@ -4,7 +4,7 @@ from qim3d.utils._logger import log
 
 
 class CC:
-    def __init__(self, connected_components, num_connected_components):
+    def __init__(self, connected_components: np.ndarray, num_connected_components: int):
         """
         Initializes a ConnectedComponents object.
 
@@ -23,7 +23,7 @@ class CC:
         """
         return self.cc_count
 
-    def get_cc(self, index=None, crop=False):
+    def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray:
         """
         Get the connected component with the given index, if index is None selects a random component.
 
@@ -52,7 +52,7 @@ class CC:
         
         return volume
 
-    def get_bounding_box(self, index=None):
+    def get_bounding_box(self, index: int|None =None)-> list[tuple]:
         """Get the bounding boxes of the connected components.
 
         Args:
diff --git a/qim3d/utils/_doi.py b/qim3d/utils/_doi.py
index b3ece6636b062d24d2002405a08abb1726242f4f..02095f335610b266b0dd57088787eab2697190fd 100644
--- a/qim3d/utils/_doi.py
+++ b/qim3d/utils/_doi.py
@@ -4,7 +4,7 @@ import requests
 from qim3d.utils._logger import log
 
 
-def _validate_response(response):
+def _validate_response(response: requests.Response) -> bool:
     # Check if we got a good response
     if not response.ok:
         log.error(f"Could not read the provided DOI ({response.reason})")
@@ -13,7 +13,7 @@ def _validate_response(response):
     return True
 
 
-def _doi_to_url(doi):
+def _doi_to_url(doi: str) -> str:
     if doi[:3] != "http":
         url = "https://doi.org/" + doi
     else:
@@ -22,7 +22,7 @@ def _doi_to_url(doi):
     return url
 
 
-def _make_request(doi, header):
+def _make_request(doi: str, header: str) -> requests.Response:
     # Get url from doi
     url = _doi_to_url(doi)
 
@@ -35,7 +35,7 @@ def _make_request(doi, header):
     return response
 
 
-def _log_and_get_text(doi, header):
+def _log_and_get_text(doi, header) -> str:
     response = _make_request(doi, header)
 
     if response and response.encoding:
@@ -50,13 +50,13 @@ def _log_and_get_text(doi, header):
         return text
 
 
-def get_bibtex(doi):
+def get_bibtex(doi: str):
     """Generates bibtex from doi"""
     header = {"Accept": "application/x-bibtex"}
 
     return _log_and_get_text(doi, header)
 
-def cusom_header(doi, header):
+def custom_header(doi: str, header: str) -> str:
     """Allows a custom header to be passed
 
     For example:
@@ -67,7 +67,7 @@ def cusom_header(doi, header):
     """
     return _log_and_get_text(doi, header)
 
-def get_metadata(doi):
+def get_metadata(doi: str) -> dict:
     """Generates a metadata dictionary from doi"""
     header = {"Accept": "application/vnd.citationstyles.csl+json"}
     response = _make_request(doi, header)
@@ -76,7 +76,7 @@ def get_metadata(doi):
 
     return metadata
 
-def get_reference(doi):
+def get_reference(doi: str) -> str:
     """Generates a metadata dictionary from doi and use it to build a reference string"""
 
     metadata = get_metadata(doi)
@@ -84,7 +84,7 @@ def get_reference(doi):
 
     return reference_string
 
-def build_reference_string(metadata):
+def build_reference_string(metadata: dict) -> str:
     """Generates a reference string from metadata"""
     authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']])
     year = metadata['issued']['date-parts'][0][0]
diff --git a/qim3d/utils/_misc.py b/qim3d/utils/_misc.py
index a754ecb8840b889eed2b3132a98c1c8ea1e9f1f5..3d8660b445a65577ef3b4c7d16af613b44785aac 100644
--- a/qim3d/utils/_misc.py
+++ b/qim3d/utils/_misc.py
@@ -12,7 +12,7 @@ import difflib
 import qim3d
 
 
-def get_local_ip():
+def get_local_ip() -> str:
     """Retrieves the local IP address of the current machine.
 
     The function uses a socket to determine the local IP address.
@@ -42,7 +42,7 @@ def get_local_ip():
     return ip_address
 
 
-def port_from_str(s):
+def port_from_str(s: str) -> int:
     """
     Generates a port number from a given string.
 
@@ -65,7 +65,7 @@ def port_from_str(s):
     return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4)
 
 
-def gradio_header(title, port):
+def gradio_header(title: str, port: int) -> None:
     """Display the header for a Gradio server.
 
     Displays a formatted header containing the provided title,
@@ -99,7 +99,7 @@ def gradio_header(title, port):
     ouf.showlist(details, style="box", title="Starting gradio server")
 
 
-def sizeof(num, suffix="B"):
+def sizeof(num: float, suffix: str = "B") -> str:
     """Converts a number to a human-readable string representing its size.
 
     Converts the given number to a human-readable string representing its size in
@@ -131,7 +131,7 @@ def sizeof(num, suffix="B"):
     return f"{num:.1f} Y{suffix}"
 
 
-def find_similar_paths(path):
+def find_similar_paths(path: str) -> list[str]:
     parent_dir = os.path.dirname(path) or "."
     parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
     valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
@@ -165,14 +165,14 @@ def get_file_size(file_path: str) -> int:
     return file_size
 
 
-def stringify_path(path):
+def stringify_path(path: os.PathLike) -> str:
     """Converts an os.PathLike object to a string"""
     if isinstance(path, os.PathLike):
         path = path.__fspath__()
     return path
 
 
-def get_port_dict():
+def get_port_dict() -> dict:
     # Gets user and port
     username = getpass.getuser()
     url = f"https://platform.qim.dk/qim-api/get-port/{username}"
@@ -189,7 +189,7 @@ def get_port_dict():
     return port_dict
 
 
-def get_css():
+def get_css() -> str:
 
     current_directory = os.path.dirname(os.path.abspath(__file__))
     parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
@@ -201,7 +201,7 @@ def get_css():
     return css_content
 
 
-def downscale_img(img, max_voxels=512**3):
+def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray:
     """Downscale image if total number of voxels exceeds 512³.
 
     Args:
@@ -226,7 +226,7 @@ def downscale_img(img, max_voxels=512**3):
     return zoom(img, zoom_factor, order=0)
 
 
-def scale_to_float16(arr: np.ndarray):
+def scale_to_float16(arr: np.ndarray) -> np.ndarray:
     """
     Scale the input array to the float16 data type.
 
diff --git a/qim3d/utils/_ome_zarr.py b/qim3d/utils/_ome_zarr.py
index 452997f14be060f9fbafe29891022f373b79f797..db7ad145ca96d0f4b0fade1c5558a6eb07aac01e 100644
--- a/qim3d/utils/_ome_zarr.py
+++ b/qim3d/utils/_ome_zarr.py
@@ -1,7 +1,7 @@
 from zarr.util import normalize_chunks, normalize_dtype, normalize_shape
 import numpy as np
 
-def get_chunk_size(shape:tuple, dtype):
+def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]:
     """
     How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use 
     in qim3d.io.export_ome_zarr function
@@ -20,7 +20,7 @@ def get_chunk_size(shape:tuple, dtype):
     return chunks
 
 
-def get_n_chunks(shapes:tuple, dtypes:tuple):
+def get_n_chunks(shapes: tuple, dtypes: tuple) -> int:
     """
     Estimates how many chunks we will use in advence so we can pass this number to a progress bar and track how many
     have been already written to disk
diff --git a/qim3d/utils/_progress_bar.py b/qim3d/utils/_progress_bar.py
index 629026f4b08a84e664925ce7a0a1a09a9c214434..57f33738911a3c8493f5da855a8a3863f800f0c0 100644
--- a/qim3d/utils/_progress_bar.py
+++ b/qim3d/utils/_progress_bar.py
@@ -24,7 +24,7 @@ class RepeatTimer(Timer):
             self.function(*self.args, **self.kwargs)
 
 class ProgressBar(ABC):
-    def __init__(self,tqdm_kwargs:dict, repeat_time: float,  *args, **kwargs):
+    def __init__(self, tqdm_kwargs: dict, repeat_time: float,  *args, **kwargs):
         """
         Context manager for ('with' statement) to track progress during a long progress over 
         which we don't have control (like loading a file) and thus can not insert the tqdm
@@ -98,7 +98,7 @@ class FileLoadingProgressBar(ProgressBar):
         super().__init__( tqdm_kwargs, repeat_time)
         self.process = psutil.Process()
 
-    def get_new_update(self):
+    def get_new_update(self) -> int:
         counters = self.process.io_counters()
         try:
             memory = counters.read_chars
@@ -107,7 +107,7 @@ class FileLoadingProgressBar(ProgressBar):
         return memory
 
 class OmeZarrExportProgressBar(ProgressBar):
-    def __init__(self,path:str, n_chunks:int, reapeat_time="auto"):
+    def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"):
         """
         Context manager to track the exporting of OmeZarr files.
 
@@ -152,7 +152,7 @@ class OmeZarrExportProgressBar(ProgressBar):
         self.last_update = 0
 
     def get_new_update(self):
-        def file_count(folder_path:str):
+        def file_count(folder_path: str) -> int:
             """
             Goes recursively through the folders and counts how many files are there, 
             Doesn't count metadata json files
diff --git a/qim3d/utils/_server.py b/qim3d/utils/_server.py
index 9d20fccd76c47043ea50c0c74a4c2bee11097d5e..f61651585428ad293353bd9bdc837d29a6275acc 100644
--- a/qim3d/utils/_server.py
+++ b/qim3d/utils/_server.py
@@ -14,7 +14,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
         super().end_headers()
 
-    def list_directory(self, path):
+    def list_directory(self, path: str|os.PathLike):
         """Helper to produce a directory listing, includes hidden files."""
         try:
             file_list = os.listdir(path)
@@ -49,7 +49,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         # Write the encoded HTML directly to the response
         self.wfile.write(encoded)
 
-def start_http_server(directory, port=8000):
+def start_http_server(directory: str, port: int = 8000) -> HTTPServer:
     """
     Starts an HTTP server serving the specified directory on the given port with CORS enabled.
     
diff --git a/qim3d/utils/_system.py b/qim3d/utils/_system.py
index 197926270bcc3a073308b5e18f6a09ba904b92fb..ccccb6724812cc162d948b805cde89844ce47431 100644
--- a/qim3d/utils/_system.py
+++ b/qim3d/utils/_system.py
@@ -36,7 +36,7 @@ class Memory:
             round(self.free_pct, 1),
         )
 
-def _test_disk_speed(file_size_bytes=1024, ntimes=10):
+def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]:
     '''
     Test the write and read speed of the disk by writing a file of a given size
     and then reading it back.
@@ -95,7 +95,7 @@ def _test_disk_speed(file_size_bytes=1024, ntimes=10):
     return avg_write_speed, write_speed_std, avg_read_speed, read_speed_std
 
 
-def disk_report(file_size=1024 * 1024 * 100, ntimes=10):
+def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None:
     '''
     Report the average write and read speed of the disk.
 
diff --git a/qim3d/viz/_cc.py b/qim3d/viz/_cc.py
index 1c0ed56098696de99ccb1706d58321102317aeb8..d07024b304d22fcb815312e70eabc54a88b54fde 100644
--- a/qim3d/viz/_cc.py
+++ b/qim3d/viz/_cc.py
@@ -2,18 +2,18 @@ import matplotlib.pyplot as plt
 import numpy as np
 import qim3d
 from qim3d.utils._logger import log
-
+from qim3d.segmentation._connected_components import CC
 
 def plot_cc(
-    connected_components,
+    connected_components: CC,
     component_indexs: list | tuple = None,
-    max_cc_to_plot=32,
-    overlay=None,
-    crop=False,
-    show=True,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
+    max_cc_to_plot: int = 32,
+    overlay: np.ndarray = None,
+    crop: bool = False,
+    display_figure: bool = True,
+    color_map: str = "viridis",
+    value_min: float = None,
+    value_max: float = None,
     **kwargs,
 ) -> list[plt.Figure]:
     """
@@ -75,7 +75,7 @@ def plot_cc(
                 cc = connected_components.get_cc(component, crop=False)
                 overlay_crop = np.where(cc == 0, 0, overlay)
             fig = qim3d.viz.slices_grid(
-                overlay_crop, show=show, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs
+                overlay_crop, display_figure=display_figure, color_map=color_map, value_min=value_min, value_max=value_max, **kwargs
             )
         else:
             # assigns discrete color map to each connected component if not given
@@ -89,7 +89,7 @@ def plot_cc(
 
         figs.append(fig)
 
-    if not show:
+    if not display_figure:
         return figs
 
     return
diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py
index 2198cd94a0eb703b7a3b95f15a4407397d642a06..e03f08526b176ff61d878cf11876473c7a36577c 100644
--- a/qim3d/viz/_data_exploration.py
+++ b/qim3d/viz/_data_exploration.py
@@ -9,7 +9,9 @@ from typing import List, Optional, Union
 
 import dask.array as da
 import ipywidgets as widgets
+import matplotlib.figure
 import matplotlib.pyplot as plt
+import matplotlib
 from IPython.display import SVG, display
 import matplotlib
 import numpy as np
@@ -29,7 +31,7 @@ def slices_grid(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-    image_size=None,
+    image_size: int = None,
     image_height: int = 2,
     image_width: int = 2,
     display_figure: bool = False,
@@ -38,7 +40,7 @@ def slices_grid(
     color_bar: bool = False,
     color_bar_style: str = "small",
     **matplotlib_imshow_kwargs,
-) -> plt.Figure:
+) -> matplotlib.figure.Figure:
     """Displays one or several slices from a 3d volume.
 
     By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
@@ -296,7 +298,7 @@ def slices_grid(
     return fig
 
 
-def _get_slice_range(position: int, num_slices: int, n_total):
+def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray:
     """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
     start_idx = position - num_slices // 2
     end_idx = (
@@ -324,7 +326,7 @@ def slicer(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
+    image_size: int = None,
     color_bar: bool = False,
     **matplotlib_imshow_kwargs,
 ) -> widgets.interactive:
@@ -401,8 +403,8 @@ def slicer_orthogonal(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
-):
+    image_size: int = None,
+)-> widgets.interactive:
     """Interactive widget for visualizing orthogonal slices of a 3D volume.
 
     Args:
@@ -461,7 +463,7 @@ def fade_mask(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-):
+)-> widgets.interactive:
     """Interactive widget for visualizing the effect of edge fading on a 3D volume.
 
     This can be used to select the best parameters before applying the mask.
@@ -596,7 +598,7 @@ def fade_mask(
     return slicer_obj
 
 
-def chunks(zarr_path: str, **kwargs):
+def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
     """
     Function to visualize chunks of a Zarr dataset using the specified visualization method.
 
@@ -854,14 +856,14 @@ def histogram(
     log_scale: bool = False,
     despine: bool = True,
     show_title: bool = True,
-    color="qim3d",
-    edgecolor=None,
-    figsize=(8, 4.5),
-    element="step",
-    return_fig=False,
-    show=True,
+    color: str = "qim3d",
+    edgecolor: str|None = None,
+    figsize: tuple[float, float] = (8, 4.5),
+    element: str = "step",
+    return_fig: bool = False,
+    show: bool = True,
     **sns_kwargs,
-):
+) -> None|matplotlib.figure.Figure:
     """
     Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
 
diff --git a/qim3d/viz/_detection.py b/qim3d/viz/_detection.py
index a0926105e4c01cb99ab708d1a42c395246411d3e..14ba42bff77169a2a18f24591b92e353e479d32c 100644
--- a/qim3d/viz/_detection.py
+++ b/qim3d/viz/_detection.py
@@ -6,7 +6,7 @@ from IPython.display import clear_output, display
 import qim3d
 
 
-def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs):
+def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive:
     """
     Plots the blobs found on a slice of the volume.
 
diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py
index b80e9556c22ce8d2fa19806b64f1cecc1f6be625..b45a138ccaf80682851dc53bff770a29e11c6172 100644
--- a/qim3d/viz/_k3d.py
+++ b/qim3d/viz/_k3d.py
@@ -15,18 +15,18 @@ from qim3d.utils._misc import downscale_img, scale_to_float16
 
 
 def volumetric(
-    img,
-    aspectmode="data",
-    show=True,
-    save=False,
-    grid_visible=False,
-    color_map='magma',
-    constant_opacity=False,
-    vmin=None,
-    vmax=None,
-    samples="auto",
-    max_voxels=512**3,
-    data_type="scaled_float16",
+    img: np.ndarray,
+    aspectmode: str = "data",
+    show: bool = True,
+    save: bool = False,
+    grid_visible: bool = False,
+    color_map: str = 'magma',
+    constant_opacity: bool = False,
+    vmin: float|None = None,
+    vmax: float|None = None,
+    samples: int|str = "auto",
+    max_voxels: int = 512**3,
+    data_type: str = "scaled_float16",
     **kwargs,
 ):
     """
@@ -175,13 +175,13 @@ def volumetric(
 
 
 def mesh(
-    verts,
-    faces,
-    wireframe=True,
-    flat_shading=True,
-    grid_visible=False,
-    show=True,
-    save=False,
+    verts: np.ndarray,
+    faces: np.ndarray,
+    wireframe: bool = True,
+    flat_shading: bool = True,
+    grid_visible: bool = False,
+    show: bool = True,
+    save: bool = False,
     **kwargs,
 ):
     """
diff --git a/qim3d/viz/_layers2d.py b/qim3d/viz/_layers2d.py
index 1e284461f2eb86bf5883891cd9205911f151d835..676845a57180964107c880c08b32ff2e1c5bc2ed 100644
--- a/qim3d/viz/_layers2d.py
+++ b/qim3d/viz/_layers2d.py
@@ -8,7 +8,7 @@ import numpy as np
 from PIL import Image
 
 
-def image_with_lines(image:np.ndarray, lines: list, line_thickness:float|int) -> Image:
+def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> Image:
     """
     Plots the image and plots the lines on top of it. Then extracts it as PIL.Image and in the same size as the input image was.
     Paramters:
diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py
index 53a7f3a61178534e703c545287ddcd204d26b940..16d7c7ce94bf427e82c20221d6226333f11b1bd5 100644
--- a/qim3d/viz/_metrics.py
+++ b/qim3d/viz/_metrics.py
@@ -1,19 +1,21 @@
 """Visualization tools"""
 
+import matplotlib.figure
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 from qim3d.utils._logger import log
-
+import torch
+import matplotlib
 
 def plot_metrics(
-    *metrics,
-    linestyle="-",
-    batch_linestyle="dotted",
-    labels: list = None,
+    *metrics: tuple[dict[str, float]],
+    linestyle: str = "-",
+    batch_linestyle: str = "dotted",
+    labels: list|None = None,
     figsize: tuple = (16, 6),
-    show=False
+    show: bool = False
 ):
     """
     Plots the metrics over epochs and batches.
@@ -79,8 +81,13 @@ def plot_metrics(
 
 
 def grid_overview(
-    data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
-):
+    data: list|torch.utils.data.Dataset, 
+    num_images: int = 7, 
+    cmap_im: str = "gray", 
+    cmap_segm: str = "viridis", 
+    alpha: float = 0.5, 
+    show: bool = False
+)-> matplotlib.figure.Figure:
     """Displays an overview grid of images, labels, and masks (if they exist).
 
     Labels are the annotated target segmentations
@@ -174,13 +181,13 @@ def grid_overview(
 
 
 def grid_pred(
-    in_targ_preds,
-    num_images=7,
-    cmap_im="gray",
-    cmap_segm="viridis",
-    alpha=0.5,
-    show=False,
-):
+    in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray],
+    num_images: int = 7,
+    cmap_im: str = "gray",
+    cmap_segm: str = "viridis",
+    alpha: float = 0.5,
+    show: bool = False,
+)-> matplotlib.figure.Figure:
     """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
 
     Displays a grid of subplots representing different aspects of the input images and segmentations.
@@ -282,7 +289,7 @@ def grid_pred(
     return fig
 
 
-def vol_masked(vol, vol_mask, viz_delta=128):
+def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray:
     """
     Applies masking to a volume based on a binary volume mask.
 
diff --git a/qim3d/viz/_structure_tensor.py b/qim3d/viz/_structure_tensor.py
index 63d141c41ae9dd365e89a9d82c959d08173d76a6..8148810287fe206334cf1f1e72103e7b3db8b74e 100644
--- a/qim3d/viz/_structure_tensor.py
+++ b/qim3d/viz/_structure_tensor.py
@@ -19,9 +19,9 @@ def vectors(
     vec: np.ndarray,
     axis: int = 0,
     volume_cmap:str = 'grey',
-    vmin:float = None,
-    vmax:float = None,
-    slice_idx: Optional[Union[int, float]] = None,
+    vmin: float|None = None,
+    vmax: float|None = None,
+    slice_idx: Union[int, float]|None = None,
     grid_size: int = 10,
     interactive: bool = True,
     figsize: Tuple[int, int] = (10, 5),