From 427ecd0ba4c2a42ea2b262dfb56ca2cd7a1b685e Mon Sep 17 00:00:00 2001
From: s214735 <s214735@dtu.dk>
Date: Mon, 23 Dec 2024 13:48:28 +0100
Subject: [PATCH] Added to classes as well

---
 qim3d/features/_common_features_methods.py | 12 ++++--
 qim3d/filters/_common_filter_methods.py    | 40 +++++++++++++-----
 qim3d/gui/annotation_tool.py               |  6 +--
 qim3d/gui/interface.py                     |  2 +-
 qim3d/gui/iso3d.py                         | 49 +++++++++++-----------
 qim3d/gui/layers2d.py                      | 12 +++---
 qim3d/gui/local_thickness.py               | 10 ++---
 qim3d/gui/qim_theme.py                     |  2 +-
 qim3d/io/_ome_zarr.py                      | 24 ++++++-----
 qim3d/io/_saving.py                        | 21 ++++++----
 10 files changed, 105 insertions(+), 73 deletions(-)

diff --git a/qim3d/features/_common_features_methods.py b/qim3d/features/_common_features_methods.py
index e69b6cba..d7a6960e 100644
--- a/qim3d/features/_common_features_methods.py
+++ b/qim3d/features/_common_features_methods.py
@@ -5,7 +5,9 @@ import trimesh
 import qim3d
 
 
-def volume(obj, **mesh_kwargs) -> float:
+def volume(obj: np.ndarray, 
+           **mesh_kwargs
+           ) -> float:
     """
     Compute the volume of a 3D volume or mesh.
 
@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
     return obj.volume
 
 
-def area(obj, **mesh_kwargs) -> float:
+def area(obj: np.ndarray, 
+         **mesh_kwargs
+         ) -> float:
     """
     Compute the surface area of a 3D volume or mesh.
 
@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
     return obj.area
 
 
-def sphericity(obj, **mesh_kwargs) -> float:
+def sphericity(obj: np.ndarray, 
+               **mesh_kwargs
+               ) -> float:
     """
     Compute the sphericity of a 3D volume or mesh.
 
diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py
index f2d42197..bc3bab5b 100644
--- a/qim3d/filters/_common_filter_methods.py
+++ b/qim3d/filters/_common_filter_methods.py
@@ -26,7 +26,10 @@ __all__ = [
 
 
 class FilterBase:
-    def __init__(self, dask=False, chunks="auto", *args, **kwargs):
+    def __init__(self, 
+                 dask: bool = False, 
+                 chunks: str = "auto", 
+                 *args, **kwargs):
         """
         Base class for image filters.
 
@@ -40,7 +43,7 @@ class FilterBase:
         self.kwargs = kwargs
 
 class Gaussian(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray):
         """
         Applies a Gaussian filter to the input.
 
@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
 
 
 class Median(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray):
         """
         Applies a median filter to the input.
 
@@ -68,7 +71,7 @@ class Median(FilterBase):
 
 
 class Maximum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray):
         """
         Applies a maximum filter to the input.
 
@@ -82,7 +85,7 @@ class Maximum(FilterBase):
 
 
 class Minimum(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray):
         """
         Applies a minimum filter to the input.
 
@@ -95,7 +98,7 @@ class Minimum(FilterBase):
         return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
 
 class Tophat(FilterBase):
-    def __call__(self, input):
+    def __call__(self, input: np.ndarray):
         """
         Applies a tophat filter to the input.
 
@@ -210,7 +213,10 @@ class Pipeline:
         return input
 
 
-def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
+def gaussian(vol: np.ndarray, 
+             dask: bool = False, 
+             chunks: str = 'auto',
+             *args, **kwargs) -> np.ndarray:
     """
     Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
 
@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
         return res
 
 
-def median(vol, dask=False, chunks='auto', **kwargs):
+def median(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str ='auto', 
+           **kwargs) -> np.ndarray:
     """
     Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
 
@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def maximum(vol, dask=False, chunks='auto', **kwargs):
+def maximum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
 
@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
         return res
 
 
-def minimum(vol, dask=False, chunks='auto', **kwargs):
+def minimum(vol: np.ndarray, 
+            dask: bool = False, 
+            chunks: str = 'auto', 
+            **kwargs) -> np.ndarray:
     """
     Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
 
@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
         res = ndimage.minimum_filter(vol, **kwargs)
         return res
 
