From 58c2f5018e66d394cc6607d88b1fbabf276884df Mon Sep 17 00:00:00 2001
From: s214735 <s214735@dtu.dk>
Date: Thu, 26 Dec 2024 10:44:58 +0100
Subject: [PATCH] Checked rest of files

---
 qim3d/ml/_augmentations.py                    |  2 +-
 qim3d/ml/_data.py                             | 20 +++++++---
 qim3d/ml/_ml_utils.py                         | 26 ++++++-------
 qim3d/ml/models/_unet.py                      | 37 ++++++++++--------
 .../operations/_common_operations_methods.py  |  2 +-
 qim3d/processing/_layers.py                   |  9 ++++-
 qim3d/processing/_local_thickness.py          |  2 +-
 qim3d/processing/_structure_tensor.py         |  2 +-
 qim3d/segmentation/_connected_components.py   |  2 +-
 qim3d/utils/_doi.py                           | 18 ++++-----
 qim3d/utils/_misc.py                          | 20 +++++-----
 qim3d/utils/_ome_zarr.py                      |  4 +-
 qim3d/utils/_progress_bar.py                  |  8 ++--
 qim3d/utils/_server.py                        |  4 +-
 qim3d/utils/_system.py                        |  4 +-
 qim3d/viz/_cc.py                              | 12 +++---
 qim3d/viz/_data_exploration.py                | 20 +++++-----
 qim3d/viz/_detection.py                       |  2 +-
 qim3d/viz/_k3d.py                             | 38 +++++++++----------
 qim3d/viz/_layers2d.py                        |  2 +-
 qim3d/viz/_metrics.py                         | 37 ++++++++++--------
 21 files changed, 150 insertions(+), 121 deletions(-)

diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index ea81e53a..8cb0b472 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -38,7 +38,7 @@ class Augmentation:
         self.transform_validation = transform_validation
         self.transform_test = transform_test
     
-    def augment(self, im_h, im_w, level=None):
+    def augment(self, im_h: int, im_w: int, level: bool = None):
         """
         Returns an albumentations.core.composition.Compose class depending on the augmentation level.
         A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 38fbdac7..17f9a281 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -4,7 +4,9 @@ from PIL import Image
 from qim3d.utils._logger import log
 import torch
 import numpy as np
-
+from typing import Optional, Callable
+import torch.nn as nn
+from ._data import Augmentation
 
 class Dataset(torch.utils.data.Dataset):
     """
@@ -36,7 +38,7 @@ class Dataset(torch.utils.data.Dataset):
             transform=albumentations.Compose([ToTensorV2()]))
         image, target = dataset[idx]
     """
-    def __init__(self, root_path: str, split="train", transform=None):
+    def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None):
         super().__init__()
 
         # Check if split is valid
@@ -58,7 +60,7 @@ class Dataset(torch.utils.data.Dataset):
     def __len__(self):
         return len(self.sample_images)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int):
         image_path = self.sample_images[idx]
         target_path = self.sample_targets[idx]
 
@@ -76,7 +78,7 @@ class Dataset(torch.utils.data.Dataset):
 
 
     # TODO: working with images of different sizes
-    def check_shape_consistency(self,sample_images):
+    def check_shape_consistency(self,sample_images: tuple[str]):
         image_shapes= []
         for image_path in sample_images:
             image_shape = self._get_shape(image_path)
@@ -131,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
     return h_adjust, w_adjust 
 
 
-def prepare_datasets(path: str, val_fraction: float, model, augmentation):
+def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation):
     """
     Splits and augments the train/validation/test datasets.
 
@@ -172,7 +174,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation):
     return train_set, val_set, test_set
 
 
