From b4d2a9ebd6891a7a784273f78166a5cb5dae4926 Mon Sep 17 00:00:00 2001
From: s214735 <s214735@student.dtu.dk>
Date: Fri, 10 Jan 2025 11:35:36 +0100
Subject: [PATCH] Type hints

---
 qim3d/detection/_common_detection_methods.py  |  2 +-
 qim3d/features/_common_features_methods.py    | 12 +++--
 qim3d/filters/_common_filter_methods.py       | 40 ++++++++++-----
 qim3d/generate/_aggregators.py                |  2 +-
 qim3d/gui/annotation_tool.py                  |  8 +--
 qim3d/gui/data_explorer.py                    | 23 +++++----
 qim3d/gui/interface.py                        |  7 +--
 qim3d/gui/iso3d.py                            | 49 ++++++++++---------
 qim3d/gui/layers2d.py                         | 13 ++---
 qim3d/gui/local_thickness.py                  | 12 ++---
 qim3d/gui/qim_theme.py                        |  2 +-
 qim3d/io/_convert.py                          | 13 ++---
 qim3d/io/_downloader.py                       | 16 +++---
 qim3d/io/_loading.py                          | 46 ++++++++---------
 qim3d/io/_ome_zarr.py                         | 32 ++++++------
 qim3d/io/_saving.py                           | 37 +++++++-------
 qim3d/io/_sync.py                             |  6 +--
 qim3d/mesh/_common_mesh_methods.py            |  4 +-
 qim3d/ml/_augmentations.py                    | 10 ++--
 qim3d/ml/_data.py                             | 22 ++++++---
 qim3d/ml/_ml_utils.py                         | 28 +++++------
 qim3d/ml/models/_unet.py                      | 37 ++++++++------
 .../operations/_common_operations_methods.py  |  2 +-
 qim3d/processing/_layers.py                   | 11 ++++-
 qim3d/processing/_local_thickness.py          |  2 +-
 qim3d/processing/_structure_tensor.py         |  2 +-
 qim3d/segmentation/_connected_components.py   |  6 +--
 qim3d/utils/_doi.py                           | 18 +++----
 qim3d/utils/_misc.py                          | 20 ++++----
 qim3d/utils/_ome_zarr.py                      |  4 +-
 qim3d/utils/_progress_bar.py                  |  8 +--
 qim3d/utils/_server.py                        |  4 +-
 qim3d/utils/_system.py                        |  4 +-
 qim3d/viz/_cc.py                              | 22 ++++-----
 qim3d/viz/_data_exploration.py                | 32 ++++++------
 qim3d/viz/_detection.py                       |  2 +-
 qim3d/viz/_k3d.py                             | 38 +++++++-------
 qim3d/viz/_layers2d.py                        |  2 +-
 qim3d/viz/_metrics.py                         | 39 +++++++++------
 qim3d/viz/_structure_tensor.py                |  6 +--
 40 files changed, 356 insertions(+), 287 deletions(-)

diff --git a/qim3d/detection/_common_detection_methods.py b/qim3d/detection/_common_detection_methods.py
index 0a985d85..131ec1e4 100644
--- a/qim3d/detection/_common_detection_methods.py
+++ b/qim3d/detection/_common_detection_methods.py
@@ -15,7 +15,7 @@ def blobs(
     overlap: float = 0.5,
     threshold_rel: float = None,
     exclude_border: bool = False,
-) -> np.ndarray:
+) -> tuple[np.ndarray, np.ndarray]:
     """
     Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True
 
diff --git a/qim3d/features/_common_features_methods.py b/qim3d/features/_common_features_methods.py
index e69b6cba..63838ef1 100644
--- a/qim3d/features/_common_features_methods.py
+++ b/qim3d/features/_common_features_methods.py
@@ -5,7 +5,9 @@ import trimesh
 import qim3d
 
 
-def volume(obj, **mesh_kwargs) -> float:
+def volume(obj: np.ndarray|trimesh.Trimesh, 
+           **mesh_kwargs
+           ) -> float:
     """
     Compute the volume of a 3D volume or mesh.
 
@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
     return obj.volume
 
 
-def area(obj, **mesh_kwargs) -> float:
+def area(obj: np.ndarray|trimesh.Trimesh, 
+         **mesh_kwargs
+         ) -> float:
     """
     Compute the surface area of a 3D volume or mesh.
 
@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
     return obj.area
 
 
-def sphericity(obj, **mesh_kwargs) -> float:
+def sphericity(obj: np.ndarray|trimesh.Trimesh, 
+               **mesh_kwargs
+               ) -> float:
     """
     Compute the sphericity of a 3D volume or mesh.
 
diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py
index f2d42197..59b0be12 100644
--- a/qim3d/filters/_common_filter_methods.py
+++ b/qim3d/filters/_common_filter_methods.py
@@ -26,7 +26,10 @@ __all__ = [
 
 
 class FilterBase:
-    def __init__(self, dask=False, chunks="auto", *args, **kwargs):
+    def __init__(self, 
+                 dask: bool = False, 
+                 chunks: str = "auto", 
+                 *args, **kwargs):
         """
         Base class for image filters.
 
@@ -40,7 +43,7 @@ class FilterBase:
         self.kwargs = kwargs
 
 class Gaussian(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a Gaussian filter to the input.
 
@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
 
 
 class Median(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a median filter to the input.
 
@@ -68,7 +71,7 @@ class Median(FilterBase):
 
 
 class Maximum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a maximum filter to the input.
 
@@ -82,7 +85,7 @@ class Maximum(FilterBase):
 
 
 class Minimum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a minimum filter to the input.
 
@@ -95,7 +98,7 @@ class Minimum(FilterBase):
         return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
 
 class Tophat(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray) -> np.ndarray:
         """
         Applies a tophat filter to the input.
 
@@ -210,7 +213,10 @@ class Pipeline:
         return input
 
 
-def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
+def gaussian(vol: np.ndarray, 
+             dask: bool = False, 
+             chunks: str = 'auto',
+             *args, **kwargs) -> np.ndarray:
     """
     Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
 
@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
         return res
 
 
-def median(vol, dask=False, chunks='auto', **kwargs):
+def median(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str ='auto', 
+           **kwargs) -> np.ndarray:
     """
     Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
 
@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def maximum(vol, dask=False, chunks='auto', **kwargs):
+def maximum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
 
@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def minimum(vol, dask=False, chunks='auto', **kwargs):
+def minimum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
 
@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
         res = ndimage.minimum_filter(vol, **kwargs)
         return res
 
-def tophat(vol, dask=False, chunks='auto', **kwargs):
+def tophat(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str = 'auto', 
+           **kwargs) -> np.ndarray:
     """
     Remove background from the volume.
 
diff --git a/qim3d/generate/_aggregators.py b/qim3d/generate/_aggregators.py
index 9d10df37..262b047e 100644
--- a/qim3d/generate/_aggregators.py
+++ b/qim3d/generate/_aggregators.py
@@ -142,7 +142,7 @@ def noise_object_collection(
     object_shape: str = None,
     seed: int = 0,
     verbose: bool = False,
-) -> tuple[np.ndarray, object]:
+) -> tuple[np.ndarray, np.ndarray]:
     """
     Generate a 3D volume of multiple synthetic objects using Perlin noise.
 
diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py
index 4498e1ad..5ebae5f7 100644
--- a/qim3d/gui/annotation_tool.py
+++ b/qim3d/gui/annotation_tool.py
@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
 
 
 class Interface(BaseInterface):
-    def __init__(self, name_suffix: str = "", verbose: bool = False, img=None):
+    def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
         super().__init__(
             title="Annotation Tool",
             height=768,
@@ -55,7 +55,7 @@ class Interface(BaseInterface):
         self.masks_rgb = None
         self.temp_files = []
 
-    def get_result(self):
+    def get_result(self) -> dict:
         # Get the temporary files from gradio
         temp_path_list = []
         for filename in os.listdir(self.temp_dir):
@@ -95,13 +95,13 @@ class Interface(BaseInterface):
         except FileNotFoundError:
             files = None
 
-    def create_preview(self, img_editor):
+    def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
         background = img_editor["background"]
         masks = img_editor["layers"][0]
         overlay_image = overlay_rgb_images(background, masks)
         return overlay_image
 
-    def cerate_download_list(self, img_editor):
+    def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
         masks_rgb = img_editor["layers"][0]
         mask_threshold = 200  # This value is based
 
diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py
index ecabeb81..ee3ba16d 100644
--- a/qim3d/gui/data_explorer.py
+++ b/qim3d/gui/data_explorer.py
@@ -20,6 +20,7 @@ import os
 import re
 
 import gradio as gr
+import matplotlib.figure
 import matplotlib.pyplot as plt
 import numpy as np
 import outputformat as ouf
@@ -29,6 +30,8 @@ from qim3d.utils._logger import log
 from qim3d.utils import _misc
 
 from qim3d.gui.interface import BaseInterface
+from typing import Callable, Any, Dict
+import matplotlib
 
 
 class Interface(BaseInterface):
@@ -271,7 +274,7 @@ class Interface(BaseInterface):
         operations.change(fn=self.show_results, inputs = operations, outputs = results)
         cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
 
-    def update_explorer(self, new_path):
+    def update_explorer(self, new_path: str):
         new_path = os.path.expanduser(new_path)
 
         # In case we have a directory
@@ -367,7 +370,7 @@ class Interface(BaseInterface):
         except Exception as error_message:
             self.error_message = F"Error when loading data: {error_message}"
     
-    def run_operations(self, operations, *args):
+    def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
         outputs = []
         self.calculated_operations = []
         for operation in self.all_operations:
@@ -411,7 +414,7 @@ class Interface(BaseInterface):
             case _:
                 raise NotImplementedError(F"Operation '{operation} is not defined")
 
-    def show_results(self, operations):
+    def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
         update_list = []
         for operation in self.all_operations:
             if operation in operations and operation in self.calculated_operations:
@@ -426,7 +429,7 @@ class Interface(BaseInterface):
 #
 #######################################################
 
-    def create_img_fig(self, img, **kwargs):
+    def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
         fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
 
         ax.imshow(img, interpolation="nearest", **kwargs)
@@ -437,8 +440,8 @@ class Interface(BaseInterface):
 
         return fig
 
-    def update_slice_wrapper(self, letter):
-        def update_slice(position_slider:float, cmap:str):
+    def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
+        def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]:
             """
             position_slider: float from gradio slider, saying which relative slice we want to see
             cmap: string gradio drop down menu, saying what cmap we want to use for display
@@ -465,7 +468,7 @@ class Interface(BaseInterface):
             return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
         return update_slice
     
-    def vol_histogram(self, nbins, min_value, max_value):
+    def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
         # Start histogram
         vol_hist = np.zeros(nbins)
 
@@ -478,7 +481,7 @@ class Interface(BaseInterface):
 
         return vol_hist, bin_edges
 
-    def plot_histogram(self):
+    def plot_histogram(self) -> matplotlib.figure.Figure:
         # The Histogram needs results from the projections
         if not self.projections_calculated:
             _ = self.get_projections()
@@ -498,7 +501,7 @@ class Interface(BaseInterface):
 
         return fig
     
-    def create_projections_figs(self):
+    def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
         if not self.projections_calculated:
             projections = self.get_projections()
             self.max_projection = projections[0]
@@ -519,7 +522,7 @@ class Interface(BaseInterface):
         self.projections_calculated = True
         return max_projection_fig, min_projection_fig
 
-    def get_projections(self):
+    def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
         # Create arrays for iteration
         max_projection = np.zeros(np.shape(self.vol[0]))
         min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py
index c5596850..36d820dd 100644
--- a/qim3d/gui/interface.py
+++ b/qim3d/gui/interface.py
@@ -6,6 +6,7 @@ import gradio as gr
 
 from .qim_theme import QimTheme
 import qim3d.gui
+import numpy as np
 
 
 # TODO: when offline it throws an error in cli
@@ -48,10 +49,10 @@ class BaseInterface(ABC):
     def set_invisible(self):
         return gr.update(visible=False)
     
-    def change_visibility(self, is_visible):
+    def change_visibility(self, is_visible: bool):
         return gr.update(visible = is_visible)
 
-    def launch(self, img=None, force_light_mode: bool = True, **kwargs):
+    def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
         """
         img: If None, user can upload image after the interface is launched.
             If defined, the interface will be launched with the image already there
@@ -76,7 +77,7 @@ class BaseInterface(ABC):
             **kwargs,
         )
 
-    def clear(self):
+    def clear(self) -> None:
         """Used to reset outputs with the clear button"""
         return None
 
diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py
index 1a20f91e..c7d4620f 100644
--- a/qim3d/gui/iso3d.py
+++ b/qim3d/gui/iso3d.py
@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
         self.img = img
         self.plot_height = plot_height
 
-    def load_data(self, gradiofile):
+    def load_data(self, gradiofile: gr.File):
         try:
             self.vol = load(gradiofile.name)
             assert self.vol.ndim == 3
@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
         except AssertionError:
             raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
 
-    def resize_vol(self, display_size):
+    def resize_vol(self, display_size: int):
         """Resizes the loaded volume to the display size"""
 
         # Get original size
@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
                 f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
             )
 
-    def save_fig(self, fig, filename):
+    def save_fig(self, fig: go.Figure, filename: str):
         # Write Plotly figure to disk
         fig.write_html(filename)
 
     def create_fig(self, 
-        gradio_file,
-        display_size,
-        opacity,
-        opacityscale,
-        only_wireframe,
-        min_value,
-        max_value,
-        surface_count,
-        colormap,
-        show_colorbar,
-        reversescale,
-        flip_z,
-        show_axis,
-        show_ticks,
-        show_caps,
-        show_z_slice,
-        slice_z_location,
-        show_y_slice,
-        slice_y_location,
-        show_x_slice,
-        slice_x_location,):
+        gradio_file: gr.File,
+        display_size: int ,
+        opacity: float,
+        opacityscale: str,
+        only_wireframe: bool,
+        min_value: float,
+        max_value: float,
+        surface_count: int,
+        colormap: str,
+        show_colorbar: bool,
+        reversescale: bool,
+        flip_z: bool,
+        show_axis: bool,
+        show_ticks: bool,
+        show_caps: bool,
+        show_z_slice: bool,
+        slice_z_location: int,
+        show_y_slice: bool,
+        slice_y_location: int,
+        show_x_slice: bool,
+        slice_x_location: int,
+        ) -> tuple[go.Figure, str]:
 
         # Load volume
         self.load_data(gradio_file)
diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py
index 26d9639f..febd16a5 100644
--- a/qim3d/gui/layers2d.py
+++ b/qim3d/gui/layers2d.py
@@ -29,6 +29,7 @@ from qim3d.processing import segment_layers, get_lines
 from qim3d.operations import overlay_rgb_images
 from qim3d.io import load
 from qim3d.viz._layers2d import image_with_lines
+from typing import Dict, Any
 
 #TODO figure out how not update anything and go through processing when there are no data loaded
 # So user could play with the widgets but it doesnt throw error
@@ -303,14 +304,14 @@ class Interface(BaseInterface):
             
         
 
-    def change_plot_type(self, plot_type, ):
+    def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
         self.plot_type = plot_type
         if plot_type == 'Segmentation lines':
             return gr.update(visible = False), gr.update(visible = True)
         else:  
             return gr.update(visible = True), gr.update(visible = False)
         
-    def change_plot_size(self, x_check, y_check, z_check):
+    def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
         """
         Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
         also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
@@ -320,10 +321,10 @@ class Interface(BaseInterface):
         height = self.heights[index] # also used to define heights of plots in the begining
         return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
 
-    def change_row_visibility(self, x_check, y_check, z_check):
+    def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
         return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
     
-    def update_explorer(self, new_path):
+    def update_explorer(self, new_path: str):
         # Refresh the file explorer object
         new_path = os.path.expanduser(new_path)
 
@@ -342,13 +343,13 @@ class Interface(BaseInterface):
     def set_relaunch_button(self):
         return gr.update(value=f"Relaunch", interactive=True)
 
-    def set_spinner(self, message):
+    def set_spinner(self, message: str):
         if self.error:
             return gr.Button()
         # spinner icon/shows the user something is happeing
         return gr.update(value=f"{message}", interactive=False)
     
-    def load_data(self, base_path, explorer):
+    def load_data(self, base_path: str, explorer: str):
         if base_path and os.path.isfile(base_path):
             file_path = base_path
         elif explorer and os.path.isfile(explorer):
diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py
index 0c76f1aa..461f3e17 100644
--- a/qim3d/gui/local_thickness.py
+++ b/qim3d/gui/local_thickness.py
@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples
 
 class Interface(InterfaceWithExamples):
     def __init__(self,
-                 img = None,
+                 img: np.ndarray = None,
                  verbose:bool = False,
                  plot_height:int = 768,
                  figsize:int = 6): 
@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
     #
     #######################################################
 
-    def process_input(self, data, dark_objects):
+    def process_input(self, data: np.ndarray, dark_objects: bool):
         # Load volume
         try:
             self.vol = load(data.name)
@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
         self.vmin = np.min(self.vol)
         self.vmax = np.max(self.vol)
 
-    def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"):
+    def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
         plt.close()
         z_idx = int(zpos * (vol.shape[0] - 1))
         fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
 
         return fig
 
-    def make_binary(self, threshold):
+    def make_binary(self, threshold: float):
         # Make a binary volume
         # Nothing fancy, but we could add new features here
         self.vol_binary = self.vol > (threshold * np.max(self.vol))
     
-    def compute_localthickness(self, lt_scale):
+    def compute_localthickness(self, lt_scale: float):
         self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
 
         # Valus for visualization
         self.vmin_lt = np.min(self.vol_thickness)
         self.vmax_lt = np.max(self.vol_thickness)
 
-    def thickness_histogram(self, nbins):
+    def thickness_histogram(self, nbins: int):
         # Ignore zero thickness
         non_zero_values = self.vol_thickness[self.vol_thickness > 0]
 
diff --git a/qim3d/gui/qim_theme.py b/qim3d/gui/qim_theme.py
index f13594e4..5cdc844b 100644
--- a/qim3d/gui/qim_theme.py
+++ b/qim3d/gui/qim_theme.py
@@ -7,7 +7,7 @@ class QimTheme(gr.themes.Default):
     there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
     in comments.
     """
-    def __init__(self, force_light_mode:bool = True):
+    def __init__(self, force_light_mode: bool = True):
         """
         Parameters:
         -----------
diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py
index 4ab6c0a9..0d776822 100644
--- a/qim3d/io/_convert.py
+++ b/qim3d/io/_convert.py
@@ -7,6 +7,7 @@ import numpy as np
 import tifffile as tiff
 import zarr
 from tqdm import tqdm
+import zarr.core
 
 from qim3d.utils._misc import stringify_path
 from qim3d.io._saving import save
@@ -21,7 +22,7 @@ class Convert:
         """
         self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
 
-    def convert(self, input_path, output_path):
+    def convert(self, input_path: str, output_path: str):
         def get_file_extension(file_path):
             root, ext = os.path.splitext(file_path)
             if ext in ['.gz', '.bz2', '.xz']:  # handle common compressed extensions
@@ -67,7 +68,7 @@ class Convert:
             else:
                 raise ValueError("Invalid path")
 
-    def convert_tif_to_zarr(self, tif_path, zarr_path):
+    def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a tiff file to a zarr file
 
         Args:
@@ -97,7 +98,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_tif(self, zarr_path, tif_path):
+    def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
         """Convert a zarr file to a tiff file
 
         Args:
@@ -110,7 +111,7 @@ class Convert:
         z = zarr.open(zarr_path)
         save(tif_path, z)
 
-    def convert_nifti_to_zarr(self, nifti_path, zarr_path):
+    def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a nifti file to a zarr file
 
         Args:
@@ -139,7 +140,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_nifti(self, zarr_path, nifti_path, compression=False):
+    def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
         """Convert a zarr file to a nifti file
 
         Args:
@@ -153,7 +154,7 @@ class Convert:
         save(nifti_path, z, compression=compression)
         
 
-def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
+def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
     """Convert a file to another format without loading the entire file into memory
 
     Args:
diff --git a/qim3d/io/_downloader.py b/qim3d/io/_downloader.py
index fe42e410..18e36ca5 100644
--- a/qim3d/io/_downloader.py
+++ b/qim3d/io/_downloader.py
@@ -76,7 +76,7 @@ class Downloader:
             [file_name_n](load_file,optional): Function to download file number n in the given folder.
         """
 
-        def __init__(self, folder):
+        def __init__(self, folder: str):
             files = _extract_names(folder)
 
             for idx, file in enumerate(files):
@@ -88,7 +88,7 @@ class Downloader:
 
                 setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
 
-        def _make_fn(self, folder, file):
+        def _make_fn(self, folder: str, file: str):
             """Private method that returns a function. The function downloads the chosen file from the folder.
 
             Args:
@@ -101,7 +101,7 @@ class Downloader:
 
             url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
 
-            def _download(load_file=False, virtual_stack=True):
+            def _download(load_file: bool = False, virtual_stack: bool = True):
                 """Downloads the file and optionally also loads it.
 
                 Args:
@@ -121,7 +121,7 @@ class Downloader:
             return _download
 
 
-def _update_progress(pbar, blocknum, bs):
+def _update_progress(pbar: tqdm, blocknum: int, bs: int):
     """
     Helper function for the ´download_file()´ function. Updates the progress bar.
     """
@@ -129,7 +129,7 @@ def _update_progress(pbar, blocknum, bs):
     pbar.update(blocknum * bs - pbar.n)
 
 
-def _get_file_size(url):
+def _get_file_size(url: str):
     """
     Helper function for the ´download_file()´ function. Finds the size of the file.
     """
@@ -137,7 +137,7 @@ def _get_file_size(url):
     return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
 
 
-def download_file(path, name, file):
+def download_file(path: str, name: str, file: str):
     """Downloads the file from path / name / file.
 
     Args:
@@ -177,7 +177,7 @@ def download_file(path, name, file):
         )
 
 
-def _extract_html(url):
+def _extract_html(url: str):
     """Extracts the html content of a webpage in "utf-8"
 
     Args:
@@ -198,7 +198,7 @@ def _extract_html(url):
     return html_content
 
 
-def _extract_names(name=None):
+def _extract_names(name: str = None):
     """Extracts the names of the folders and files.
 
     Finds the names of either the folders if no name is given,
diff --git a/qim3d/io/_loading.py b/qim3d/io/_loading.py
index e5f2191e..e3d7d45b 100644
--- a/qim3d/io/_loading.py
+++ b/qim3d/io/_loading.py
@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
 from qim3d.utils._progress_bar import FileLoadingProgressBar
 import trimesh
 
+from typing import Optional, Dict
+
 dask.config.set(scheduler="processes") 
 
 
@@ -76,7 +78,7 @@ class DataLoader:
         self.dim_order = kwargs.get("dim_order", (2, 1, 0))
         self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
 
-    def load_tiff(self, path):
+    def load_tiff(self, path: str|os.PathLike):
         """Load a TIFF file from the specified path.
 
         Args:
@@ -100,7 +102,7 @@ class DataLoader:
 
         return vol
 
-    def load_h5(self, path):
+    def load_h5(self, path: str|os.PathLike) -> tuple[np.ndarray, Optional[Dict]]:
         """Load an HDF5 file from the specified path.
 
         Args:
@@ -183,7 +185,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_tiff_stack(self, path):
+    def load_tiff_stack(self, path: str|os.PathLike) -> np.ndarray|np.memmap:
         """Load a stack of TIFF files from the specified path.
 
         Args:
@@ -237,7 +239,7 @@ class DataLoader:
 
         return vol
 
-    def load_txrm(self, path):
+    def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]:
         """Load a TXRM/XRM/TXM file from the specified path.
 
         Args:
@@ -308,7 +310,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_nifti(self, path):
+    def load_nifti(self, path: str|os.PathLike):
         """Load a NIfTI file from the specified path.
 
         Args:
@@ -338,7 +340,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_pil(self, path):
+    def load_pil(self, path: str|os.PathLike):
         """Load a PIL image from the specified path
 
         Args:
@@ -349,7 +351,7 @@ class DataLoader:
         """
         return np.array(Image.open(path))
 
-    def load_PIL_stack(self, path):
+    def load_PIL_stack(self, path: str|os.PathLike):
         """Load a stack of PIL files from the specified path.
 
         Args:
@@ -433,7 +435,7 @@ class DataLoader:
 
       
 
-    def _load_vgi_metadata(self, path):
+    def _load_vgi_metadata(self, path: str|os.PathLike):
         """Helper functions that loads metadata from a VGI file
 
         Args:
@@ -482,7 +484,7 @@ class DataLoader:
 
         return meta_data
 
-    def load_vol(self, path):
+    def load_vol(self, path: str|os.PathLike):
         """Load a VOL filed based on the VGI metadata file
 
         Args:
@@ -548,7 +550,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_dicom(self, path):
+    def load_dicom(self, path: str|os.PathLike):
         """Load a DICOM file
 
         Args:
@@ -563,7 +565,7 @@ class DataLoader:
         else:
             return dcm_data.pixel_array
 
-    def load_dicom_dir(self, path):
+    def load_dicom_dir(self, path: str|os.PathLike):
         """Load a directory of DICOM files into a numpy 3d array
 
         Args:
@@ -605,7 +607,7 @@ class DataLoader:
             return vol
         
 
-    def load_zarr(self, path: str):
+    def load_zarr(self, path: str|os.PathLike):
         """ Loads a Zarr array from disk.
 
         Args:
@@ -654,7 +656,7 @@ class DataLoader:
                     message + " Set 'force_load=True' to ignore this error."
                 )
 
-    def load(self, path):
+    def load(self, path: str|os.PathLike):
         """
         Load a file or directory based on the given path.
 
@@ -757,16 +759,16 @@ def _get_ole_offsets(ole):
 
 
 def load(
-    path,
-    virtual_stack=False,
-    dataset_name=None,
-    return_metadata=False,
-    contains=None,
-    progress_bar:bool = True,
+    path: str|os.PathLike,
+    virtual_stack: bool = False,
+    dataset_name: bool = None,
+    return_metadata: bool = False,
+    contains: bool = None,
+    progress_bar: bool = True,
     force_load: bool = False,
-    dim_order=(2, 1, 0),
+    dim_order: tuple = (2, 1, 0),
     **kwargs,
-):
+) -> np.ndarray:
     """
     Load data from the specified file or directory.
 
