diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py
index 9244ef88e8aa73e19ad98c237d37299e85460b0c..0d776822689430b733dfa1467ecdff0b48d87bb1 100644
--- a/qim3d/io/_convert.py
+++ b/qim3d/io/_convert.py
@@ -7,6 +7,7 @@ import numpy as np
 import tifffile as tiff
 import zarr
 from tqdm import tqdm
+import zarr.core
 
 from qim3d.utils._misc import stringify_path
 from qim3d.io._saving import save
@@ -67,7 +68,7 @@ class Convert:
             else:
                 raise ValueError("Invalid path")
 
-    def convert_tif_to_zarr(self, tif_path: str, zarr_path: str):
+    def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a tiff file to a zarr file
 
         Args:
@@ -97,7 +98,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_tif(self, zarr_path: str, tif_path: str):
+    def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
         """Convert a zarr file to a tiff file
 
         Args:
@@ -110,7 +111,7 @@ class Convert:
         z = zarr.open(zarr_path)
         save(tif_path, z)
 
-    def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str):
+    def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
         """Convert a nifti file to a zarr file
 
         Args:
@@ -139,7 +140,7 @@ class Convert:
 
         return z
 
-    def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False):
+    def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
         """Convert a zarr file to a nifti file
 
         Args:
@@ -153,7 +154,7 @@ class Convert:
         save(nifti_path, z, compression=compression)
         
 
-def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
+def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
     """Convert a file to another format without loading the entire file into memory
 
     Args:
diff --git a/qim3d/io/_loading.py b/qim3d/io/_loading.py
index b558c3f4eabf7a64f819b26abb9c073eb709cda6..0f4641e7708bb36178eb98d1041f30e9ee420862 100644
--- a/qim3d/io/_loading.py
+++ b/qim3d/io/_loading.py
@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
 from qim3d.utils._progress_bar import FileLoadingProgressBar
 import trimesh
 
+from typing import Optional, Dict
+
 dask.config.set(scheduler="processes") 
 
 
@@ -100,7 +102,7 @@ class DataLoader:
 
         return vol
 
-    def load_h5(self, path: str):
+    def load_h5(self, path: str) -> tuple[np.ndarray, Optional[Dict]]:
         """Load an HDF5 file from the specified path.
 
         Args:
@@ -183,7 +185,7 @@ class DataLoader:
         else:
             return vol
 
-    def load_tiff_stack(self, path: str):
+    def load_tiff_stack(self, path: str) -> np.ndarray|np.memmap:
         """Load a stack of TIFF files from the specified path.
 
         Args:
@@ -237,7 +239,7 @@ class DataLoader:
 
         return vol
 
-    def load_txrm(self, path: str):
+    def load_txrm(self, path: str) -> tuple[dask.array|np.ndarray, Optional[Dict]]:
         """Load a TXRM/XRM/TXM file from the specified path.
 
         Args:
@@ -766,7 +768,7 @@ def load(
     force_load: bool = False,
     dim_order: tuple = (2, 1, 0),
     **kwargs,
-):
+) -> np.ndarray:
     """
     Load data from the specified file or directory.
 