-def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False):  
+def prepare_dataloaders(train_set: torch.utils.data, 
+                        val_set: torch.utils.data, 
+                        test_set: torch.utils.data, 
+                        batch_size: int, 
+                        shuffle_train: bool = True, 
+                        num_workers: int = 8, 
+                        pin_memory: bool = False):  
     """
     Prepares the dataloaders for model training.
 
diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index 7c8a3b86..1c442791 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -9,18 +9,18 @@ from qim3d.viz._metrics import plot_metrics
 
 from tqdm.auto import tqdm
 from tqdm.contrib.logging import logging_redirect_tqdm
-
+from models._unet import Hyperparameters
 
 def train_model(
-    model,
-    hyperparameters,
-    train_loader,
-    val_loader,
-    eval_every=1,
-    print_every=5,
-    plot=True,
-    return_loss=False,
-):
+    model: torch.nn.Module,
+    hyperparameters: Hyperparameters,
+    train_loader: torch.utils.data.DataLoader,
+    val_loader: torch.utils.data.DataLoader,
+    eval_every: int = 1,
+    print_every: int = 5,
+    plot: bool = True,
+    return_loss: bool = False,
+) -> tuple[tuple[float], tuple[float]]:
     """Function for training Neural Network models.
 
     Args:
@@ -137,7 +137,7 @@ def train_model(
         return train_loss, val_loss
 
 
-def model_summary(dataloader, model):
+def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module):
     """Prints the summary of a PyTorch model.
 
     Args:
@@ -160,7 +160,7 @@ def model_summary(dataloader, model):
     return model_s
 
 
-def inference(data, model):
+def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """Performs inference on input data using the specified model.
 
     Performs inference on the input data using the provided model. The input data should be in the form of a list,
@@ -242,7 +242,7 @@ def inference(data, model):
     return inputs, targets, preds
 
 
-def volume_inference(volume, model, threshold=0.5):
+def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray:
     """
     Compute on the entire volume
     Args:
diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py
index 353d10bd..27ee78a4 100644
--- a/qim3d/ml/models/_unet.py
+++ b/qim3d/ml/models/_unet.py
@@ -119,13 +119,13 @@ class Hyperparameters:
 
     def __init__(
         self,
-        model,
-        n_epochs=10,
-        learning_rate=1e-3,
-        optimizer="Adam",
-        momentum=0,
-        weight_decay=0,
-        loss_function="Focal",
+        model: nn.Module,
+        n_epochs: int = 10,
+        learning_rate: float = 1e-3,
+        optimizer: str = "Adam",
+        momentum: float = 0,
+        weight_decay: float = 0,
+        loss_function: str = "Focal",
     ):
 
         # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
@@ -168,13 +168,13 @@ class Hyperparameters:
 
     def model_params(
         self,
-        model,
-        n_epochs,
-        optimizer,
-        learning_rate,
-        weight_decay,
-        momentum,
-        loss_function,
+        model: nn.Module,
+        n_epochs: int,
+        optimizer: str,
+        learning_rate: float,
+        weight_decay: float,
+        momentum: float,
+        loss_function: str,
     ):
 
         optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
@@ -188,7 +188,12 @@ class Hyperparameters:
         return hyper_dict
 
     # selecting the optimizer
-    def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
+    def _optimizer(self, 
+        model: nn.Module, 
+        optimizer: str, 
+        learning_rate: float, 
+        weight_decay: float, 
+        momentum: float):
         from torch.optim import Adam, SGD, RMSprop
 
         if optimizer == "Adam":
@@ -212,7 +217,7 @@ class Hyperparameters:
         return optim
 
     # selecting the loss function
-    def _loss_functions(self, loss_function):
+    def _loss_functions(self, loss_function: str):
         from monai.losses import FocalLoss, DiceLoss, DiceCELoss
         from torch.nn import BCEWithLogitsLoss
 
diff --git a/qim3d/operations/_common_operations_methods.py b/qim3d/operations/_common_operations_methods.py
index 1c14e163..5975c30e 100644
--- a/qim3d/operations/_common_operations_methods.py
+++ b/qim3d/operations/_common_operations_methods.py
@@ -153,7 +153,7 @@ def fade_mask(
 
 
 def overlay_rgb_images(
-    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black:bool = True,
+    background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black: bool = True,
 ) -> np.ndarray:
     """
     Overlay an RGB foreground onto an RGB background using alpha blending.
diff --git a/qim3d/processing/_layers.py b/qim3d/processing/_layers.py
index d91ef0a4..ed90133f 100644
--- a/qim3d/processing/_layers.py
+++ b/qim3d/processing/_layers.py
@@ -2,7 +2,14 @@ import numpy as np
 from slgbuilder import GraphObject 
 from slgbuilder import MaxflowBuilder
 
-def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, delta:float = 1, min_margin:int = 10, max_margin:int = None, wrap:bool = False):
+def segment_layers(data: np.ndarray, 
+                   inverted: bool = False, 
+                   n_layers: int = 1, 
+                   delta: float = 1, 
+                   min_margin: int = 10, 
+                   max_margin: int = None, 
+                   wrap: bool = False
+                   ) -> list:
     """
     Works on 2D and 3D data.
     Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation
diff --git a/qim3d/processing/_local_thickness.py b/qim3d/processing/_local_thickness.py
index a7e877d6..947f8865 100644
--- a/qim3d/processing/_local_thickness.py
+++ b/qim3d/processing/_local_thickness.py
@@ -10,7 +10,7 @@ def local_thickness(
     image: np.ndarray,
     scale: float = 1,
     mask: Optional[np.ndarray] = None,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> np.ndarray:
     """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness)
diff --git a/qim3d/processing/_structure_tensor.py b/qim3d/processing/_structure_tensor.py
index 10eba890..f6dc98cd 100644
--- a/qim3d/processing/_structure_tensor.py
+++ b/qim3d/processing/_structure_tensor.py
@@ -12,7 +12,7 @@ def structure_tensor(
     rho: float = 6.0,
     base_noise: bool = True,
     full: bool = False,
-    visualize=False,
+    visualize: bool = False,
     **viz_kwargs
 ) -> Tuple[np.ndarray, np.ndarray]:
     """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/).
diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py
index f43f0dc2..1f29b235 100644
--- a/qim3d/segmentation/_connected_components.py
+++ b/qim3d/segmentation/_connected_components.py
@@ -4,7 +4,7 @@ from qim3d.utils._logger import log
 
 
 class CC:
-    def __init__(self, connected_components, num_connected_components):
+    def __init__(self, connected_components: np.ndarray, num_connected_components: int):
         """
         Initializes a ConnectedComponents object.
 
diff --git a/qim3d/utils/_doi.py b/qim3d/utils/_doi.py
index b3ece663..6c4ae626 100644
--- a/qim3d/utils/_doi.py
+++ b/qim3d/utils/_doi.py
@@ -4,7 +4,7 @@ import requests
 from qim3d.utils._logger import log
 
 
-def _validate_response(response):
+def _validate_response(response: requests.Response) -> bool:
     # Check if we got a good response
     if not response.ok:
         log.error(f"Could not read the provided DOI ({response.reason})")
@@ -13,7 +13,7 @@ def _validate_response(response):
     return True
 
 
-def _doi_to_url(doi):
+def _doi_to_url(doi: str) -> str:
     if doi[:3] != "http":
         url = "https://doi.org/" + doi
     else:
@@ -22,7 +22,7 @@ def _doi_to_url(doi):
     return url
 
 
-def _make_request(doi, header):
+def _make_request(doi: str, header: str) -> requests.Response:
     # Get url from doi
     url = _doi_to_url(doi)
 
@@ -35,7 +35,7 @@ def _make_request(doi, header):
     return response
 
 
-def _log_and_get_text(doi, header):
+def _log_and_get_text(doi, header) -> str:
     response = _make_request(doi, header)
 
     if response and response.encoding:
@@ -50,13 +50,13 @@ def _log_and_get_text(doi, header):
         return text
 
 
-def get_bibtex(doi):
+def get_bibtex(doi: str):
     """Generates bibtex from doi"""
     header = {"Accept": "application/x-bibtex"}
 
     return _log_and_get_text(doi, header)
 
-def cusom_header(doi, header):
+def cusom_header(doi: str, header: str) -> str:
     """Allows a custom header to be passed
 
     For example:
@@ -67,7 +67,7 @@ def cusom_header(doi, header):
     """
     return _log_and_get_text(doi, header)
 