-def tophat(vol, dask=False, chunks='auto', **kwargs):
+def tophat(vol: np.ndarray, 
+           dask: bool = False, 
+           chunks: str = 'auto', 
+           **kwargs) -> np.ndarray:
     """
     Remove background from the volume.
 
diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py
index 4498e1ad..1ecbc5d3 100644
--- a/qim3d/gui/annotation_tool.py
+++ b/qim3d/gui/annotation_tool.py
@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
 
 
 class Interface(BaseInterface):
-    def __init__(self, name_suffix: str = "", verbose: bool = False, img=None):
+    def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
         super().__init__(
             title="Annotation Tool",
             height=768,
@@ -95,13 +95,13 @@ class Interface(BaseInterface):
         except FileNotFoundError:
             files = None
 
-    def create_preview(self, img_editor):
+    def create_preview(self, img_editor: gr.ImageEditor):
         background = img_editor["background"]
         masks = img_editor["layers"][0]
         overlay_image = overlay_rgb_images(background, masks)
         return overlay_image
 
-    def cerate_download_list(self, img_editor):
+    def cerate_download_list(self, img_editor: gr.ImageEditor):
         masks_rgb = img_editor["layers"][0]
         mask_threshold = 200  # This value is based
 
diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py
index c5596850..0241ce32 100644
--- a/qim3d/gui/interface.py
+++ b/qim3d/gui/interface.py
@@ -48,7 +48,7 @@ class BaseInterface(ABC):
     def set_invisible(self):
         return gr.update(visible=False)
     
-    def change_visibility(self, is_visible):
+    def change_visibility(self, is_visible: bool):
         return gr.update(visible = is_visible)
 
     def launch(self, img=None, force_light_mode: bool = True, **kwargs):
diff --git a/qim3d/gui/iso3d.py b/qim3d/gui/iso3d.py
index 1a20f91e..c7d4620f 100644
--- a/qim3d/gui/iso3d.py
+++ b/qim3d/gui/iso3d.py
@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
         self.img = img
         self.plot_height = plot_height
 
-    def load_data(self, gradiofile):
+    def load_data(self, gradiofile: gr.File):
         try:
             self.vol = load(gradiofile.name)
             assert self.vol.ndim == 3
@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
         except AssertionError:
             raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
 
-    def resize_vol(self, display_size):
+    def resize_vol(self, display_size: int):
         """Resizes the loaded volume to the display size"""
 
         # Get original size
@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
                 f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
             )
 
-    def save_fig(self, fig, filename):
+    def save_fig(self, fig: go.Figure, filename: str):
         # Write Plotly figure to disk
         fig.write_html(filename)
 
     def create_fig(self, 
-        gradio_file,
-        display_size,
-        opacity,
-        opacityscale,
-        only_wireframe,
-        min_value,
-        max_value,
-        surface_count,
-        colormap,
-        show_colorbar,
-        reversescale,
-        flip_z,
-        show_axis,
-        show_ticks,
-        show_caps,
-        show_z_slice,
-        slice_z_location,
-        show_y_slice,
-        slice_y_location,
-        show_x_slice,
-        slice_x_location,):
+        gradio_file: gr.File,
+        display_size: int ,
+        opacity: float,
+        opacityscale: str,
+        only_wireframe: bool,
+        min_value: float,
+        max_value: float,
+        surface_count: int,
+        colormap: str,
+        show_colorbar: bool,
+        reversescale: bool,
+        flip_z: bool,
+        show_axis: bool,
+        show_ticks: bool,
+        show_caps: bool,
+        show_z_slice: bool,
+        slice_z_location: int,
+        show_y_slice: bool,
+        slice_y_location: int,
+        show_x_slice: bool,
+        slice_x_location: int,
+        ) -> tuple[go.Figure, str]:
 
         # Load volume
         self.load_data(gradio_file)
diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py
index 780746a0..d3ba4164 100644
--- a/qim3d/gui/layers2d.py
+++ b/qim3d/gui/layers2d.py
@@ -302,14 +302,14 @@ class Interface(BaseInterface):
             
         
 
-    def change_plot_type(self, plot_type, ):
+    def change_plot_type(self, plot_type: str, ):
         self.plot_type = plot_type
         if plot_type == 'Segmentation lines':
             return gr.update(visible = False), gr.update(visible = True)
         else:  
             return gr.update(visible = True), gr.update(visible = False)
         
-    def change_plot_size(self, x_check, y_check, z_check):
+    def change_plot_size(self, x_check: int, y_check: int, z_check: int):
         """
         Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
         also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
@@ -319,10 +319,10 @@ class Interface(BaseInterface):
         height = self.heights[index] # also used to define heights of plots in the begining
         return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
 
-    def change_row_visibility(self, x_check, y_check, z_check):
+    def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
         return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
     
-    def update_explorer(self, new_path):
+    def update_explorer(self, new_path: str):
         # Refresh the file explorer object
         new_path = os.path.expanduser(new_path)
 