@@ -854,7 +856,7 @@ def load(
 
     return data
 
-def load_mesh(filename):
+def load_mesh(filename: str) -> trimesh.Trimesh:
     """
     Load a mesh from an .obj file using trimesh.
 
diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py
index ae517315..36976316 100644
--- a/qim3d/io/_ome_zarr.py
+++ b/qim3d/io/_ome_zarr.py
@@ -46,13 +46,13 @@ class OMEScaler(
     """Scaler in the style of OME-Zarr.
     This is needed because their current zoom implementation is broken."""
 
-    def __init__(self, order=0, downscale=2, max_layer=5, method="scaleZYXdask"):
+    def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
         self.order = order
         self.downscale = downscale
         self.max_layer = max_layer
         self.method = method
 
-    def scaleZYX(self, base):
+    def scaleZYX(self, base: da.core.Array):
         """Downsample using :func:`scipy.ndimage.zoom`."""
         rv = [base]
         log.info(f"- Scale 0: {rv[-1].shape}")
@@ -63,7 +63,7 @@ class OMEScaler(
 
         return list(rv)
 
-    def scaleZYXdask(self, base):
+    def scaleZYXdask(self, base: da.core.Array):
         """
         Downsample a 3D volume using Dask and scipy.ndimage.zoom.
 
@@ -82,7 +82,7 @@ class OMEScaler(
 
 
         """
-        def resize_zoom(vol, scale_factors, order, scaled_shape):
+        def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
 
             # Get the chunksize needed so that all the blocks match the new shape
             # This snippet comes from the original OME-Zarr-python library
@@ -181,16 +181,16 @@ class OMEScaler(
 
 
 def export_ome_zarr(
-    path,
-    data,
-    chunk_size=256,
-    downsample_rate=2,
-    order=1,
-    replace=False,
-    method="scaleZYX",
+    path: str|os.PathLike,
+    data: np.ndarray|da.core.Array,
+    chunk_size: int = 256,
+    downsample_rate: int = 2,
+    order: int = 1,
+    replace: bool = False,
+    method: str = "scaleZYX",
     progress_bar: bool = True,
-    progress_bar_repeat_time="auto",
-):
+    progress_bar_repeat_time: str = "auto",
+) -> None:
     """
     Export 3D image data to OME-Zarr format with pyramidal downsampling.
 
@@ -299,7 +299,11 @@ def export_ome_zarr(
     return
 
 
-def import_ome_zarr(path, scale=0, load=True):
+def import_ome_zarr(
+        path: str|os.PathLike, 
+        scale: int = 0, 
+        load: bool = True
+        ) -> np.ndarray:
     """
     Import image data from an OME-Zarr file.
 
diff --git a/qim3d/io/_saving.py b/qim3d/io/_saving.py
index d7721ee2..a7053c00 100644
--- a/qim3d/io/_saving.py
+++ b/qim3d/io/_saving.py
@@ -76,7 +76,7 @@ class DataSaver:
         self.sliced_dim = kwargs.get("sliced_dim", 0)
         self.chunk_shape = kwargs.get("chunk_shape", "auto")
 
-    def save_tiff(self, path, data):
+    def save_tiff(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a TIFF file to the given path.
 
         Args:
@@ -85,7 +85,7 @@ class DataSaver:
         """
         tifffile.imwrite(path, data, compression=self.compression)
 
-    def save_tiff_stack(self, path, data):
+    def save_tiff_stack(self, path: str|os.PathLike, data: np.ndarray):
         """Save data as a TIFF stack containing slices in separate files to the given path.
         The slices will be named according to the basename plus a suffix with a zero-filled
         value corresponding to the slice number
@@ -124,7 +124,7 @@ class DataSaver:
                 f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
             )
 
-    def save_nifti(self, path, data):
+    def save_nifti(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a NIfTI file to the given path.
 
         Args:
@@ -154,7 +154,7 @@ class DataSaver:
         # Save image
         nib.save(img, path)
 
-    def save_vol(self, path, data):
+    def save_vol(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a VOL file to the given path.
 
         Args:
@@ -200,7 +200,7 @@ class DataSaver:
                 "dataset", data=data, compression="gzip" if self.compression else None
             )
 
-    def save_dicom(self, path, data):
+    def save_dicom(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a DICOM file to the given path.
 
         Args:
@@ -255,7 +255,7 @@ class DataSaver:
 
         ds.save_as(path)
 
-    def save_to_zarr(self, path, data):
+    def save_to_zarr(self, path: str|os.PathLike, data: da.core.Array):
         """Saves a Dask array to a Zarr array on disk.
 
         Args:
@@ -284,7 +284,7 @@ class DataSaver:
             )
             zarr_array[:] = data
 
-    def save_PIL(self, path, data):
+    def save_PIL(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to a PIL file to the given path.
 
         Args:
@@ -303,7 +303,7 @@ class DataSaver:
         # Save image
         img.save(path)
 
-    def save(self, path, data):
+    def save(self, path: str|os.PathLike, data: np.ndarray):
         """Save data to the given path.
 
         Args:
@@ -401,15 +401,15 @@ class DataSaver:
 
 
 def save(
-    path,
-    data,
-    replace=False,
-    compression=False,
-    basename=None,
-    sliced_dim=0,
-    chunk_shape="auto",
+    path: str|os.PathLike,
+    data: np.ndarray,
+    replace: bool = False,
+    compression: bool = False,
+    basename: bool = None,
+    sliced_dim: int = 0,
+    chunk_shape: str = "auto",
     **kwargs,
-):
+) -> None:
     """Save data to a specified file path.
 
     Args:
@@ -464,7 +464,10 @@ def save(
     ).save(path, data)
 
 
-def save_mesh(filename, mesh):
+def save_mesh(
+        filename: str, 
+        mesh: trimesh.Trimesh
+        ) -> None:
     """
     Save a trimesh object to an .obj file.
 
diff --git a/qim3d/io/_sync.py b/qim3d/io/_sync.py
index 9085ae88..953bf15b 100644
--- a/qim3d/io/_sync.py
+++ b/qim3d/io/_sync.py
@@ -28,7 +28,7 @@ class Sync:
 
             return False
 
-    def check_destination(self, source, destination, checksum=False, verbose=True):
+    def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
         """Check if all files from 'source' are in 'destination'
 
         This function compares the files in the 'source' directory to those in
@@ -80,7 +80,7 @@ class Sync:
 
         return diff_files
 
-    def compare_dirs(self, source, destination, checksum=False, verbose=True):
+    def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
         """Checks whether 'source' and 'destination' directories are synchronized.
 
         This function compares the contents of two directories
@@ -168,7 +168,7 @@ class Sync:
             )
         return
 
-    def count_files_and_dirs(self, path, verbose=True):
+    def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
         """Count the number of files and directories in the given path.
 
         This function recursively counts the number of files and
diff --git a/qim3d/mesh/_common_mesh_methods.py b/qim3d/mesh/_common_mesh_methods.py
index 535fcae8..87cb9d8b 100644
--- a/qim3d/mesh/_common_mesh_methods.py
+++ b/qim3d/mesh/_common_mesh_methods.py
@@ -8,8 +8,8 @@ from qim3d.utils._logger import log
 def from_volume(
     volume: np.ndarray,
     level: float = None,
-    step_size=1,
-    allow_degenerate=False,
+    step_size: int = 1,
+    allow_degenerate: bool = False,
     padding: Tuple[int, int, int] = (2, 2, 2),
     **kwargs: Any,
 ) -> trimesh.Trimesh:
diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index ea81e53a..e12b1a51 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -20,10 +20,10 @@ class Augmentation:
     """
     
     def __init__(self, 
-                 resize = 'crop', 
-                 transform_train = 'moderate', 
-                 transform_validation = None,
-                 transform_test = None,
+                 resize: str = 'crop', 
+                 transform_train: str = 'moderate', 
+                 transform_validation: str | None = None,
+                 transform_test: str | None = None,
                  mean: float = 0.5, 
                  std: float = 0.5
                 ):
@@ -38,7 +38,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h, im_w, level=None):
+    def augment(self, im_h: int, im_w: int, level: bool | None = None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 38fbdac7..d54f1bf6 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -4,7 +4,9 @@ from PIL import Image
 from qim3d.utils._logger import log
 import torch
 import numpy as np
-
+from typing import Optional, Callable
+import torch.nn as nn
+from ._data import Augmentation
 
 class Dataset(torch.utils.data.Dataset):
     """
@@ -36,7 +38,7 @@ class Dataset(torch.utils.data.Dataset):
             transform=albumentations.Compose([ToTensorV2()]))
         image, target = dataset[idx]
     """
-    def __init__(self, root_path: str, split="train", transform=None):
+    def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None):
         super().__init__()
 
         # Check if split is valid
@@ -58,7 +60,7 @@ class Dataset(torch.utils.data.Dataset):
     def __len__(self):
         return len(self.sample_images)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
@@ -76,7 +78,7 @@ class Dataset(torch.utils.data.Dataset):
 
 
     # TODO: working with images of different sizes
-    def check_shape_consistency(self,sample_images):
+    def check_shape_consistency(self,sample_images: tuple[str]):
         image_shapes= []
         for image_path in sample_images:
             image_shape = self._get_shape(image_path)
@@ -99,7 +101,7 @@ class Dataset(torch.utils.data.Dataset):
         return Image.open(str(image_path)).size
 
 
-def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
+def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
     """
     Checks the compatibility of the image shape with the depth of the model.
     If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
@@ -131,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
     return h_adjust, w_adjust 
 
 
-def prepare_datasets(path: str, val_fraction: float, model, augmentation):
+def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
     """
     Splits and augments the train/validation/test datasets.
 
@@ -172,7 +174,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
     return train_set, val_set, test_set
 
 
-def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False):  
+def prepare_dataloaders(train_set: torch.utils.data, 
+                        val_set: torch.utils.data, 
+                        test_set: torch.utils.data, 
+                        batch_size: int, 
+                        shuffle_train: bool = True, 
+                        num_workers: int = 8, 
+                        pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:  
     """
     Prepares the dataloaders for model training.
 
diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index 7c8a3b86..f46a7481 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -3,24 +3,24 @@
 import torch
 import numpy as np
 
-from torchinfo import summary
+from torchinfo import summary, ModelStatistics
 from qim3d.utils._logger import log
 from qim3d.viz._metrics import plot_metrics
 
 from tqdm.auto import tqdm
 from tqdm.contrib.logging import logging_redirect_tqdm
-
+from models._unet import Hyperparameters
 
 def train_model(
-    model,
-    hyperparameters,
-    train_loader,
-    val_loader,
-    eval_every=1,
-    print_every=5,
-    plot=True,
-    return_loss=False,
-):
+    model: torch.nn.Module,
+    hyperparameters: Hyperparameters,
+    train_loader: torch.utils.data.DataLoader,
+    val_loader: torch.utils.data.DataLoader,
+    eval_every: int = 1,
+    print_every: int = 5,
+    plot: bool = True,
+    return_loss: bool = False,
+) -> tuple[tuple[float], tuple[float]]:
     """Function for training Neural Network models.
 
     Args:
@@ -137,7 +137,7 @@ def train_model(
         return train_loss, val_loss
 
 
-def model_summary(dataloader, model):
+def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module) -> ModelStatistics:
     """Prints the summary of a PyTorch model.
 
     Args:
@@ -160,7 +160,7 @@ def model_summary(dataloader, model):
     return model_s
 
 
-def inference(data, model):
+def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """Performs inference on input data using the specified model.
 
     Performs inference on the input data using the provided model. The input data should be in the form of a list,
@@ -242,7 +242,7 @@ def inference(data, model):
     return inputs, targets, preds
 
 
-def volume_inference(volume, model, threshold=0.5):
+def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray:
     """
     Compute on the entire volume
     Args:
diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py
index 353d10bd..27ee78a4 100644
--- a/qim3d/ml/models/_unet.py
+++ b/qim3d/ml/models/_unet.py
@@ -119,13 +119,13 @@ class Hyperparameters:
 
     def __init__(
         self,
-        model,
-        n_epochs=10,
-        learning_rate=1e-3,
-        optimizer="Adam",
-        momentum=0,
-        weight_decay=0,
-        loss_function="Focal",
+        model: nn.Module,
+        n_epochs: int = 10,
+        learning_rate: float = 1e-3,
+        optimizer: str = "Adam",
+        momentum: float = 0,
+        weight_decay: float = 0,
+        loss_function: str = "Focal",
     ):
 
         # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
@@ -168,13 +168,13 @@ class Hyperparameters:
 
     def model_params(
         self,
-        model,
-        n_epochs,
-        optimizer,
-        learning_rate,
-        weight_decay,
-        momentum,
-        loss_function,
+        model: nn.Module,
+        n_epochs: int,
+        optimizer: str,
+        learning_rate: float,
+        weight_decay: float,
+        momentum: float,
+        loss_function: str,
     ):
 
         optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
@@ -188,7 +188,12 @@ class Hyperparameters:
         return hyper_dict
 
     # selecting the optimizer
-    def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
+    def _optimizer(self, 
+        model: nn.Module, 
+        optimizer: str, 
+        learning_rate: float, 
+        weight_decay: float, 
+        momentum: float):
         from torch.optim import Adam, SGD, RMSprop
 
         if optimizer == "Adam":
@@ -212,7 +217,7 @@ class Hyperparameters:
         return optim
 
     # selecting the loss function
-    def _loss_functions(self, loss_function):
+    def _loss_functions(self, loss_function: str):
         from monai.losses import FocalLoss, DiceLoss, DiceCELoss
         from torch.nn import BCEWithLogitsLoss
 
diff --git a/qim3d/operations/_common_operations_methods.py b/qim3d/operations/_common_operations_methods.py
index 1c14e163..5975c30e 100644
--- a/qim3d/operations/_common_operations_methods.py
+++ b/qim3d/operations/_common_operations_methods.py
@@ -153,7 +153,7 @@ def fade_mask(
 
 
 def overlay_rgb_images(
-    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black:bool = True,
+    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black: bool = True,
 ) -> np.ndarray:
     """
     Overlay an RGB foreground onto an RGB background using alpha blending.
diff --git a/qim3d/processing/_layers.py b/qim3d/processing/_layers.py
index d91ef0a4..306a1284 100644
--- a/qim3d/processing/_layers.py
+++ b/qim3d/processing/_layers.py
@@ -2,7 +2,14 @@ import numpy as np
 from slgbuilder import GraphObject 
 from slgbuilder import MaxflowBuilder
 
-def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, delta:float = 1, min_margin:int = 10, max_margin:int = None, wrap:bool = False):
+def segment_layers(data: np.ndarray, 
+                   inverted: bool = False, 
+                   n_layers: int = 1, 
+                   delta: float = 1, 
+                   min_margin: int = 10, 
+                   max_margin: int = None, 
+                   wrap: bool = False
+                   ) -> list:
     """
     Works on 2D and 3D data.
     Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation
@@ -82,7 +89,7 @@ def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, del
 
     return segmentations
 
-def get_lines(segmentations:list|np.ndarray) -> list:
+def get_lines(segmentations:list[np.ndarray]) -> list:
     """
     Expects list of arrays where each array is 2D segmentation with only 2 classes. This function gets the border between those two
     so it could be plotted. Used with qim3d.processing.segment_layers
diff --git a/qim3d/processing/_local_thickness.py b/qim3d/processing/_local_thickness.py
index a7e877d6..947f8865 100644
--- a/qim3d/processing/_local_thickness.py
+++ b/qim3d/processing/_local_thickness.py
@@ -10,7 +10,7 @@ def local_thickness(
     image: np.ndarray,
     scale: float = 1,
     mask: Optional[np.ndarray] = None,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> np.ndarray:
     """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness)
diff --git a/qim3d/processing/_structure_tensor.py b/qim3d/processing/_structure_tensor.py
index 10eba890..f6dc98cd 100644
--- a/qim3d/processing/_structure_tensor.py
+++ b/qim3d/processing/_structure_tensor.py
@@ -12,7 +12,7 @@ def structure_tensor(
     rho: float = 6.0,
     base_noise: bool = True,
     full: bool = False,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> Tuple[np.ndarray, np.ndarray]:
     """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/).
diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py
index f43f0dc2..45725fee 100644
--- a/qim3d/segmentation/_connected_components.py
+++ b/qim3d/segmentation/_connected_components.py
@@ -4,7 +4,7 @@ from qim3d.utils._logger import log
 
 
 class CC:
-    def __init__(self, connected_components, num_connected_components):
+    def __init__(self, connected_components: np.ndarray, num_connected_components: int):
         """
         Initializes a ConnectedComponents object.
 
@@ -23,7 +23,7 @@ class CC:
         """
         return self.cc_count
 
-    def get_cc(self, index=None, crop=False):
+    def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray:
         """
         Get the connected component with the given index, if index is None selects a random component.
 
@@ -52,7 +52,7 @@ class CC:
         
         return volume
 
-    def get_bounding_box(self, index=None):
+    def get_bounding_box(self, index: int|None =None)-> list[tuple]:
         """Get the bounding boxes of the connected components.
 
         Args:
diff --git a/qim3d/utils/_doi.py b/qim3d/utils/_doi.py
index b3ece663..02095f33 100644
--- a/qim3d/utils/_doi.py
+++ b/qim3d/utils/_doi.py
@@ -4,7 +4,7 @@ import requests
 from qim3d.utils._logger import log
 
 
-def _validate_response(response):
+def _validate_response(response: requests.Response) -> bool:
     # Check if we got a good response
     if not response.ok:
         log.error(f"Could not read the provided DOI ({response.reason})")
@@ -13,7 +13,7 @@ def _validate_response(response):
     return True
 
 
-def _doi_to_url(doi):
+def _doi_to_url(doi: str) -> str:
     if doi[:3] != "http":
         url = "https://doi.org/" + doi
     else:
@@ -22,7 +22,7 @@ def _doi_to_url(doi):
     return url
 
 
-def _make_request(doi, header):
+def _make_request(doi: str, header: str) -> requests.Response:
     # Get url from doi
     url = _doi_to_url(doi)
 
@@ -35,7 +35,7 @@ def _make_request(doi, header):
     return response
 
 
-def _log_and_get_text(doi, header):
+def _log_and_get_text(doi, header) -> str:
     response = _make_request(doi, header)
 
     if response and response.encoding:
@@ -50,13 +50,13 @@ def _log_and_get_text(doi, header):
         return text
 
 
-def get_bibtex(doi):
+def get_bibtex(doi: str):
     """Generates bibtex from doi"""
     header = {"Accept": "application/x-bibtex"}
 
     return _log_and_get_text(doi, header)
 
-def cusom_header(doi, header):
+def custom_header(doi: str, header: str) -> str:
     """Allows a custom header to be passed
 
     For example:
@@ -67,7 +67,7 @@ def cusom_header(doi, header):
     """
     return _log_and_get_text(doi, header)
 
-def get_metadata(doi):
+def get_metadata(doi: str) -> dict:
     """Generates a metadata dictionary from doi"""
     header = {"Accept": "application/vnd.citationstyles.csl+json"}
     response = _make_request(doi, header)
@@ -76,7 +76,7 @@ def get_metadata(doi):
 
     return metadata
 
-def get_reference(doi):
+def get_reference(doi: str) -> str:
     """Generates a metadata dictionary from doi and use it to build a reference string"""
 
     metadata = get_metadata(doi)
@@ -84,7 +84,7 @@ def get_reference(doi):
 
     return reference_string
 
-def build_reference_string(metadata):
+def build_reference_string(metadata: dict) -> str:
     """Generates a reference string from metadata"""
     authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']])
     year = metadata['issued']['date-parts'][0][0]
diff --git a/qim3d/utils/_misc.py b/qim3d/utils/_misc.py
index a754ecb8..3d8660b4 100644
--- a/qim3d/utils/_misc.py
+++ b/qim3d/utils/_misc.py
@@ -12,7 +12,7 @@ import difflib
 import qim3d
 
 
-def get_local_ip():
+def get_local_ip() -> str:
     """Retrieves the local IP address of the current machine.
 
     The function uses a socket to determine the local IP address.
@@ -42,7 +42,7 @@ def get_local_ip():
     return ip_address
 
 
-def port_from_str(s):
+def port_from_str(s: str) -> int:
     """
     Generates a port number from a given string.
 
@@ -65,7 +65,7 @@ def port_from_str(s):
     return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4)
 
 
-def gradio_header(title, port):
+def gradio_header(title: str, port: int) -> None:
     """Display the header for a Gradio server.
 
     Displays a formatted header containing the provided title,
@@ -99,7 +99,7 @@ def gradio_header(title, port):
     ouf.showlist(details, style="box", title="Starting gradio server")
 
 
-def sizeof(num, suffix="B"):
+def sizeof(num: float, suffix: str = "B") -> str:
     """Converts a number to a human-readable string representing its size.
 
     Converts the given number to a human-readable string representing its size in
@@ -131,7 +131,7 @@ def sizeof(num, suffix="B"):
     return f"{num:.1f} Y{suffix}"
 
 
-def find_similar_paths(path):
+def find_similar_paths(path: str) -> list[str]:
     parent_dir = os.path.dirname(path) or "."
     parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
     valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
@@ -165,14 +165,14 @@ def get_file_size(file_path: str) -> int:
     return file_size
 
 
-def stringify_path(path):
+def stringify_path(path: os.PathLike) -> str:
     """Converts an os.PathLike object to a string"""
     if isinstance(path, os.PathLike):
         path = path.__fspath__()
     return path
 
 
-def get_port_dict():
+def get_port_dict() -> dict:
     # Gets user and port
     username = getpass.getuser()
     url = f"https://platform.qim.dk/qim-api/get-port/{username}"
@@ -189,7 +189,7 @@ def get_port_dict():
     return port_dict
 
 
-def get_css():
+def get_css() -> str:
 
     current_directory = os.path.dirname(os.path.abspath(__file__))
     parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
@@ -201,7 +201,7 @@ def get_css():
     return css_content
 
 
-def downscale_img(img, max_voxels=512**3):
+def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray:
     """Downscale image if total number of voxels exceeds 512³.
 
     Args:
@@ -226,7 +226,7 @@ def downscale_img(img, max_voxels=512**3):
     return zoom(img, zoom_factor, order=0)
 
 
-def scale_to_float16(arr: np.ndarray):
+def scale_to_float16(arr: np.ndarray) -> np.ndarray:
     """
     Scale the input array to the float16 data type.
 
diff --git a/qim3d/utils/_ome_zarr.py b/qim3d/utils/_ome_zarr.py
index 452997f1..db7ad145 100644
--- a/qim3d/utils/_ome_zarr.py
+++ b/qim3d/utils/_ome_zarr.py
@@ -1,7 +1,7 @@
 from zarr.util import normalize_chunks, normalize_dtype, normalize_shape
 import numpy as np
 
-def get_chunk_size(shape:tuple, dtype):
+def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]:
     """
     How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use 
     in qim3d.io.export_ome_zarr function
@@ -20,7 +20,7 @@ def get_chunk_size(shape:tuple, dtype):
     return chunks
 
 
-def get_n_chunks(shapes:tuple, dtypes:tuple):
+def get_n_chunks(shapes: tuple, dtypes: tuple) -> int:
     """
     Estimates how many chunks we will use in advence so we can pass this number to a progress bar and track how many
     have been already written to disk
diff --git a/qim3d/utils/_progress_bar.py b/qim3d/utils/_progress_bar.py
index 629026f4..57f33738 100644
--- a/qim3d/utils/_progress_bar.py
+++ b/qim3d/utils/_progress_bar.py
@@ -24,7 +24,7 @@ class RepeatTimer(Timer):
             self.function(*self.args, **self.kwargs)
 
 class ProgressBar(ABC):
-    def __init__(self,tqdm_kwargs:dict, repeat_time: float,  *args, **kwargs):
+    def __init__(self, tqdm_kwargs: dict, repeat_time: float,  *args, **kwargs):
         """
         Context manager for ('with' statement) to track progress during a long progress over 
         which we don't have control (like loading a file) and thus can not insert the tqdm
@@ -98,7 +98,7 @@ class FileLoadingProgressBar(ProgressBar):
         super().__init__( tqdm_kwargs, repeat_time)
         self.process = psutil.Process()
 
-    def get_new_update(self):
+    def get_new_update(self) -> int:
         counters = self.process.io_counters()
         try:
             memory = counters.read_chars
@@ -107,7 +107,7 @@ class FileLoadingProgressBar(ProgressBar):
         return memory
 
 class OmeZarrExportProgressBar(ProgressBar):
-    def __init__(self,path:str, n_chunks:int, reapeat_time="auto"):
+    def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"):
         """
         Context manager to track the exporting of OmeZarr files.
 
@@ -152,7 +152,7 @@ class OmeZarrExportProgressBar(ProgressBar):
         self.last_update = 0
 
     def get_new_update(self):
-        def file_count(folder_path:str):
+        def file_count(folder_path: str) -> int:
             """
             Goes recursively through the folders and counts how many files are there, 
             Doesn't count metadata json files
diff --git a/qim3d/utils/_server.py b/qim3d/utils/_server.py
index 9d20fccd..f6165158 100644
--- a/qim3d/utils/_server.py
+++ b/qim3d/utils/_server.py
@@ -14,7 +14,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
         super().end_headers()
 
-    def list_directory(self, path):
+    def list_directory(self, path: str|os.PathLike):
         """Helper to produce a directory listing, includes hidden files."""
         try:
             file_list = os.listdir(path)
@@ -49,7 +49,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         # Write the encoded HTML directly to the response
         self.wfile.write(encoded)
 
-def start_http_server(directory, port=8000):
+def start_http_server(directory: str, port: int = 8000) -> HTTPServer:
     """
     Starts an HTTP server serving the specified directory on the given port with CORS enabled.
     
diff --git a/qim3d/utils/_system.py b/qim3d/utils/_system.py
index 19792627..ccccb672 100644
--- a/qim3d/utils/_system.py
+++ b/qim3d/utils/_system.py
@@ -36,7 +36,7 @@ class Memory:
             round(self.free_pct, 1),
         )
 
-def _test_disk_speed(file_size_bytes=1024, ntimes=10):
+def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]:
     '''
     Test the write and read speed of the disk by writing a file of a given size
     and then reading it back.
@@ -95,7 +95,7 @@ def _test_disk_speed(file_size_bytes=1024, ntimes=10):
     return avg_write_speed, write_speed_std, avg_read_speed, read_speed_std
 
 
-def disk_report(file_size=1024 * 1024 * 100, ntimes=10):
+def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None:
     '''
     Report the average write and read speed of the disk.
 
diff --git a/qim3d/viz/_cc.py b/qim3d/viz/_cc.py
index 1c0ed560..d07024b3 100644
--- a/qim3d/viz/_cc.py
+++ b/qim3d/viz/_cc.py
@@ -2,18 +2,18 @@ import matplotlib.pyplot as plt
 import numpy as np
 import qim3d
 from qim3d.utils._logger import log
-
+from qim3d.segmentation._connected_components import CC
 
 def plot_cc(
-    connected_components,
+    connected_components: CC,
     component_indexs: list | tuple = None,
-    max_cc_to_plot=32,
-    overlay=None,
-    crop=False,
-    show=True,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
+    max_cc_to_plot: int = 32,
+    overlay: np.ndarray = None,
+    crop: bool = False,
+    display_figure: bool = True,
+    color_map: str = "viridis",
+    value_min: float = None,
+    value_max: float = None,
     **kwargs,
 ) -> list[plt.Figure]:
     """
@@ -75,7 +75,7 @@ def plot_cc(
                 cc = connected_components.get_cc(component, crop=False)
                 overlay_crop = np.where(cc == 0, 0, overlay)
             fig = qim3d.viz.slices_grid(
-                overlay_crop, show=show, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs
+                overlay_crop, display_figure=display_figure, color_map=color_map, value_min=value_min, value_max=value_max, **kwargs
             )
         else:
             # assigns discrete color map to each connected component if not given
@@ -89,7 +89,7 @@ def plot_cc(
 
         figs.append(fig)
 
-    if not show:
+    if not display_figure:
         return figs
 
     return
diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py
index 2198cd94..e03f0852 100644
--- a/qim3d/viz/_data_exploration.py
+++ b/qim3d/viz/_data_exploration.py
@@ -9,7 +9,9 @@ from typing import List, Optional, Union
 
 import dask.array as da
 import ipywidgets as widgets
+import matplotlib.figure
 import matplotlib.pyplot as plt
+import matplotlib
 from IPython.display import SVG, display
 import matplotlib
 import numpy as np
@@ -29,7 +31,7 @@ def slices_grid(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-    image_size=None,
+    image_size: int = None,
     image_height: int = 2,
     image_width: int = 2,
     display_figure: bool = False,
@@ -38,7 +40,7 @@ def slices_grid(
     color_bar: bool = False,
     color_bar_style: str = "small",
     **matplotlib_imshow_kwargs,
-) -> plt.Figure:
+) -> matplotlib.figure.Figure:
     """Displays one or several slices from a 3d volume.
 
     By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
@@ -296,7 +298,7 @@ def slices_grid(
     return fig
 
 
-def _get_slice_range(position: int, num_slices: int, n_total):
+def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray:
     """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
     start_idx = position - num_slices // 2
     end_idx = (
@@ -324,7 +326,7 @@ def slicer(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
+    image_size: int = None,
     color_bar: bool = False,
     **matplotlib_imshow_kwargs,
 ) -> widgets.interactive:
@@ -401,8 +403,8 @@ def slicer_orthogonal(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
-):
+    image_size: int = None,
+)-> widgets.interactive:
     """Interactive widget for visualizing orthogonal slices of a 3D volume.
 
     Args:
@@ -461,7 +463,7 @@ def fade_mask(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-):
+)-> widgets.interactive:
     """Interactive widget for visualizing the effect of edge fading on a 3D volume.
 
     This can be used to select the best parameters before applying the mask.
@@ -596,7 +598,7 @@ def fade_mask(
     return slicer_obj
 
 
-def chunks(zarr_path: str, **kwargs):
+def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
     """
     Function to visualize chunks of a Zarr dataset using the specified visualization method.
 
@@ -854,14 +856,14 @@ def histogram(
     log_scale: bool = False,
     despine: bool = True,
     show_title: bool = True,
-    color="qim3d",
-    edgecolor=None,
-    figsize=(8, 4.5),
-    element="step",
-    return_fig=False,
-    show=True,
+    color: str = "qim3d",
+    edgecolor: str|None = None,
+    figsize: tuple[float, float] = (8, 4.5),
+    element: str = "step",
+    return_fig: bool = False,
+    show: bool = True,
     **sns_kwargs,
-):
+) -> None|matplotlib.figure.Figure:
     """
     Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
 
diff --git a/qim3d/viz/_detection.py b/qim3d/viz/_detection.py
index a0926105..14ba42bf 100644
--- a/qim3d/viz/_detection.py
+++ b/qim3d/viz/_detection.py
@@ -6,7 +6,7 @@ from IPython.display import clear_output, display
 import qim3d
 
 
-def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs):
+def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive:
     """
     Plots the blobs found on a slice of the volume.
 
diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py
index b80e9556..b45a138c 100644
--- a/qim3d/viz/_k3d.py
+++ b/qim3d/viz/_k3d.py
@@ -15,18 +15,18 @@ from qim3d.utils._misc import downscale_img, scale_to_float16
 
 
 def volumetric(
-    img,
-    aspectmode="data",
-    show=True,
-    save=False,
-    grid_visible=False,
-    color_map='magma',
-    constant_opacity=False,
-    vmin=None,
-    vmax=None,
-    samples="auto",
-    max_voxels=512**3,
-    data_type="scaled_float16",
+    img: np.ndarray,
+    aspectmode: str = "data",
+    show: bool = True,
+    save: bool = False,
+    grid_visible: bool = False,
+    color_map: str = 'magma',
+    constant_opacity: bool = False,
+    vmin: float|None = None,
+    vmax: float|None = None,
+    samples: int|str = "auto",
+    max_voxels: int = 512**3,
+    data_type: str = "scaled_float16",
     **kwargs,
 ):
     """
@@ -175,13 +175,13 @@ def volumetric(
 
 
 def mesh(
-    verts,
-    faces,
-    wireframe=True,
-    flat_shading=True,
-    grid_visible=False,
-    show=True,
-    save=False,
+    verts: np.ndarray,
+    faces: np.ndarray,
+    wireframe: bool = True,
+    flat_shading: bool = True,
+    grid_visible: bool = False,
+    show: bool = True,
+    save: bool = False,
     **kwargs,
 ):
     """
diff --git a/qim3d/viz/_layers2d.py b/qim3d/viz/_layers2d.py
index 1e284461..676845a5 100644
--- a/qim3d/viz/_layers2d.py
+++ b/qim3d/viz/_layers2d.py
@@ -8,7 +8,7 @@ import numpy as np
 from PIL import Image
 
 
-def image_with_lines(image:np.ndarray, lines: list, line_thickness:float|int) -> Image:
+def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> Image:
     """
     Plots the image and plots the lines on top of it. Then extracts it as PIL.Image and in the same size as the input image was.
     Paramters:
diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py
index 53a7f3a6..16d7c7ce 100644
--- a/qim3d/viz/_metrics.py
+++ b/qim3d/viz/_metrics.py
@@ -1,19 +1,21 @@
 """Visualization tools"""
 
+import matplotlib.figure
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 from qim3d.utils._logger import log
-
+import torch
+import matplotlib
 
 def plot_metrics(
-    *metrics,
-    linestyle="-",
-    batch_linestyle="dotted",
-    labels: list = None,
+    *metrics: tuple[dict[str, float]],
+    linestyle: str = "-",
+    batch_linestyle: str = "dotted",
+    labels: list|None = None,
     figsize: tuple = (16, 6),
-    show=False
+    show: bool = False
 ):
     """
     Plots the metrics over epochs and batches.
@@ -79,8 +81,13 @@ def plot_metrics(
 
 
 def grid_overview(
-    data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
-):
+    data: list|torch.utils.data.Dataset, 
+    num_images: int = 7, 
+    cmap_im: str = "gray", 
+    cmap_segm: str = "viridis", 
+    alpha: float = 0.5, 
+    show: bool = False
+)-> matplotlib.figure.Figure:
     """Displays an overview grid of images, labels, and masks (if they exist).
 
     Labels are the annotated target segmentations
@@ -174,13 +181,13 @@ def grid_overview(
 
 
 def grid_pred(
-    in_targ_preds,
-    num_images=7,
-    cmap_im="gray",
-    cmap_segm="viridis",
-    alpha=0.5,
-    show=False,
-):
+    in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray],
+    num_images: int = 7,
+    cmap_im: str = "gray",
+    cmap_segm: str = "viridis",
+    alpha: float = 0.5,
+    show: bool = False,
+)-> matplotlib.figure.Figure:
     """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
 
     Displays a grid of subplots representing different aspects of the input images and segmentations.
@@ -282,7 +289,7 @@ def grid_pred(
     return fig
 
 
-def vol_masked(vol, vol_mask, viz_delta=128):
+def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray:
     """
     Applies masking to a volume based on a binary volume mask.
 
diff --git a/qim3d/viz/_structure_tensor.py b/qim3d/viz/_structure_tensor.py
index 63d141c4..81488102 100644
--- a/qim3d/viz/_structure_tensor.py
+++ b/qim3d/viz/_structure_tensor.py
@@ -19,9 +19,9 @@ def vectors(
     vec: np.ndarray,
     axis: int = 0,
     volume_cmap:str = 'grey',
-    vmin:float = None,
-    vmax:float = None,
-    slice_idx: Optional[Union[int, float]] = None,
+    vmin: float|None = None,
+    vmax: float|None = None,
+    slice_idx: Union[int, float]|None = None,
     grid_size: int = 10,
     interactive: bool = True,
     figsize: Tuple[int, int] = (10, 5),
-- 
GitLab