-def get_metadata(doi):
+def get_metadata(doi: str) -> dict:
     """Generates a metadata dictionary from doi"""
     header = {"Accept": "application/vnd.citationstyles.csl+json"}
     response = _make_request(doi, header)
@@ -76,7 +76,7 @@ def get_metadata(doi):
 
     return metadata
 
-def get_reference(doi):
+def get_reference(doi: str) -> str:
     """Generates a metadata dictionary from doi and use it to build a reference string"""
 
     metadata = get_metadata(doi)
@@ -84,7 +84,7 @@ def get_reference(doi):
 
     return reference_string
 
-def build_reference_string(metadata):
+def build_reference_string(metadata: dict) -> str:
     """Generates a reference string from metadata"""
     authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']])
     year = metadata['issued']['date-parts'][0][0]
diff --git a/qim3d/utils/_misc.py b/qim3d/utils/_misc.py
index a754ecb8..3d8660b4 100644
--- a/qim3d/utils/_misc.py
+++ b/qim3d/utils/_misc.py
@@ -12,7 +12,7 @@ import difflib
 import qim3d
 
 
-def get_local_ip():
+def get_local_ip() -> str:
     """Retrieves the local IP address of the current machine.
 
     The function uses a socket to determine the local IP address.
@@ -42,7 +42,7 @@ def get_local_ip():
     return ip_address
 
 
-def port_from_str(s):
+def port_from_str(s: str) -> int:
     """
     Generates a port number from a given string.
 
@@ -65,7 +65,7 @@ def port_from_str(s):
     return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4)
 
 
-def gradio_header(title, port):
+def gradio_header(title: str, port: int) -> None:
     """Display the header for a Gradio server.
 
     Displays a formatted header containing the provided title,
@@ -99,7 +99,7 @@ def gradio_header(title, port):
     ouf.showlist(details, style="box", title="Starting gradio server")
 
 
-def sizeof(num, suffix="B"):
+def sizeof(num: float, suffix: str = "B") -> str:
     """Converts a number to a human-readable string representing its size.
 
     Converts the given number to a human-readable string representing its size in
@@ -131,7 +131,7 @@ def sizeof(num, suffix="B"):
     return f"{num:.1f} Y{suffix}"
 
 
-def find_similar_paths(path):
+def find_similar_paths(path: str) -> list[str]:
     parent_dir = os.path.dirname(path) or "."
     parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else ""
     valid_paths = [os.path.join(parent_dir, file) for file in parent_files]
@@ -165,14 +165,14 @@ def get_file_size(file_path: str) -> int:
     return file_size
 
 
-def stringify_path(path):
+def stringify_path(path: os.PathLike) -> str:
     """Converts an os.PathLike object to a string"""
     if isinstance(path, os.PathLike):
         path = path.__fspath__()
     return path
 
 
-def get_port_dict():
+def get_port_dict() -> dict:
     # Gets user and port
     username = getpass.getuser()
     url = f"https://platform.qim.dk/qim-api/get-port/{username}"
@@ -189,7 +189,7 @@ def get_port_dict():
     return port_dict
 
 
-def get_css():
+def get_css() -> str:
 
     current_directory = os.path.dirname(os.path.abspath(__file__))
     parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
@@ -201,7 +201,7 @@ def get_css():
     return css_content
 
 
-def downscale_img(img, max_voxels=512**3):
+def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray:
     """Downscale image if total number of voxels exceeds 512³.
 
     Args:
@@ -226,7 +226,7 @@ def downscale_img(img, max_voxels=512**3):
     return zoom(img, zoom_factor, order=0)
 
 
-def scale_to_float16(arr: np.ndarray):
+def scale_to_float16(arr: np.ndarray) -> np.ndarray:
     """
     Scale the input array to the float16 data type.
 
diff --git a/qim3d/utils/_ome_zarr.py b/qim3d/utils/_ome_zarr.py
index 452997f1..db7ad145 100644
--- a/qim3d/utils/_ome_zarr.py
+++ b/qim3d/utils/_ome_zarr.py
@@ -1,7 +1,7 @@
 from zarr.util import normalize_chunks, normalize_dtype, normalize_shape
 import numpy as np
 
