Skip to content
Snippets Groups Projects
Commit 6b3f808f authored by s214735's avatar s214735
Browse files

Doublechecking and minor fixes

parent 58c2f501
No related branches found
No related tags found
1 merge request!143Type hints
......@@ -15,7 +15,7 @@ def blobs(
overlap: float = 0.5,
threshold_rel: float = None,
exclude_border: bool = False,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
"""
Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True
......
......@@ -5,7 +5,7 @@ import trimesh
import qim3d
def volume(obj: np.ndarray,
def volume(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
......@@ -51,7 +51,7 @@ def volume(obj: np.ndarray,
return obj.volume
def area(obj: np.ndarray,
def area(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
......@@ -96,7 +96,7 @@ def area(obj: np.ndarray,
return obj.area
def sphericity(obj: np.ndarray,
def sphericity(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
......
......@@ -43,7 +43,7 @@ class FilterBase:
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input: np.ndarray):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a Gaussian filter to the input.
......@@ -57,7 +57,7 @@ class Gaussian(FilterBase):
class Median(FilterBase):
def __call__(self, input: np.ndarray):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a median filter to the input.
......@@ -71,7 +71,7 @@ class Median(FilterBase):
class Maximum(FilterBase):
def __call__(self, input: np.ndarray):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a maximum filter to the input.
......@@ -85,7 +85,7 @@ class Maximum(FilterBase):
class Minimum(FilterBase):
def __call__(self, input: np.ndarray):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a minimum filter to the input.
......@@ -98,7 +98,7 @@ class Minimum(FilterBase):
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase):
def __call__(self, input: np.ndarray):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a tophat filter to the input.
......
......@@ -142,7 +142,7 @@ def noise_object_collection(
object_shape: str = None,
seed: int = 0,
verbose: bool = False,
) -> tuple[np.ndarray, object]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Generate a 3D volume of multiple synthetic objects using Perlin noise.
......
......@@ -55,7 +55,7 @@ class Interface(BaseInterface):
self.masks_rgb = None
self.temp_files = []
def get_result(self):
def get_result(self) -> dict:
# Get the temporary files from gradio
temp_path_list = []
for filename in os.listdir(self.temp_dir):
......@@ -95,13 +95,13 @@ class Interface(BaseInterface):
except FileNotFoundError:
files = None
def create_preview(self, img_editor: gr.ImageEditor):
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"]
masks = img_editor["layers"][0]
overlay_image = overlay_rgb_images(background, masks)
return overlay_image
def cerate_download_list(self, img_editor: gr.ImageEditor):
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0]
mask_threshold = 200 # This value is based
......
......@@ -20,6 +20,7 @@ import os
import re
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import outputformat as ouf
......@@ -29,6 +30,8 @@ from qim3d.utils._logger import log
from qim3d.utils import _misc
from qim3d.gui.interface import BaseInterface
from typing import Callable, Any, Dict
import matplotlib
class Interface(BaseInterface):
......@@ -271,7 +274,7 @@ class Interface(BaseInterface):
operations.change(fn=self.show_results, inputs = operations, outputs = results)
cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
def update_explorer(self, new_path):
def update_explorer(self, new_path: str):
new_path = os.path.expanduser(new_path)
# In case we have a directory
......@@ -367,7 +370,7 @@ class Interface(BaseInterface):
except Exception as error_message:
self.error_message = F"Error when loading data: {error_message}"
def run_operations(self, operations, *args):
def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
outputs = []
self.calculated_operations = []
for operation in self.all_operations:
......@@ -411,7 +414,7 @@ class Interface(BaseInterface):
case _:
raise NotImplementedError(F"Operation '{operation} is not defined")
def show_results(self, operations):
def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
update_list = []
for operation in self.all_operations:
if operation in operations and operation in self.calculated_operations:
......@@ -426,7 +429,7 @@ class Interface(BaseInterface):
#
#######################################################
def create_img_fig(self, img, **kwargs):
def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(img, interpolation="nearest", **kwargs)
......@@ -437,8 +440,8 @@ class Interface(BaseInterface):
return fig
def update_slice_wrapper(self, letter):
def update_slice(position_slider:float, cmap:str):
def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]:
"""
position_slider: float from gradio slider, saying which relative slice we want to see
cmap: string gradio drop down menu, saying what cmap we want to use for display
......@@ -465,7 +468,7 @@ class Interface(BaseInterface):
return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
return update_slice
def vol_histogram(self, nbins, min_value, max_value):
def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
# Start histogram
vol_hist = np.zeros(nbins)
......@@ -478,7 +481,7 @@ class Interface(BaseInterface):
return vol_hist, bin_edges
def plot_histogram(self):
def plot_histogram(self) -> matplotlib.figure.Figure:
# The Histogram needs results from the projections
if not self.projections_calculated:
_ = self.get_projections()
......@@ -498,7 +501,7 @@ class Interface(BaseInterface):
return fig
def create_projections_figs(self):
def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
if not self.projections_calculated:
projections = self.get_projections()
self.max_projection = projections[0]
......@@ -519,7 +522,7 @@ class Interface(BaseInterface):
self.projections_calculated = True
return max_projection_fig, min_projection_fig
def get_projections(self):
def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
# Create arrays for iteration
max_projection = np.zeros(np.shape(self.vol[0]))
min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
......
......@@ -6,6 +6,7 @@ import gradio as gr
from .qim_theme import QimTheme
import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli
......@@ -51,7 +52,7 @@ class BaseInterface(ABC):
def change_visibility(self, is_visible: bool):
return gr.update(visible = is_visible)
def launch(self, img=None, force_light_mode: bool = True, **kwargs):
def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
"""
img: If None, user can upload image after the interface is launched.
If defined, the interface will be launched with the image already there
......@@ -76,7 +77,7 @@ class BaseInterface(ABC):
**kwargs,
)
def clear(self):
def clear(self) -> None:
"""Used to reset outputs with the clear button"""
return None
......
......@@ -28,6 +28,7 @@ from .interface import BaseInterface
from qim3d.processing import overlay_rgb_images, segment_layers, get_lines
from qim3d.io import load
from qim3d.viz._layers2d import image_with_lines
from typing import Dict, Any
#TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error
......@@ -302,14 +303,14 @@ class Interface(BaseInterface):
def change_plot_type(self, plot_type: str, ):
def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
self.plot_type = plot_type
if plot_type == 'Segmentation lines':
return gr.update(visible = False), gr.update(visible = True)
else:
return gr.update(visible = True), gr.update(visible = False)
def change_plot_size(self, x_check: int, y_check: int, z_check: int):
def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""
Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
......
......@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples
class Interface(InterfaceWithExamples):
def __init__(self,
img = None,
img: np.ndarray = None,
verbose:bool = False,
plot_height:int = 768,
figsize:int = 6):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment