Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes
Commits on Source (6)
Showing
with 197 additions and 164 deletions
...@@ -159,20 +159,12 @@ def main(): ...@@ -159,20 +159,12 @@ def main():
elif args.subcommand == "viz": elif args.subcommand == "viz":
if args.method == "itk-vtk": if args.method == "itk-vtk":
try:
# We need the full path to the file for the viewer # We need the full path to the file for the viewer
current_dir = os.getcwd() current_dir = os.getcwd()
full_path = os.path.normpath(os.path.join(current_dir, args.source)) full_path = os.path.normpath(os.path.join(current_dir, args.source))
qim3d.viz.itk_vtk(full_path, open_browser = not args.no_browser) qim3d.viz.itk_vtk(full_path, open_browser = not args.no_browser)
except qim3d.viz.NotInstalledError as err:
print(err)
message = "Itk-vtk-viewer is not installed or qim3d can not find it.\nYou can either:\n\to Use 'qim3d viz SOURCE -m k3d' to display data using different method\n\to Install itk-vtk-viewer yourself following https://kitware.github.io/itk-vtk-viewer/docs/cli.html#Installation\n\to Let qim3D install itk-vtk-viewer now (it will also install node.js in qim3d library)\nDo you want qim3D to install itk-vtk-viewer now?"
print(message)
answer = input("[Y/n]:")
if answer in "Yy":
qim3d.viz.Installer().install()
qim3d.viz.itk_vtk(full_path)
elif args.method == "k3d": elif args.method == "k3d":
volume = qim3d.io.load(str(args.source)) volume = qim3d.io.load(str(args.source))
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
import numpy as np import numpy as np
from qim3d.utils._logger import log from qim3d.utils._logger import log
__all__ = ["blob_detection"] __all__ = ["blobs"]
def blob_detection( def blobs(
vol: np.ndarray, vol: np.ndarray,
background: str = "dark", background: str = "dark",
min_sigma: float = 1, min_sigma: float = 1,
...@@ -15,7 +15,7 @@ def blob_detection( ...@@ -15,7 +15,7 @@ def blob_detection(
overlap: float = 0.5, overlap: float = 0.5,
threshold_rel: float = None, threshold_rel: float = None,
exclude_border: bool = False, exclude_border: bool = False,
) -> np.ndarray: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True
......
...@@ -5,7 +5,9 @@ import trimesh ...@@ -5,7 +5,9 @@ import trimesh
import qim3d import qim3d
def volume(obj, **mesh_kwargs) -> float: def volume(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
""" """
Compute the volume of a 3D volume or mesh. Compute the volume of a 3D volume or mesh.
...@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float: ...@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
return obj.volume return obj.volume
def area(obj, **mesh_kwargs) -> float: def area(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
""" """
Compute the surface area of a 3D volume or mesh. Compute the surface area of a 3D volume or mesh.
...@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float: ...@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
return obj.area return obj.area
def sphericity(obj, **mesh_kwargs) -> float: def sphericity(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
""" """
Compute the sphericity of a 3D volume or mesh. Compute the sphericity of a 3D volume or mesh.
......
...@@ -26,7 +26,10 @@ __all__ = [ ...@@ -26,7 +26,10 @@ __all__ = [
class FilterBase: class FilterBase:
def __init__(self, dask=False, chunks="auto", *args, **kwargs): def __init__(self,
dask: bool = False,
chunks: str = "auto",
*args, **kwargs):
""" """
Base class for image filters. Base class for image filters.
...@@ -40,7 +43,7 @@ class FilterBase: ...@@ -40,7 +43,7 @@ class FilterBase:
self.kwargs = kwargs self.kwargs = kwargs
class Gaussian(FilterBase): class Gaussian(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
Applies a Gaussian filter to the input. Applies a Gaussian filter to the input.
...@@ -54,7 +57,7 @@ class Gaussian(FilterBase): ...@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
class Median(FilterBase): class Median(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
Applies a median filter to the input. Applies a median filter to the input.
...@@ -68,7 +71,7 @@ class Median(FilterBase): ...@@ -68,7 +71,7 @@ class Median(FilterBase):
class Maximum(FilterBase): class Maximum(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
Applies a maximum filter to the input. Applies a maximum filter to the input.
...@@ -82,7 +85,7 @@ class Maximum(FilterBase): ...@@ -82,7 +85,7 @@ class Maximum(FilterBase):
class Minimum(FilterBase): class Minimum(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
Applies a minimum filter to the input. Applies a minimum filter to the input.
...@@ -95,7 +98,7 @@ class Minimum(FilterBase): ...@@ -95,7 +98,7 @@ class Minimum(FilterBase):
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs) return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase): class Tophat(FilterBase):
def __call__(self, input): def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
Applies a tophat filter to the input. Applies a tophat filter to the input.
...@@ -210,7 +213,10 @@ class Pipeline: ...@@ -210,7 +213,10 @@ class Pipeline:
return input return input
def gaussian(vol, dask=False, chunks='auto', *args, **kwargs): def gaussian(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
*args, **kwargs) -> np.ndarray:
""" """
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter. Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
...@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs): ...@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
return res return res
def median(vol, dask=False, chunks='auto', **kwargs): def median(vol: np.ndarray,
dask: bool = False,
chunks: str ='auto',
**kwargs) -> np.ndarray:
""" """
Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter. Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
...@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs): ...@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
return res return res
def maximum(vol, dask=False, chunks='auto', **kwargs): def maximum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter. Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
...@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs): ...@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
return res return res
def minimum(vol, dask=False, chunks='auto', **kwargs): def minimum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter. Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
...@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs): ...@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
res = ndimage.minimum_filter(vol, **kwargs) res = ndimage.minimum_filter(vol, **kwargs)
return res return res
def tophat(vol, dask=False, chunks='auto', **kwargs): def tophat(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
""" """
Remove background from the volume. Remove background from the volume.
......
...@@ -142,7 +142,7 @@ def noise_object_collection( ...@@ -142,7 +142,7 @@ def noise_object_collection(
object_shape: str = None, object_shape: str = None,
seed: int = 0, seed: int = 0,
verbose: bool = False, verbose: bool = False,
) -> tuple[np.ndarray, object]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Generate a 3D volume of multiple synthetic objects using Perlin noise. Generate a 3D volume of multiple synthetic objects using Perlin noise.
......
...@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface ...@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface): class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img=None): def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
super().__init__( super().__init__(
title="Annotation Tool", title="Annotation Tool",
height=768, height=768,
...@@ -55,7 +55,7 @@ class Interface(BaseInterface): ...@@ -55,7 +55,7 @@ class Interface(BaseInterface):
self.masks_rgb = None self.masks_rgb = None
self.temp_files = [] self.temp_files = []
def get_result(self): def get_result(self) -> dict:
# Get the temporary files from gradio # Get the temporary files from gradio
temp_path_list = [] temp_path_list = []
for filename in os.listdir(self.temp_dir): for filename in os.listdir(self.temp_dir):
...@@ -95,13 +95,13 @@ class Interface(BaseInterface): ...@@ -95,13 +95,13 @@ class Interface(BaseInterface):
except FileNotFoundError: except FileNotFoundError:
files = None files = None
def create_preview(self, img_editor): def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"] background = img_editor["background"]
masks = img_editor["layers"][0] masks = img_editor["layers"][0]
overlay_image = overlay_rgb_images(background, masks) overlay_image = overlay_rgb_images(background, masks)
return overlay_image return overlay_image
def cerate_download_list(self, img_editor): def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0] masks_rgb = img_editor["layers"][0]
mask_threshold = 200 # This value is based mask_threshold = 200 # This value is based
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import re import re
import gradio as gr import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import outputformat as ouf import outputformat as ouf
...@@ -29,6 +30,8 @@ from qim3d.utils._logger import log ...@@ -29,6 +30,8 @@ from qim3d.utils._logger import log
from qim3d.utils import _misc from qim3d.utils import _misc
from qim3d.gui.interface import BaseInterface from qim3d.gui.interface import BaseInterface
from typing import Callable, Any, Dict
import matplotlib
class Interface(BaseInterface): class Interface(BaseInterface):
...@@ -271,7 +274,7 @@ class Interface(BaseInterface): ...@@ -271,7 +274,7 @@ class Interface(BaseInterface):
operations.change(fn=self.show_results, inputs = operations, outputs = results) operations.change(fn=self.show_results, inputs = operations, outputs = results)
cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs) cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
def update_explorer(self, new_path): def update_explorer(self, new_path: str):
new_path = os.path.expanduser(new_path) new_path = os.path.expanduser(new_path)
# In case we have a directory # In case we have a directory
...@@ -367,7 +370,7 @@ class Interface(BaseInterface): ...@@ -367,7 +370,7 @@ class Interface(BaseInterface):
except Exception as error_message: except Exception as error_message:
self.error_message = F"Error when loading data: {error_message}" self.error_message = F"Error when loading data: {error_message}"
def run_operations(self, operations, *args): def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
outputs = [] outputs = []
self.calculated_operations = [] self.calculated_operations = []
for operation in self.all_operations: for operation in self.all_operations:
...@@ -411,7 +414,7 @@ class Interface(BaseInterface): ...@@ -411,7 +414,7 @@ class Interface(BaseInterface):
case _: case _:
raise NotImplementedError(F"Operation '{operation} is not defined") raise NotImplementedError(F"Operation '{operation} is not defined")
def show_results(self, operations): def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
update_list = [] update_list = []
for operation in self.all_operations: for operation in self.all_operations:
if operation in operations and operation in self.calculated_operations: if operation in operations and operation in self.calculated_operations:
...@@ -426,7 +429,7 @@ class Interface(BaseInterface): ...@@ -426,7 +429,7 @@ class Interface(BaseInterface):
# #
####################################################### #######################################################
def create_img_fig(self, img, **kwargs): def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(img, interpolation="nearest", **kwargs) ax.imshow(img, interpolation="nearest", **kwargs)
...@@ -437,8 +440,8 @@ class Interface(BaseInterface): ...@@ -437,8 +440,8 @@ class Interface(BaseInterface):
return fig return fig
def update_slice_wrapper(self, letter): def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
def update_slice(position_slider:float, cmap:str): def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]:
""" """
position_slider: float from gradio slider, saying which relative slice we want to see position_slider: float from gradio slider, saying which relative slice we want to see
cmap: string gradio drop down menu, saying what cmap we want to use for display cmap: string gradio drop down menu, saying what cmap we want to use for display
...@@ -465,7 +468,7 @@ class Interface(BaseInterface): ...@@ -465,7 +468,7 @@ class Interface(BaseInterface):
return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True) return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
return update_slice return update_slice
def vol_histogram(self, nbins, min_value, max_value): def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
# Start histogram # Start histogram
vol_hist = np.zeros(nbins) vol_hist = np.zeros(nbins)
...@@ -478,7 +481,7 @@ class Interface(BaseInterface): ...@@ -478,7 +481,7 @@ class Interface(BaseInterface):
return vol_hist, bin_edges return vol_hist, bin_edges
def plot_histogram(self): def plot_histogram(self) -> matplotlib.figure.Figure:
# The Histogram needs results from the projections # The Histogram needs results from the projections
if not self.projections_calculated: if not self.projections_calculated:
_ = self.get_projections() _ = self.get_projections()
...@@ -498,7 +501,7 @@ class Interface(BaseInterface): ...@@ -498,7 +501,7 @@ class Interface(BaseInterface):
return fig return fig
def create_projections_figs(self): def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
if not self.projections_calculated: if not self.projections_calculated:
projections = self.get_projections() projections = self.get_projections()
self.max_projection = projections[0] self.max_projection = projections[0]
...@@ -519,7 +522,7 @@ class Interface(BaseInterface): ...@@ -519,7 +522,7 @@ class Interface(BaseInterface):
self.projections_calculated = True self.projections_calculated = True
return max_projection_fig, min_projection_fig return max_projection_fig, min_projection_fig
def get_projections(self): def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
# Create arrays for iteration # Create arrays for iteration
max_projection = np.zeros(np.shape(self.vol[0])) max_projection = np.zeros(np.shape(self.vol[0]))
min_projection = np.ones(np.shape(self.vol[0])) * float("inf") min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
......
...@@ -6,6 +6,7 @@ import gradio as gr ...@@ -6,6 +6,7 @@ import gradio as gr
from .qim_theme import QimTheme from .qim_theme import QimTheme
import qim3d.gui import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli # TODO: when offline it throws an error in cli
...@@ -48,10 +49,10 @@ class BaseInterface(ABC): ...@@ -48,10 +49,10 @@ class BaseInterface(ABC):
def set_invisible(self): def set_invisible(self):
return gr.update(visible=False) return gr.update(visible=False)
def change_visibility(self, is_visible): def change_visibility(self, is_visible: bool):
return gr.update(visible = is_visible) return gr.update(visible = is_visible)
def launch(self, img=None, force_light_mode: bool = True, **kwargs): def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
""" """
img: If None, user can upload image after the interface is launched. img: If None, user can upload image after the interface is launched.
If defined, the interface will be launched with the image already there If defined, the interface will be launched with the image already there
...@@ -76,7 +77,7 @@ class BaseInterface(ABC): ...@@ -76,7 +77,7 @@ class BaseInterface(ABC):
**kwargs, **kwargs,
) )
def clear(self): def clear(self) -> None:
"""Used to reset outputs with the clear button""" """Used to reset outputs with the clear button"""
return None return None
......
...@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples): ...@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
self.img = img self.img = img
self.plot_height = plot_height self.plot_height = plot_height
def load_data(self, gradiofile): def load_data(self, gradiofile: gr.File):
try: try:
self.vol = load(gradiofile.name) self.vol = load(gradiofile.name)
assert self.vol.ndim == 3 assert self.vol.ndim == 3
...@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples): ...@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
except AssertionError: except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}") raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
def resize_vol(self, display_size): def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size""" """Resizes the loaded volume to the display size"""
# Get original size # Get original size
...@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples): ...@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}" f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
) )
def save_fig(self, fig, filename): def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk # Write Plotly figure to disk
fig.write_html(filename) fig.write_html(filename)
def create_fig(self, def create_fig(self,
gradio_file, gradio_file: gr.File,
display_size, display_size: int ,
opacity, opacity: float,
opacityscale, opacityscale: str,
only_wireframe, only_wireframe: bool,
min_value, min_value: float,
max_value, max_value: float,
surface_count, surface_count: int,
colormap, colormap: str,
show_colorbar, show_colorbar: bool,
reversescale, reversescale: bool,
flip_z, flip_z: bool,
show_axis, show_axis: bool,
show_ticks, show_ticks: bool,
show_caps, show_caps: bool,
show_z_slice, show_z_slice: bool,
slice_z_location, slice_z_location: int,
show_y_slice, show_y_slice: bool,
slice_y_location, slice_y_location: int,
show_x_slice, show_x_slice: bool,
slice_x_location,): slice_x_location: int,
) -> tuple[go.Figure, str]:
# Load volume # Load volume
self.load_data(gradio_file) self.load_data(gradio_file)
......
...@@ -25,9 +25,11 @@ import numpy as np ...@@ -25,9 +25,11 @@ import numpy as np
from .interface import BaseInterface from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d # from qim3d.processing import layers2d as l2d
from qim3d.processing import overlay_rgb_images, segment_layers, get_lines from qim3d.processing import segment_layers, get_lines
from qim3d.operations import overlay_rgb_images
from qim3d.io import load from qim3d.io import load
from qim3d.viz._layers2d import image_with_lines from qim3d.viz._layers2d import image_with_lines
from typing import Dict, Any
#TODO figure out how not update anything and go through processing when there are no data loaded #TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error # So user could play with the widgets but it doesnt throw error
...@@ -302,14 +304,14 @@ class Interface(BaseInterface): ...@@ -302,14 +304,14 @@ class Interface(BaseInterface):
def change_plot_type(self, plot_type, ): def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
self.plot_type = plot_type self.plot_type = plot_type
if plot_type == 'Segmentation lines': if plot_type == 'Segmentation lines':
return gr.update(visible = False), gr.update(visible = True) return gr.update(visible = False), gr.update(visible = True)
else: else:
return gr.update(visible = True), gr.update(visible = False) return gr.update(visible = True), gr.update(visible = False)
def change_plot_size(self, x_check, y_check, z_check): def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
""" """
Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
also their height because gradio doesn't do it automatically. The values of heights were set just by eye. also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
...@@ -319,10 +321,10 @@ class Interface(BaseInterface): ...@@ -319,10 +321,10 @@ class Interface(BaseInterface):
height = self.heights[index] # also used to define heights of plots in the begining height = self.heights[index] # also used to define heights of plots in the begining
return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check) return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
def change_row_visibility(self, x_check, y_check, z_check): def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check) return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
def update_explorer(self, new_path): def update_explorer(self, new_path: str):
# Refresh the file explorer object # Refresh the file explorer object
new_path = os.path.expanduser(new_path) new_path = os.path.expanduser(new_path)
...@@ -341,13 +343,13 @@ class Interface(BaseInterface): ...@@ -341,13 +343,13 @@ class Interface(BaseInterface):
def set_relaunch_button(self): def set_relaunch_button(self):
return gr.update(value=f"Relaunch", interactive=True) return gr.update(value=f"Relaunch", interactive=True)
def set_spinner(self, message): def set_spinner(self, message: str):
if self.error: if self.error:
return gr.Button() return gr.Button()
# spinner icon/shows the user something is happeing # spinner icon/shows the user something is happeing
return gr.update(value=f"{message}", interactive=False) return gr.update(value=f"{message}", interactive=False)
def load_data(self, base_path, explorer): def load_data(self, base_path: str, explorer: str):
if base_path and os.path.isfile(base_path): if base_path and os.path.isfile(base_path):
file_path = base_path file_path = base_path
elif explorer and os.path.isfile(explorer): elif explorer and os.path.isfile(explorer):
......
...@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples ...@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples
class Interface(InterfaceWithExamples): class Interface(InterfaceWithExamples):
def __init__(self, def __init__(self,
img = None, img: np.ndarray = None,
verbose:bool = False, verbose:bool = False,
plot_height:int = 768, plot_height:int = 768,
figsize:int = 6): figsize:int = 6):
...@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples): ...@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
# #
####################################################### #######################################################
def process_input(self, data, dark_objects): def process_input(self, data: np.ndarray, dark_objects: bool):
# Load volume # Load volume
try: try:
self.vol = load(data.name) self.vol = load(data.name)
...@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples): ...@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
self.vmin = np.min(self.vol) self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol) self.vmax = np.max(self.vol)
def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"): def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
plt.close() plt.close()
z_idx = int(zpos * (vol.shape[0] - 1)) z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize)) fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
...@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples): ...@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
return fig return fig
def make_binary(self, threshold): def make_binary(self, threshold: float):
# Make a binary volume # Make a binary volume
# Nothing fancy, but we could add new features here # Nothing fancy, but we could add new features here
self.vol_binary = self.vol > (threshold * np.max(self.vol)) self.vol_binary = self.vol > (threshold * np.max(self.vol))
def compute_localthickness(self, lt_scale): def compute_localthickness(self, lt_scale: float):
self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale) self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
# Valus for visualization # Valus for visualization
self.vmin_lt = np.min(self.vol_thickness) self.vmin_lt = np.min(self.vol_thickness)
self.vmax_lt = np.max(self.vol_thickness) self.vmax_lt = np.max(self.vol_thickness)
def thickness_histogram(self, nbins): def thickness_histogram(self, nbins: int):
# Ignore zero thickness # Ignore zero thickness
non_zero_values = self.vol_thickness[self.vol_thickness > 0] non_zero_values = self.vol_thickness[self.vol_thickness > 0]
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import tifffile as tiff import tifffile as tiff
import zarr import zarr
from tqdm import tqdm from tqdm import tqdm
import zarr.core
from qim3d.utils._misc import stringify_path from qim3d.utils._misc import stringify_path
from qim3d.io._saving import save from qim3d.io._saving import save
...@@ -21,7 +22,7 @@ class Convert: ...@@ -21,7 +22,7 @@ class Convert:
""" """
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64)) self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
def convert(self, input_path, output_path): def convert(self, input_path: str, output_path: str):
def get_file_extension(file_path): def get_file_extension(file_path):
root, ext = os.path.splitext(file_path) root, ext = os.path.splitext(file_path)
if ext in ['.gz', '.bz2', '.xz']: # handle common compressed extensions if ext in ['.gz', '.bz2', '.xz']: # handle common compressed extensions
...@@ -67,7 +68,7 @@ class Convert: ...@@ -67,7 +68,7 @@ class Convert:
else: else:
raise ValueError("Invalid path") raise ValueError("Invalid path")
def convert_tif_to_zarr(self, tif_path, zarr_path): def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file """Convert a tiff file to a zarr file
Args: Args:
...@@ -97,7 +98,7 @@ class Convert: ...@@ -97,7 +98,7 @@ class Convert:
return z return z
def convert_zarr_to_tif(self, zarr_path, tif_path): def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file """Convert a zarr file to a tiff file
Args: Args:
...@@ -110,7 +111,7 @@ class Convert: ...@@ -110,7 +111,7 @@ class Convert:
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(tif_path, z) save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path, zarr_path): def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file """Convert a nifti file to a zarr file
Args: Args:
...@@ -139,7 +140,7 @@ class Convert: ...@@ -139,7 +140,7 @@ class Convert:
return z return z
def convert_zarr_to_nifti(self, zarr_path, nifti_path, compression=False): def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
"""Convert a zarr file to a nifti file """Convert a zarr file to a nifti file
Args: Args:
...@@ -153,7 +154,7 @@ class Convert: ...@@ -153,7 +154,7 @@ class Convert:
save(nifti_path, z, compression=compression) save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)): def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
"""Convert a file to another format without loading the entire file into memory """Convert a file to another format without loading the entire file into memory
Args: Args:
......
...@@ -76,7 +76,7 @@ class Downloader: ...@@ -76,7 +76,7 @@ class Downloader:
[file_name_n](load_file,optional): Function to download file number n in the given folder. [file_name_n](load_file,optional): Function to download file number n in the given folder.
""" """
def __init__(self, folder): def __init__(self, folder: str):
files = _extract_names(folder) files = _extract_names(folder)
for idx, file in enumerate(files): for idx, file in enumerate(files):
...@@ -88,7 +88,7 @@ class Downloader: ...@@ -88,7 +88,7 @@ class Downloader:
setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file)) setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
def _make_fn(self, folder, file): def _make_fn(self, folder: str, file: str):
"""Private method that returns a function. The function downloads the chosen file from the folder. """Private method that returns a function. The function downloads the chosen file from the folder.
Args: Args:
...@@ -101,7 +101,7 @@ class Downloader: ...@@ -101,7 +101,7 @@ class Downloader:
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository" url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
def _download(load_file=False, virtual_stack=True): def _download(load_file: bool = False, virtual_stack: bool = True):
"""Downloads the file and optionally also loads it. """Downloads the file and optionally also loads it.
Args: Args:
...@@ -121,7 +121,7 @@ class Downloader: ...@@ -121,7 +121,7 @@ class Downloader:
return _download return _download
def _update_progress(pbar, blocknum, bs): def _update_progress(pbar: tqdm, blocknum: int, bs: int):
""" """
Helper function for the ´download_file()´ function. Updates the progress bar. Helper function for the ´download_file()´ function. Updates the progress bar.
""" """
...@@ -129,7 +129,7 @@ def _update_progress(pbar, blocknum, bs): ...@@ -129,7 +129,7 @@ def _update_progress(pbar, blocknum, bs):
pbar.update(blocknum * bs - pbar.n) pbar.update(blocknum * bs - pbar.n)
def _get_file_size(url): def _get_file_size(url: str):
""" """
Helper function for the ´download_file()´ function. Finds the size of the file. Helper function for the ´download_file()´ function. Finds the size of the file.
""" """
...@@ -137,7 +137,7 @@ def _get_file_size(url): ...@@ -137,7 +137,7 @@ def _get_file_size(url):
return int(urllib.request.urlopen(url).info().get("Content-Length", -1)) return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
def download_file(path, name, file): def download_file(path: str, name: str, file: str):
"""Downloads the file from path / name / file. """Downloads the file from path / name / file.
Args: Args:
...@@ -177,7 +177,7 @@ def download_file(path, name, file): ...@@ -177,7 +177,7 @@ def download_file(path, name, file):
) )
def _extract_html(url): def _extract_html(url: str):
"""Extracts the html content of a webpage in "utf-8" """Extracts the html content of a webpage in "utf-8"
Args: Args:
...@@ -198,7 +198,7 @@ def _extract_html(url): ...@@ -198,7 +198,7 @@ def _extract_html(url):
return html_content return html_content
def _extract_names(name=None): def _extract_names(name: str = None):
"""Extracts the names of the folders and files. """Extracts the names of the folders and files.
Finds the names of either the folders if no name is given, Finds the names of either the folders if no name is given,
......
...@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory ...@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
from qim3d.utils._progress_bar import FileLoadingProgressBar from qim3d.utils._progress_bar import FileLoadingProgressBar
import trimesh import trimesh
from typing import Optional, Dict
dask.config.set(scheduler="processes") dask.config.set(scheduler="processes")
...@@ -76,7 +78,7 @@ class DataLoader: ...@@ -76,7 +78,7 @@ class DataLoader:
self.dim_order = kwargs.get("dim_order", (2, 1, 0)) self.dim_order = kwargs.get("dim_order", (2, 1, 0))
self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp") self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
def load_tiff(self, path): def load_tiff(self, path: str|os.PathLike):
"""Load a TIFF file from the specified path. """Load a TIFF file from the specified path.
Args: Args:
...@@ -100,7 +102,7 @@ class DataLoader: ...@@ -100,7 +102,7 @@ class DataLoader:
return vol return vol
def load_h5(self, path): def load_h5(self, path: str|os.PathLike) -> tuple[np.ndarray, Optional[Dict]]:
"""Load an HDF5 file from the specified path. """Load an HDF5 file from the specified path.
Args: Args:
...@@ -183,7 +185,7 @@ class DataLoader: ...@@ -183,7 +185,7 @@ class DataLoader:
else: else:
return vol return vol
def load_tiff_stack(self, path): def load_tiff_stack(self, path: str|os.PathLike) -> np.ndarray|np.memmap:
"""Load a stack of TIFF files from the specified path. """Load a stack of TIFF files from the specified path.
Args: Args:
...@@ -237,7 +239,7 @@ class DataLoader: ...@@ -237,7 +239,7 @@ class DataLoader:
return vol return vol
def load_txrm(self, path): def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]:
"""Load a TXRM/XRM/TXM file from the specified path. """Load a TXRM/XRM/TXM file from the specified path.
Args: Args:
...@@ -308,7 +310,7 @@ class DataLoader: ...@@ -308,7 +310,7 @@ class DataLoader:
else: else:
return vol return vol
def load_nifti(self, path): def load_nifti(self, path: str|os.PathLike):
"""Load a NIfTI file from the specified path. """Load a NIfTI file from the specified path.
Args: Args:
...@@ -338,7 +340,7 @@ class DataLoader: ...@@ -338,7 +340,7 @@ class DataLoader:
else: else:
return vol return vol
def load_pil(self, path): def load_pil(self, path: str|os.PathLike):
"""Load a PIL image from the specified path """Load a PIL image from the specified path
Args: Args:
...@@ -349,7 +351,7 @@ class DataLoader: ...@@ -349,7 +351,7 @@ class DataLoader:
""" """
return np.array(Image.open(path)) return np.array(Image.open(path))
def load_PIL_stack(self, path): def load_PIL_stack(self, path: str|os.PathLike):
"""Load a stack of PIL files from the specified path. """Load a stack of PIL files from the specified path.
Args: Args:
...@@ -433,7 +435,7 @@ class DataLoader: ...@@ -433,7 +435,7 @@ class DataLoader:
def _load_vgi_metadata(self, path): def _load_vgi_metadata(self, path: str|os.PathLike):
"""Helper functions that loads metadata from a VGI file """Helper functions that loads metadata from a VGI file
Args: Args:
...@@ -482,7 +484,7 @@ class DataLoader: ...@@ -482,7 +484,7 @@ class DataLoader:
return meta_data return meta_data
def load_vol(self, path): def load_vol(self, path: str|os.PathLike):
"""Load a VOL filed based on the VGI metadata file """Load a VOL filed based on the VGI metadata file
Args: Args:
...@@ -548,7 +550,7 @@ class DataLoader: ...@@ -548,7 +550,7 @@ class DataLoader:
else: else:
return vol return vol
def load_dicom(self, path): def load_dicom(self, path: str|os.PathLike):
"""Load a DICOM file """Load a DICOM file
Args: Args:
...@@ -563,7 +565,7 @@ class DataLoader: ...@@ -563,7 +565,7 @@ class DataLoader:
else: else:
return dcm_data.pixel_array return dcm_data.pixel_array
def load_dicom_dir(self, path): def load_dicom_dir(self, path: str|os.PathLike):
"""Load a directory of DICOM files into a numpy 3d array """Load a directory of DICOM files into a numpy 3d array
Args: Args:
...@@ -605,7 +607,7 @@ class DataLoader: ...@@ -605,7 +607,7 @@ class DataLoader:
return vol return vol
def load_zarr(self, path: str): def load_zarr(self, path: str|os.PathLike):
""" Loads a Zarr array from disk. """ Loads a Zarr array from disk.
Args: Args:
...@@ -654,7 +656,7 @@ class DataLoader: ...@@ -654,7 +656,7 @@ class DataLoader:
message + " Set 'force_load=True' to ignore this error." message + " Set 'force_load=True' to ignore this error."
) )
def load(self, path): def load(self, path: str|os.PathLike):
""" """
Load a file or directory based on the given path. Load a file or directory based on the given path.
...@@ -757,16 +759,16 @@ def _get_ole_offsets(ole): ...@@ -757,16 +759,16 @@ def _get_ole_offsets(ole):
def load( def load(
path, path: str|os.PathLike,
virtual_stack=False, virtual_stack: bool = False,
dataset_name=None, dataset_name: bool = None,
return_metadata=False, return_metadata: bool = False,
contains=None, contains: bool = None,
progress_bar: bool = True, progress_bar: bool = True,
force_load: bool = False, force_load: bool = False,
dim_order=(2, 1, 0), dim_order: tuple = (2, 1, 0),
**kwargs, **kwargs,
): ) -> np.ndarray:
""" """
Load data from the specified file or directory. Load data from the specified file or directory.
...@@ -854,7 +856,7 @@ def load( ...@@ -854,7 +856,7 @@ def load(
return data return data
def load_mesh(filename): def load_mesh(filename: str) -> trimesh.Trimesh:
""" """
Load a mesh from an .obj file using trimesh. Load a mesh from an .obj file using trimesh.
......
...@@ -46,13 +46,13 @@ class OMEScaler( ...@@ -46,13 +46,13 @@ class OMEScaler(
"""Scaler in the style of OME-Zarr. """Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken.""" This is needed because their current zoom implementation is broken."""
def __init__(self, order=0, downscale=2, max_layer=5, method="scaleZYXdask"): def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
self.order = order self.order = order
self.downscale = downscale self.downscale = downscale
self.max_layer = max_layer self.max_layer = max_layer
self.method = method self.method = method
def scaleZYX(self, base): def scaleZYX(self, base: da.core.Array):
"""Downsample using :func:`scipy.ndimage.zoom`.""" """Downsample using :func:`scipy.ndimage.zoom`."""
rv = [base] rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}") log.info(f"- Scale 0: {rv[-1].shape}")
...@@ -63,7 +63,7 @@ class OMEScaler( ...@@ -63,7 +63,7 @@ class OMEScaler(
return list(rv) return list(rv)
def scaleZYXdask(self, base): def scaleZYXdask(self, base: da.core.Array):
""" """
Downsample a 3D volume using Dask and scipy.ndimage.zoom. Downsample a 3D volume using Dask and scipy.ndimage.zoom.
...@@ -82,7 +82,7 @@ class OMEScaler( ...@@ -82,7 +82,7 @@ class OMEScaler(
""" """
def resize_zoom(vol, scale_factors, order, scaled_shape): def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape # Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library # This snippet comes from the original OME-Zarr-python library
...@@ -181,16 +181,16 @@ class OMEScaler( ...@@ -181,16 +181,16 @@ class OMEScaler(
def export_ome_zarr( def export_ome_zarr(
path, path: str|os.PathLike,
data, data: np.ndarray|da.core.Array,
chunk_size=256, chunk_size: int = 256,
downsample_rate=2, downsample_rate: int = 2,
order=1, order: int = 1,
replace=False, replace: bool = False,
method="scaleZYX", method: str = "scaleZYX",
progress_bar: bool = True, progress_bar: bool = True,
progress_bar_repeat_time="auto", progress_bar_repeat_time: str = "auto",
): ) -> None:
""" """
Export 3D image data to OME-Zarr format with pyramidal downsampling. Export 3D image data to OME-Zarr format with pyramidal downsampling.
...@@ -299,7 +299,11 @@ def export_ome_zarr( ...@@ -299,7 +299,11 @@ def export_ome_zarr(
return return
def import_ome_zarr(path, scale=0, load=True): def import_ome_zarr(
path: str|os.PathLike,
scale: int = 0,
load: bool = True
) -> np.ndarray:
""" """
Import image data from an OME-Zarr file. Import image data from an OME-Zarr file.
......
...@@ -76,7 +76,7 @@ class DataSaver: ...@@ -76,7 +76,7 @@ class DataSaver:
self.sliced_dim = kwargs.get("sliced_dim", 0) self.sliced_dim = kwargs.get("sliced_dim", 0)
self.chunk_shape = kwargs.get("chunk_shape", "auto") self.chunk_shape = kwargs.get("chunk_shape", "auto")
def save_tiff(self, path, data): def save_tiff(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a TIFF file to the given path. """Save data to a TIFF file to the given path.
Args: Args:
...@@ -85,7 +85,7 @@ class DataSaver: ...@@ -85,7 +85,7 @@ class DataSaver:
""" """
tifffile.imwrite(path, data, compression=self.compression) tifffile.imwrite(path, data, compression=self.compression)
def save_tiff_stack(self, path, data): def save_tiff_stack(self, path: str|os.PathLike, data: np.ndarray):
"""Save data as a TIFF stack containing slices in separate files to the given path. """Save data as a TIFF stack containing slices in separate files to the given path.
The slices will be named according to the basename plus a suffix with a zero-filled The slices will be named according to the basename plus a suffix with a zero-filled
value corresponding to the slice number value corresponding to the slice number
...@@ -124,7 +124,7 @@ class DataSaver: ...@@ -124,7 +124,7 @@ class DataSaver:
f"Total of {no_slices} files saved following the pattern '{pattern_string}'" f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
) )
def save_nifti(self, path, data): def save_nifti(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a NIfTI file to the given path. """Save data to a NIfTI file to the given path.
Args: Args:
...@@ -154,7 +154,7 @@ class DataSaver: ...@@ -154,7 +154,7 @@ class DataSaver:
# Save image # Save image
nib.save(img, path) nib.save(img, path)
def save_vol(self, path, data): def save_vol(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a VOL file to the given path. """Save data to a VOL file to the given path.
Args: Args:
...@@ -200,7 +200,7 @@ class DataSaver: ...@@ -200,7 +200,7 @@ class DataSaver:
"dataset", data=data, compression="gzip" if self.compression else None "dataset", data=data, compression="gzip" if self.compression else None
) )
def save_dicom(self, path, data): def save_dicom(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a DICOM file to the given path. """Save data to a DICOM file to the given path.
Args: Args:
...@@ -255,7 +255,7 @@ class DataSaver: ...@@ -255,7 +255,7 @@ class DataSaver:
ds.save_as(path) ds.save_as(path)
def save_to_zarr(self, path, data): def save_to_zarr(self, path: str|os.PathLike, data: da.core.Array):
"""Saves a Dask array to a Zarr array on disk. """Saves a Dask array to a Zarr array on disk.
Args: Args:
...@@ -284,7 +284,7 @@ class DataSaver: ...@@ -284,7 +284,7 @@ class DataSaver:
) )
zarr_array[:] = data zarr_array[:] = data
def save_PIL(self, path, data): def save_PIL(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a PIL file to the given path. """Save data to a PIL file to the given path.
Args: Args:
...@@ -303,7 +303,7 @@ class DataSaver: ...@@ -303,7 +303,7 @@ class DataSaver:
# Save image # Save image
img.save(path) img.save(path)
def save(self, path, data): def save(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to the given path. """Save data to the given path.
Args: Args:
...@@ -401,15 +401,15 @@ class DataSaver: ...@@ -401,15 +401,15 @@ class DataSaver:
def save( def save(
path, path: str|os.PathLike,
data, data: np.ndarray,
replace=False, replace: bool = False,
compression=False, compression: bool = False,
basename=None, basename: bool = None,
sliced_dim=0, sliced_dim: int = 0,
chunk_shape="auto", chunk_shape: str = "auto",
**kwargs, **kwargs,
): ) -> None:
"""Save data to a specified file path. """Save data to a specified file path.
Args: Args:
...@@ -464,7 +464,10 @@ def save( ...@@ -464,7 +464,10 @@ def save(
).save(path, data) ).save(path, data)
def save_mesh(filename, mesh): def save_mesh(
filename: str,
mesh: trimesh.Trimesh
) -> None:
""" """
Save a trimesh object to an .obj file. Save a trimesh object to an .obj file.
......
...@@ -28,7 +28,7 @@ class Sync: ...@@ -28,7 +28,7 @@ class Sync:
return False return False
def check_destination(self, source, destination, checksum=False, verbose=True): def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
"""Check if all files from 'source' are in 'destination' """Check if all files from 'source' are in 'destination'
This function compares the files in the 'source' directory to those in This function compares the files in the 'source' directory to those in
...@@ -80,7 +80,7 @@ class Sync: ...@@ -80,7 +80,7 @@ class Sync:
return diff_files return diff_files
def compare_dirs(self, source, destination, checksum=False, verbose=True): def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
"""Checks whether 'source' and 'destination' directories are synchronized. """Checks whether 'source' and 'destination' directories are synchronized.
This function compares the contents of two directories This function compares the contents of two directories
...@@ -168,7 +168,7 @@ class Sync: ...@@ -168,7 +168,7 @@ class Sync:
) )
return return
def count_files_and_dirs(self, path, verbose=True): def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
"""Count the number of files and directories in the given path. """Count the number of files and directories in the given path.
This function recursively counts the number of files and This function recursively counts the number of files and
......
...@@ -8,8 +8,8 @@ from qim3d.utils._logger import log ...@@ -8,8 +8,8 @@ from qim3d.utils._logger import log
def from_volume( def from_volume(
volume: np.ndarray, volume: np.ndarray,
level: float = None, level: float = None,
step_size=1, step_size: int = 1,
allow_degenerate=False, allow_degenerate: bool = False,
padding: Tuple[int, int, int] = (2, 2, 2), padding: Tuple[int, int, int] = (2, 2, 2),
**kwargs: Any, **kwargs: Any,
) -> trimesh.Trimesh: ) -> trimesh.Trimesh:
......
...@@ -20,10 +20,10 @@ class Augmentation: ...@@ -20,10 +20,10 @@ class Augmentation:
""" """
def __init__(self, def __init__(self,
resize = 'crop', resize: str = 'crop',
transform_train = 'moderate', transform_train: str = 'moderate',
transform_validation = None, transform_validation: str | None = None,
transform_test = None, transform_test: str | None = None,
mean: float = 0.5, mean: float = 0.5,
std: float = 0.5 std: float = 0.5
): ):
...@@ -38,7 +38,7 @@ class Augmentation: ...@@ -38,7 +38,7 @@ class Augmentation:
self.transform_validation = transform_validation self.transform_validation = transform_validation
self.transform_test = transform_test self.transform_test = transform_test
def augment(self, im_h, im_w, level=None): def augment(self, im_h: int, im_w: int, level: bool | None = None):
""" """
Returns an albumentations.core.composition.Compose class depending on the augmentation level. Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
......