-def get_chunk_size(shape:tuple, dtype):
+def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]:
     """
     How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use 
     in qim3d.io.export_ome_zarr function
@@ -20,7 +20,7 @@ def get_chunk_size(shape:tuple, dtype):
     return chunks
 
 
-def get_n_chunks(shapes:tuple, dtypes:tuple):
+def get_n_chunks(shapes: tuple, dtypes: tuple) -> int:
     """
     Estimates how many chunks we will use in advence so we can pass this number to a progress bar and track how many
     have been already written to disk
diff --git a/qim3d/utils/_progress_bar.py b/qim3d/utils/_progress_bar.py
index 629026f4..57f33738 100644
--- a/qim3d/utils/_progress_bar.py
+++ b/qim3d/utils/_progress_bar.py
@@ -24,7 +24,7 @@ class RepeatTimer(Timer):
             self.function(*self.args, **self.kwargs)
 
 class ProgressBar(ABC):
-    def __init__(self,tqdm_kwargs:dict, repeat_time: float,  *args, **kwargs):
+    def __init__(self, tqdm_kwargs: dict, repeat_time: float,  *args, **kwargs):
         """
         Context manager for ('with' statement) to track progress during a long progress over 
         which we don't have control (like loading a file) and thus can not insert the tqdm
@@ -98,7 +98,7 @@ class FileLoadingProgressBar(ProgressBar):
         super().__init__( tqdm_kwargs, repeat_time)
         self.process = psutil.Process()
 
-    def get_new_update(self):
+    def get_new_update(self) -> int:
         counters = self.process.io_counters()
         try:
             memory = counters.read_chars
@@ -107,7 +107,7 @@ class FileLoadingProgressBar(ProgressBar):
         return memory
 
 class OmeZarrExportProgressBar(ProgressBar):
-    def __init__(self,path:str, n_chunks:int, reapeat_time="auto"):
+    def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"):
         """
         Context manager to track the exporting of OmeZarr files.
 
@@ -152,7 +152,7 @@ class OmeZarrExportProgressBar(ProgressBar):
         self.last_update = 0
 
     def get_new_update(self):
-        def file_count(folder_path:str):
+        def file_count(folder_path: str) -> int:
             """
             Goes recursively through the folders and counts how many files are there, 
             Doesn't count metadata json files
diff --git a/qim3d/utils/_server.py b/qim3d/utils/_server.py
index 9d20fccd..e4a46673 100644
--- a/qim3d/utils/_server.py
+++ b/qim3d/utils/_server.py
@@ -14,7 +14,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
         super().end_headers()
 
-    def list_directory(self, path):
+    def list_directory(self, path: str):
         """Helper to produce a directory listing, includes hidden files."""
         try:
             file_list = os.listdir(path)
@@ -49,7 +49,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
         # Write the encoded HTML directly to the response
         self.wfile.write(encoded)
 
-def start_http_server(directory, port=8000):
+def start_http_server(directory: str, port: int = 8000) -> HTTPServer:
     """
     Starts an HTTP server serving the specified directory on the given port with CORS enabled.
     
diff --git a/qim3d/utils/_system.py b/qim3d/utils/_system.py
index 19792627..ccccb672 100644
--- a/qim3d/utils/_system.py
+++ b/qim3d/utils/_system.py
@@ -36,7 +36,7 @@ class Memory:
             round(self.free_pct, 1),
         )
 
-def _test_disk_speed(file_size_bytes=1024, ntimes=10):
+def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]:
     '''
     Test the write and read speed of the disk by writing a file of a given size
     and then reading it back.
@@ -95,7 +95,7 @@ def _test_disk_speed(file_size_bytes=1024, ntimes=10):
     return avg_write_speed, write_speed_std, avg_read_speed, read_speed_std
 
 
-def disk_report(file_size=1024 * 1024 * 100, ntimes=10):
+def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None:
     '''
     Report the average write and read speed of the disk.
 