@@ -854,7 +856,7 @@ def load(
 
     return data
 
-def load_mesh(filename: str):
+def load_mesh(filename: str) -> trimesh.Trimesh:
     """
     Load a mesh from an .obj file using trimesh.
 
diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py
index 36ec5b800f5bad7e91dc679cec6a86e97a6456f0..2481f2ba09d68b77d88ffc7448c21de84d9bdb92 100644
--- a/qim3d/io/_ome_zarr.py
+++ b/qim3d/io/_ome_zarr.py
@@ -82,7 +82,7 @@ class OMEScaler(
 
 
         """
-        def resize_zoom(vol, scale_factors, order, scaled_shape):
+        def resize_zoom(vol: dask.array, scale_factors, order, scaled_shape):
 
             # Get the chunksize needed so that all the blocks match the new shape
             # This snippet comes from the original OME-Zarr-python library
@@ -182,7 +182,7 @@ class OMEScaler(
 
 def export_ome_zarr(
     path: str,
-    data: np.ndarray,
+    data: np.ndarray|dask.array,
     chunk_size: int = 256,
     downsample_rate: int = 2,
     order: int = 1,
diff --git a/qim3d/io/_sync.py b/qim3d/io/_sync.py
index 5cc4b5e566490f3efa586356f7e21d0a1c920c57..1119907206e66742fd0d9a740bda47ca65de2743 100644
--- a/qim3d/io/_sync.py
+++ b/qim3d/io/_sync.py
@@ -28,7 +28,7 @@ class Sync:
 
             return False
 
-    def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True):
+    def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
         """Check if all files from 'source' are in 'destination'
 
         This function compares the files in the 'source' directory to those in
@@ -80,7 +80,7 @@ class Sync:
 
         return diff_files
 
-    def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True):
+    def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
         """Checks whether 'source' and 'destination' directories are synchronized.
 
         This function compares the contents of two directories
@@ -168,7 +168,7 @@ class Sync:
             )
         return
 
-    def count_files_and_dirs(self, path: str, verbose: bool = True):
+    def count_files_and_dirs(self, path: str, verbose: bool = True) -> tuple[int, int]:
         """Count the number of files and directories in the given path.
 
         This function recursively counts the number of files and
diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index 8cb0b472af009b9801c88d1b0c18c0b56c78a415..e12b1a51849593e0c14287f0c2ab1795eb86ace9 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -20,10 +20,10 @@ class Augmentation:
     """
     
     def __init__(self, 
-                 resize = 'crop', 
-                 transform_train = 'moderate', 
-                 transform_validation = None,
-                 transform_test = None,
+                 resize: str = 'crop', 
+                 transform_train: str = 'moderate', 
+                 transform_validation: str | None = None,
+                 transform_test: str | None = None,
                  mean: float = 0.5, 
                  std: float = 0.5
                 ):
@@ -38,7 +38,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h: int, im_w: int, level: bool = None):
+    def augment(self, im_h: int, im_w: int, level: bool | None = None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 17f9a281eb8b5dd71388d3438efbdd103d42b2cf..d54f1bf6d7a457457b8fea638fd8057410d2a853 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -101,7 +101,7 @@ class Dataset(torch.utils.data.Dataset):
         return Image.open(str(image_path)).size
 
 
-def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
+def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
     """
     Checks the compatibility of the image shape with the depth of the model.
     If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
@@ -133,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
     return h_adjust, w_adjust 
 
 
-def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation):
+def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
     """
     Splits and augments the train/validation/test datasets.
 
@@ -180,7 +180,7 @@ def prepare_dataloaders(train_set: torch.utils.data,
                         batch_size: int, 
                         shuffle_train: bool = True, 
                         num_workers: int = 8, 
-                        pin_memory: bool = False):  
+                        pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:  
     """
     Prepares the dataloaders for model training.
 
diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index 1c4427919492b7dfff37372df94a7086bc1eac10..f46a7481fe93cd299eb43566bd947c7ea3355daf 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -3,7 +3,7 @@
 import torch
 import numpy as np
 
-from torchinfo import summary
+from torchinfo import summary, ModelStatistics
 from qim3d.utils._logger import log
 from qim3d.viz._metrics import plot_metrics
 
@@ -137,7 +137,7 @@ def train_model(
         return train_loss, val_loss
 
 
-def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module):
+def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module) -> ModelStatistics:
     """Prints the summary of a PyTorch model.
 
     Args:
diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py
index 1f29b23579e4aa9fbe5fa92e2a3045ff4d50317d..45725feec526e0a31de67a0737f99d6f955d2333 100644
--- a/qim3d/segmentation/_connected_components.py
+++ b/qim3d/segmentation/_connected_components.py
@@ -23,7 +23,7 @@ class CC:
         """
         return self.cc_count
 
-    def get_cc(self, index=None, crop=False):
+    def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray:
         """
         Get the connected component with the given index, if index is None selects a random component.
 
@@ -52,7 +52,7 @@ class CC:
         
         return volume
 
-    def get_bounding_box(self, index=None):
+    def get_bounding_box(self, index: int|None =None)-> list[tuple]:
         """Get the bounding boxes of the connected components.
 
         Args:
diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py
index 766aae157505fa7bccf72e6bfaa57b5adee02387..e03f08526b176ff61d878cf11876473c7a36577c 100644
--- a/qim3d/viz/_data_exploration.py
+++ b/qim3d/viz/_data_exploration.py
@@ -298,7 +298,7 @@ def slices_grid(
     return fig
 
 
-def _get_slice_range(position: int, num_slices: int, n_total) -> np.ndarray:
+def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray:
     """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
     start_idx = position - num_slices // 2
     end_idx = (
@@ -856,14 +856,14 @@ def histogram(
     log_scale: bool = False,
     despine: bool = True,
     show_title: bool = True,
-    color="qim3d",
-    edgecolor=None,
-    figsize=(8, 4.5),
-    element="step",
-    return_fig=False,
-    show=True,
+    color: str = "qim3d",
+    edgecolor: str|None = None,
+    figsize: tuple[float, float] = (8, 4.5),
+    element: str = "step",
+    return_fig: bool = False,
+    show: bool = True,
     **sns_kwargs,
-) -> Optional[matplotlib.figure.Figure]:
+) -> None|matplotlib.figure.Figure:
     """
     Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
 
diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py
index 8b2c23accd93cdd19f7c263e273031f7764c1463..b45a138ccaf80682851dc53bff770a29e11c6172 100644
--- a/qim3d/viz/_k3d.py
+++ b/qim3d/viz/_k3d.py
@@ -22,8 +22,8 @@ def volumetric(
     grid_visible: bool = False,
     color_map: str = 'magma',
     constant_opacity: bool = False,
-    vmin: float = None,
-    vmax: float = None,
+    vmin: float|None = None,
+    vmax: float|None = None,
     samples: int|str = "auto",
     max_voxels: int = 512**3,
     data_type: str = "scaled_float16",
diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py
index ded37d240add0f7a9add57fc1f19ebb39dabfe9b..16d7c7ce94bf427e82c20221d6226333f11b1bd5 100644
--- a/qim3d/viz/_metrics.py
+++ b/qim3d/viz/_metrics.py
@@ -13,7 +13,7 @@ def plot_metrics(
     *metrics: tuple[dict[str, float]],
     linestyle: str = "-",
     batch_linestyle: str = "dotted",
-    labels: list = None,
+    labels: list|None = None,
     figsize: tuple = (16, 6),
     show: bool = False
 ):
diff --git a/qim3d/viz/_structure_tensor.py b/qim3d/viz/_structure_tensor.py
index 63d141c41ae9dd365e89a9d82c959d08173d76a6..8148810287fe206334cf1f1e72103e7b3db8b74e 100644
--- a/qim3d/viz/_structure_tensor.py
+++ b/qim3d/viz/_structure_tensor.py
@@ -19,9 +19,9 @@ def vectors(
     vec: np.ndarray,
     axis: int = 0,
     volume_cmap:str = 'grey',
-    vmin:float = None,
-    vmax:float = None,
-    slice_idx: Optional[Union[int, float]] = None,
+    vmin: float|None = None,
+    vmax: float|None = None,
+    slice_idx: Union[int, float]|None = None,
     grid_size: int = 10,
     interactive: bool = True,
     figsize: Tuple[int, int] = (10, 5),