@@ -341,13 +341,13 @@ class Interface(BaseInterface):
     def set_relaunch_button(self):
         return gr.update(value=f"Relaunch", interactive=True)
 
-    def set_spinner(self, message):
+    def set_spinner(self, message: str):
         if self.error:
             return gr.Button()
         # spinner icon/shows the user something is happeing
         return gr.update(value=f"{message}", interactive=False)
     
-    def load_data(self, base_path, explorer):
+    def load_data(self, base_path: str, explorer: str):
         if base_path and os.path.isfile(base_path):
             file_path = base_path
         elif explorer and os.path.isfile(explorer):
diff --git a/qim3d/gui/local_thickness.py b/qim3d/gui/local_thickness.py
index 0c76f1aa..ad837b42 100644
--- a/qim3d/gui/local_thickness.py
+++ b/qim3d/gui/local_thickness.py
@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
     #
     #######################################################
 
-    def process_input(self, data, dark_objects):
+    def process_input(self, data: np.ndarray, dark_objects: bool):
         # Load volume
         try:
             self.vol = load(data.name)
@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
         self.vmin = np.min(self.vol)
         self.vmax = np.max(self.vol)
 
-    def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"):
+    def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
         plt.close()
         z_idx = int(zpos * (vol.shape[0] - 1))
         fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
 
         return fig
 
-    def make_binary(self, threshold):
+    def make_binary(self, threshold: float):
         # Make a binary volume
         # Nothing fancy, but we could add new features here
         self.vol_binary = self.vol > (threshold * np.max(self.vol))
     
-    def compute_localthickness(self, lt_scale):
+    def compute_localthickness(self, lt_scale: float):
         self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
 
         # Valus for visualization
         self.vmin_lt = np.min(self.vol_thickness)
         self.vmax_lt = np.max(self.vol_thickness)
 
-    def thickness_histogram(self, nbins):
+    def thickness_histogram(self, nbins: int):
         # Ignore zero thickness
         non_zero_values = self.vol_thickness[self.vol_thickness > 0]
 
diff --git a/qim3d/gui/qim_theme.py b/qim3d/gui/qim_theme.py
index f13594e4..5cdc844b 100644
--- a/qim3d/gui/qim_theme.py
+++ b/qim3d/gui/qim_theme.py
@@ -7,7 +7,7 @@ class QimTheme(gr.themes.Default):
     there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
     in comments.
     """
-    def __init__(self, force_light_mode:bool = True):
+    def __init__(self, force_light_mode: bool = True):
         """
         Parameters:
         -----------
diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py
index ae517315..46b6801c 100644
--- a/qim3d/io/_ome_zarr.py
+++ b/qim3d/io/_ome_zarr.py
@@ -181,16 +181,16 @@ class OMEScaler(
 
 
 def export_ome_zarr(
-    path,
-    data,
-    chunk_size=256,
-    downsample_rate=2,
-    order=1,
-    replace=False,
-    method="scaleZYX",
+    path: str,
+    data: np.ndarray,
+    chunk_size: int = 256,
+    downsample_rate: int = 2,
+    order: int = 1,
+    replace: bool = False,
+    method: str = "scaleZYX",
     progress_bar: bool = True,
-    progress_bar_repeat_time="auto",
-):
+    progress_bar_repeat_time: str = "auto",
+) -> None:
     """
     Export 3D image data to OME-Zarr format with pyramidal downsampling.
 
@@ -299,7 +299,11 @@ def export_ome_zarr(
     return
 
 
-def import_ome_zarr(path, scale=0, load=True):
+def import_ome_zarr(
+        path: str, 
+        scale: int = 0, 
+        load: bool = True
+        ) -> np.ndarray:
     """
     Import image data from an OME-Zarr file.
 
diff --git a/qim3d/io/_saving.py b/qim3d/io/_saving.py
index d7721ee2..efe7f2ce 100644
--- a/qim3d/io/_saving.py
+++ b/qim3d/io/_saving.py
@@ -401,15 +401,15 @@ class DataSaver:
 
 
 def save(
-    path,
-    data,
-    replace=False,
-    compression=False,
-    basename=None,
-    sliced_dim=0,
-    chunk_shape="auto",
+    path: str,
+    data: np.ndarray,
+    replace: bool = False,
+    compression: bool = False,
+    basename: bool = None,
+    sliced_dim: int = 0,
+    chunk_shape: str = "auto",
     **kwargs,
-):
+) -> None:
     """Save data to a specified file path.
 
     Args:
@@ -464,7 +464,10 @@ def save(
     ).save(path, data)
 
 
-def save_mesh(filename, mesh):
+def save_mesh(
+        filename: str, 
+        mesh: trimesh.Trimesh
+        ) -> None:
     """
     Save a trimesh object to an .obj file.
 
-- 
GitLab