diff --git a/qim3d/viz/_cc.py b/qim3d/viz/_cc.py
index 1c0ed560..14ecf1f4 100644
--- a/qim3d/viz/_cc.py
+++ b/qim3d/viz/_cc.py
@@ -2,15 +2,15 @@ import matplotlib.pyplot as plt
 import numpy as np
 import qim3d
 from qim3d.utils._logger import log
-
+from segmentation._connected_components import CC
 
 def plot_cc(
-    connected_components,
+    connected_components: CC,
     component_indexs: list | tuple = None,
-    max_cc_to_plot=32,
-    overlay=None,
-    crop=False,
-    show=True,
+    max_cc_to_plot: int = 32,
+    overlay: np.ndarray = None,
+    crop: bool = False,
+    show: bool = True,
     cmap: str = "viridis",
     vmin: float = None,
     vmax: float = None,
diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py
index 2198cd94..766aae15 100644
--- a/qim3d/viz/_data_exploration.py
+++ b/qim3d/viz/_data_exploration.py
@@ -9,7 +9,9 @@ from typing import List, Optional, Union
 
 import dask.array as da
 import ipywidgets as widgets
+import matplotlib.figure
 import matplotlib.pyplot as plt
+import matplotlib
 from IPython.display import SVG, display
 import matplotlib
 import numpy as np
@@ -29,7 +31,7 @@ def slices_grid(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-    image_size=None,
+    image_size: int = None,
     image_height: int = 2,
     image_width: int = 2,
     display_figure: bool = False,
@@ -38,7 +40,7 @@ def slices_grid(
     color_bar: bool = False,
     color_bar_style: str = "small",
     **matplotlib_imshow_kwargs,
-) -> plt.Figure:
+) -> matplotlib.figure.Figure:
     """Displays one or several slices from a 3d volume.
 
     By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
@@ -296,7 +298,7 @@ def slices_grid(
     return fig
 
 
-def _get_slice_range(position: int, num_slices: int, n_total):
+def _get_slice_range(position: int, num_slices: int, n_total) -> np.ndarray:
     """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
     start_idx = position - num_slices // 2
     end_idx = (
@@ -324,7 +326,7 @@ def slicer(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
+    image_size: int = None,
     color_bar: bool = False,
     **matplotlib_imshow_kwargs,
 ) -> widgets.interactive:
@@ -401,8 +403,8 @@ def slicer_orthogonal(
     image_width: int = 3,
     display_positions: bool = False,
     interpolation: Optional[str] = None,
-    image_size=None,
-):
+    image_size: int = None,
+)-> widgets.interactive:
     """Interactive widget for visualizing orthogonal slices of a 3D volume.
 
     Args:
@@ -461,7 +463,7 @@ def fade_mask(
     color_map: str = "magma",
     value_min: float = None,
     value_max: float = None,
-):
+)-> widgets.interactive:
     """Interactive widget for visualizing the effect of edge fading on a 3D volume.
 
     This can be used to select the best parameters before applying the mask.
@@ -596,7 +598,7 @@ def fade_mask(
     return slicer_obj
 
 
-def chunks(zarr_path: str, **kwargs):
+def chunks(zarr_path: str, **kwargs)-> widgets.interactive:
     """
     Function to visualize chunks of a Zarr dataset using the specified visualization method.
 
@@ -861,7 +863,7 @@ def histogram(
     return_fig=False,
     show=True,
     **sns_kwargs,
-):
+) -> Optional[matplotlib.figure.Figure]:
     """
     Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
 
diff --git a/qim3d/viz/_detection.py b/qim3d/viz/_detection.py
index a0926105..b4ab9098 100644
--- a/qim3d/viz/_detection.py
+++ b/qim3d/viz/_detection.py
@@ -6,7 +6,7 @@ from IPython.display import clear_output, display
 import qim3d
 
 
-def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs):
+def circles(blobs: np.ndarray, vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive:
     """
     Plots the blobs found on a slice of the volume.
 
diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py
index b80e9556..8b2c23ac 100644
--- a/qim3d/viz/_k3d.py
+++ b/qim3d/viz/_k3d.py
@@ -15,18 +15,18 @@ from qim3d.utils._misc import downscale_img, scale_to_float16
 
 
 def volumetric(
-    img,
-    aspectmode="data",
-    show=True,
-    save=False,
-    grid_visible=False,
-    color_map='magma',
-    constant_opacity=False,
-    vmin=None,
-    vmax=None,
-    samples="auto",
-    max_voxels=512**3,
-    data_type="scaled_float16",
+    img: np.ndarray,
+    aspectmode: str = "data",
+    show: bool = True,
+    save: bool = False,
+    grid_visible: bool = False,
+    color_map: str = 'magma',
+    constant_opacity: bool = False,
+    vmin: float = None,
+    vmax: float = None,
+    samples: int|str = "auto",
+    max_voxels: int = 512**3,
+    data_type: str = "scaled_float16",
     **kwargs,
 ):
     """
@@ -175,13 +175,13 @@ def volumetric(
 
 
 def mesh(
-    verts,
-    faces,
-    wireframe=True,
-    flat_shading=True,
-    grid_visible=False,
-    show=True,
-    save=False,
+    verts: np.ndarray,
+    faces: np.ndarray,
+    wireframe: bool = True,
+    flat_shading: bool = True,
+    grid_visible: bool = False,
+    show: bool = True,
+    save: bool = False,
     **kwargs,
 ):
     """
diff --git a/qim3d/viz/_layers2d.py b/qim3d/viz/_layers2d.py
index 1e284461..676845a5 100644
--- a/qim3d/viz/_layers2d.py
+++ b/qim3d/viz/_layers2d.py
@@ -8,7 +8,7 @@ import numpy as np
 from PIL import Image
 
 
-def image_with_lines(image:np.ndarray, lines: list, line_thickness:float|int) -> Image:
+def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> Image:
     """
     Plots the image and plots the lines on top of it. Then extracts it as PIL.Image and in the same size as the input image was.
     Paramters:
diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py
index 53a7f3a6..ded37d24 100644
--- a/qim3d/viz/_metrics.py
+++ b/qim3d/viz/_metrics.py
@@ -1,19 +1,21 @@
 """Visualization tools"""
 
+import matplotlib.figure
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
 from qim3d.utils._logger import log
-
+import torch
+import matplotlib
 
 def plot_metrics(
-    *metrics,
-    linestyle="-",
-    batch_linestyle="dotted",
+    *metrics: tuple[dict[str, float]],
+    linestyle: str = "-",
+    batch_linestyle: str = "dotted",
     labels: list = None,
     figsize: tuple = (16, 6),
-    show=False
+    show: bool = False
 ):
     """
     Plots the metrics over epochs and batches.
@@ -79,8 +81,13 @@ def plot_metrics(
 
 
 def grid_overview(
-    data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False
-):
+    data: list|torch.utils.data.Dataset, 
+    num_images: int = 7, 
+    cmap_im: str = "gray", 
+    cmap_segm: str = "viridis", 
+    alpha: float = 0.5, 
+    show: bool = False
+)-> matplotlib.figure.Figure:
     """Displays an overview grid of images, labels, and masks (if they exist).
 
     Labels are the annotated target segmentations
@@ -174,13 +181,13 @@ def grid_overview(
 
 
 def grid_pred(
-    in_targ_preds,
-    num_images=7,
-    cmap_im="gray",
-    cmap_segm="viridis",
-    alpha=0.5,
-    show=False,
-):
+    in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray],
+    num_images: int = 7,
+    cmap_im: str = "gray",
+    cmap_segm: str = "viridis",
+    alpha: float = 0.5,
+    show: bool = False,
+)-> matplotlib.figure.Figure:
     """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
 
     Displays a grid of subplots representing different aspects of the input images and segmentations.
@@ -282,7 +289,7 @@ def grid_pred(
     return fig
 
 
-def vol_masked(vol, vol_mask, viz_delta=128):
+def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray:
     """
     Applies masking to a volume based on a binary volume mask.
 
-- 
GitLab