Skip to content
Snippets Groups Projects
Commit 427ecd0b authored by s214735's avatar s214735
Browse files

Added to classes as well

parent 1b6b2c2e
Branches
No related tags found
1 merge request!143Type hints
...@@ -5,7 +5,9 @@ import trimesh ...@@ -5,7 +5,9 @@ import trimesh
import qim3d import qim3d
def volume(obj, **mesh_kwargs) -> float: def volume(obj: np.ndarray,
**mesh_kwargs
) -> float:
""" """
Compute the volume of a 3D volume or mesh. Compute the volume of a 3D volume or mesh.
...@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float: ...@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
return obj.volume return obj.volume
def area(obj, **mesh_kwargs) -> float: def area(obj: np.ndarray,
**mesh_kwargs
) -> float:
""" """
Compute the surface area of a 3D volume or mesh. Compute the surface area of a 3D volume or mesh.
...@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float: ...@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
return obj.area return obj.area
def sphericity(obj, **mesh_kwargs) -> float: def sphericity(obj: np.ndarray,
**mesh_kwargs
) -> float:
""" """
Compute the sphericity of a 3D volume or mesh. Compute the sphericity of a 3D volume or mesh.
......
...@@ -26,7 +26,10 @@ __all__ = [ ...@@ -26,7 +26,10 @@ __all__ = [
class FilterBase: class FilterBase:
def __init__(self, dask=False, chunks="auto", *args, **kwargs): def __init__(self,
dask: bool = False,
chunks: str = "auto",
*args, **kwargs):
""" """
Base class for image filters. Base class for image filters.
...@@ -40,7 +43,7 @@ class FilterBase: ...@@ -40,7 +43,7 @@ class FilterBase:
self.kwargs = kwargs self.kwargs = kwargs
class Gaussian(FilterBase): class Gaussian(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray):
""" """
Applies a Gaussian filter to the input. Applies a Gaussian filter to the input.
...@@ -54,7 +57,7 @@ class Gaussian(FilterBase): ...@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
class Median(FilterBase): class Median(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray):
""" """
Applies a median filter to the input. Applies a median filter to the input.
...@@ -68,7 +71,7 @@ class Median(FilterBase): ...@@ -68,7 +71,7 @@ class Median(FilterBase):
class Maximum(FilterBase): class Maximum(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray):
""" """
Applies a maximum filter to the input. Applies a maximum filter to the input.
...@@ -82,7 +85,7 @@ class Maximum(FilterBase): ...@@ -82,7 +85,7 @@ class Maximum(FilterBase):
class Minimum(FilterBase): class Minimum(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray):
""" """
Applies a minimum filter to the input. Applies a minimum filter to the input.
...@@ -95,7 +98,7 @@ class Minimum(FilterBase): ...@@ -95,7 +98,7 @@ class Minimum(FilterBase):
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs) return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase): class Tophat(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray):
""" """
Applies a tophat filter to the input. Applies a tophat filter to the input.
...@@ -210,7 +213,10 @@ class Pipeline: ...@@ -210,7 +213,10 @@ class Pipeline:
return input return input
def gaussian(vol, dask=False, chunks='auto', *args, **kwargs): def gaussian(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
*args, **kwargs) -> np.ndarray:
""" """
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter. Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
...@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs): ...@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
return res return res
def median(vol, dask=False, chunks='auto', **kwargs): def median(vol: np.ndarray,
dask: bool = False,
chunks: str ='auto',
**kwargs) -> np.ndarray:
""" """
Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter. Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
...@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs): ...@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
return res return res
def maximum(vol, dask=False, chunks='auto', **kwargs): def maximum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter. Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
...@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs): ...@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
return res return res
def minimum(vol, dask=False, chunks='auto', **kwargs): def minimum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter. Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
...@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs): ...@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
res = ndimage.minimum_filter(vol, **kwargs) res = ndimage.minimum_filter(vol, **kwargs)
return res return res
def tophat(vol, dask=False, chunks='auto', **kwargs): def tophat(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Remove background from the volume. Remove background from the volume.
......
...@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface ...@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface): class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img=None): def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
super().__init__( super().__init__(
title="Annotation Tool", title="Annotation Tool",
height=768, height=768,
...@@ -95,13 +95,13 @@ class Interface(BaseInterface): ...@@ -95,13 +95,13 @@ class Interface(BaseInterface):
except FileNotFoundError: except FileNotFoundError:
files = None files = None
def create_preview(self, img_editor): def create_preview(self, img_editor: gr.ImageEditor):
background = img_editor["background"] background = img_editor["background"]
masks = img_editor["layers"][0] masks = img_editor["layers"][0]
overlay_image = overlay_rgb_images(background, masks) overlay_image = overlay_rgb_images(background, masks)
return overlay_image return overlay_image
def cerate_download_list(self, img_editor): def cerate_download_list(self, img_editor: gr.ImageEditor):
masks_rgb = img_editor["layers"][0] masks_rgb = img_editor["layers"][0]
mask_threshold = 200 # This value is based mask_threshold = 200 # This value is based
......
...@@ -48,7 +48,7 @@ class BaseInterface(ABC): ...@@ -48,7 +48,7 @@ class BaseInterface(ABC):
def set_invisible(self): def set_invisible(self):
return gr.update(visible=False) return gr.update(visible=False)
def change_visibility(self, is_visible): def change_visibility(self, is_visible: bool):
return gr.update(visible = is_visible) return gr.update(visible = is_visible)
def launch(self, img=None, force_light_mode: bool = True, **kwargs): def launch(self, img=None, force_light_mode: bool = True, **kwargs):
......
...@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples): ...@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
self.img = img self.img = img
self.plot_height = plot_height self.plot_height = plot_height
def load_data(self, gradiofile): def load_data(self, gradiofile: gr.File):
try: try:
self.vol = load(gradiofile.name) self.vol = load(gradiofile.name)
assert self.vol.ndim == 3 assert self.vol.ndim == 3
...@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples): ...@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
except AssertionError: except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
def resize_vol(self, display_size): def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size""" """Resizes the loaded volume to the display size"""
# Get original size # Get original size
...@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples): ...@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}" f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
) )
def save_fig(self, fig, filename): def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk # Write Plotly figure to disk
fig.write_html(filename) fig.write_html(filename)
def create_fig(self, def create_fig(self,
gradio_file, gradio_file: gr.File,
display_size, display_size: int ,
opacity, opacity: float,
opacityscale, opacityscale: str,
only_wireframe, only_wireframe: bool,
min_value, min_value: float,
max_value, max_value: float,
surface_count, surface_count: int,
colormap, colormap: str,
show_colorbar, show_colorbar: bool,
reversescale, reversescale: bool,
flip_z, flip_z: bool,
show_axis, show_axis: bool,
show_ticks, show_ticks: bool,
show_caps, show_caps: bool,
show_z_slice, show_z_slice: bool,
slice_z_location, slice_z_location: int,
show_y_slice, show_y_slice: bool,
slice_y_location, slice_y_location: int,
show_x_slice, show_x_slice: bool,
slice_x_location,): slice_x_location: int,
) -> tuple[go.Figure, str]:
# Load volume # Load volume
self.load_data(gradio_file) self.load_data(gradio_file)
......
...@@ -302,14 +302,14 @@ class Interface(BaseInterface): ...@@ -302,14 +302,14 @@ class Interface(BaseInterface):
def change_plot_type(self, plot_type, ): def change_plot_type(self, plot_type: str, ):
self.plot_type = plot_type self.plot_type = plot_type
if plot_type == 'Segmentation lines': if plot_type == 'Segmentation lines':
return gr.update(visible = False), gr.update(visible = True) return gr.update(visible = False), gr.update(visible = True)
else: else:
return gr.update(visible = True), gr.update(visible = False) return gr.update(visible = True), gr.update(visible = False)
def change_plot_size(self, x_check, y_check, z_check): def change_plot_size(self, x_check: int, y_check: int, z_check: int):
""" """
Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
also their height because gradio doesn't do it automatically. The values of heights were set just by eye. also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
...@@ -319,10 +319,10 @@ class Interface(BaseInterface): ...@@ -319,10 +319,10 @@ class Interface(BaseInterface):
height = self.heights[index] # also used to define heights of plots in the begining height = self.heights[index] # also used to define heights of plots in the begining
return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check) return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
def change_row_visibility(self, x_check, y_check, z_check): def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check) return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
def update_explorer(self, new_path): def update_explorer(self, new_path: str):
# Refresh the file explorer object # Refresh the file explorer object
new_path = os.path.expanduser(new_path) new_path = os.path.expanduser(new_path)
...@@ -341,13 +341,13 @@ class Interface(BaseInterface): ...@@ -341,13 +341,13 @@ class Interface(BaseInterface):
def set_relaunch_button(self): def set_relaunch_button(self):
return gr.update(value=f"Relaunch", interactive=True) return gr.update(value=f"Relaunch", interactive=True)
def set_spinner(self, message): def set_spinner(self, message: str):
if self.error: if self.error:
return gr.Button() return gr.Button()
# spinner icon/shows the user something is happeing # spinner icon/shows the user something is happeing
return gr.update(value=f"{message}", interactive=False) return gr.update(value=f"{message}", interactive=False)
def load_data(self, base_path, explorer): def load_data(self, base_path: str, explorer: str):
if base_path and os.path.isfile(base_path): if base_path and os.path.isfile(base_path):
file_path = base_path file_path = base_path
elif explorer and os.path.isfile(explorer): elif explorer and os.path.isfile(explorer):
......
...@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples): ...@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
# #
####################################################### #######################################################
def process_input(self, data, dark_objects): def process_input(self, data: np.ndarray, dark_objects: bool):
# Load volume # Load volume
try: try:
self.vol = load(data.name) self.vol = load(data.name)
...@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples): ...@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
self.vmin = np.min(self.vol) self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol) self.vmax = np.max(self.vol)
def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"): def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
plt.close() plt.close()
z_idx = int(zpos * (vol.shape[0] - 1)) z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
...@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples): ...@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
return fig return fig
def make_binary(self, threshold): def make_binary(self, threshold: float):
# Make a binary volume # Make a binary volume
# Nothing fancy, but we could add new features here # Nothing fancy, but we could add new features here
self.vol_binary = self.vol > (threshold * np.max(self.vol)) self.vol_binary = self.vol > (threshold * np.max(self.vol))
def compute_localthickness(self, lt_scale): def compute_localthickness(self, lt_scale: float):
self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale) self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
# Valus for visualization # Valus for visualization
self.vmin_lt = np.min(self.vol_thickness) self.vmin_lt = np.min(self.vol_thickness)
self.vmax_lt = np.max(self.vol_thickness) self.vmax_lt = np.max(self.vol_thickness)
def thickness_histogram(self, nbins): def thickness_histogram(self, nbins: int):
# Ignore zero thickness # Ignore zero thickness
non_zero_values = self.vol_thickness[self.vol_thickness > 0] non_zero_values = self.vol_thickness[self.vol_thickness > 0]
......
...@@ -181,16 +181,16 @@ class OMEScaler( ...@@ -181,16 +181,16 @@ class OMEScaler(
def export_ome_zarr( def export_ome_zarr(
path, path: str,
data, data: np.ndarray,
chunk_size=256, chunk_size: int = 256,
downsample_rate=2, downsample_rate: int = 2,
order=1, order: int = 1,
replace=False, replace: bool = False,
method="scaleZYX", method: str = "scaleZYX",
progress_bar: bool = True, progress_bar: bool = True,
progress_bar_repeat_time="auto", progress_bar_repeat_time: str = "auto",
): ) -> None:
""" """
Export 3D image data to OME-Zarr format with pyramidal downsampling. Export 3D image data to OME-Zarr format with pyramidal downsampling.
...@@ -299,7 +299,11 @@ def export_ome_zarr( ...@@ -299,7 +299,11 @@ def export_ome_zarr(
return return
def import_ome_zarr(path, scale=0, load=True): def import_ome_zarr(
path: str,
scale: int = 0,
load: bool = True
) -> np.ndarray:
""" """
Import image data from an OME-Zarr file. Import image data from an OME-Zarr file.
......
...@@ -401,15 +401,15 @@ class DataSaver: ...@@ -401,15 +401,15 @@ class DataSaver:
def save( def save(
path, path: str,
data, data: np.ndarray,
replace=False, replace: bool = False,
compression=False, compression: bool = False,
basename=None, basename: bool = None,
sliced_dim=0, sliced_dim: int = 0,
chunk_shape="auto", chunk_shape: str = "auto",
**kwargs, **kwargs,
): ) -> None:
"""Save data to a specified file path. """Save data to a specified file path.
Args: Args:
...@@ -464,7 +464,10 @@ def save( ...@@ -464,7 +464,10 @@ def save(
).save(path, data) ).save(path, data)
def save_mesh(filename, mesh): def save_mesh(
filename: str,
mesh: trimesh.Trimesh
) -> None:
""" """
Save a trimesh object to an .obj file. Save a trimesh object to an .obj file.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment