Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • QIM/tools/qim3d
1 result
Show changes
Commits on Source (6)
Showing
with 197 additions and 164 deletions
......@@ -159,20 +159,12 @@ def main():
elif args.subcommand == "viz":
if args.method == "itk-vtk":
try:
# We need the full path to the file for the viewer
current_dir = os.getcwd()
full_path = os.path.normpath(os.path.join(current_dir, args.source))
qim3d.viz.itk_vtk(full_path, open_browser=not args.no_browser)
except qim3d.viz.NotInstalledError as err:
print(err)
message = "Itk-vtk-viewer is not installed or qim3d can not find it.\nYou can either:\n\to Use 'qim3d viz SOURCE -m k3d' to display data using different method\n\to Install itk-vtk-viewer yourself following https://kitware.github.io/itk-vtk-viewer/docs/cli.html#Installation\n\to Let qim3D install itk-vtk-viewer now (it will also install node.js in qim3d library)\nDo you want qim3D to install itk-vtk-viewer now?"
print(message)
answer = input("[Y/n]:")
if answer in "Yy":
qim3d.viz.Installer().install()
qim3d.viz.itk_vtk(full_path)
# We need the full path to the file for the viewer
current_dir = os.getcwd()
full_path = os.path.normpath(os.path.join(current_dir, args.source))
qim3d.viz.itk_vtk(full_path, open_browser = not args.no_browser)
elif args.method == "k3d":
volume = qim3d.io.load(str(args.source))
......
......@@ -3,9 +3,9 @@
import numpy as np
from qim3d.utils._logger import log
__all__ = ["blob_detection"]
__all__ = ["blobs"]
def blob_detection(
def blobs(
vol: np.ndarray,
background: str = "dark",
min_sigma: float = 1,
......@@ -15,7 +15,7 @@ def blob_detection(
overlap: float = 0.5,
threshold_rel: float = None,
exclude_border: bool = False,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
"""
Extract blobs from a volume using Difference of Gaussian (DoG) method, and retrieve a binary volume with the blobs marked as True
......
......@@ -5,7 +5,9 @@ import trimesh
import qim3d
def volume(obj, **mesh_kwargs) -> float:
def volume(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
Compute the volume of a 3D volume or mesh.
......@@ -49,7 +51,9 @@ def volume(obj, **mesh_kwargs) -> float:
return obj.volume
def area(obj, **mesh_kwargs) -> float:
def area(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
Compute the surface area of a 3D volume or mesh.
......@@ -92,7 +96,9 @@ def area(obj, **mesh_kwargs) -> float:
return obj.area
def sphericity(obj, **mesh_kwargs) -> float:
def sphericity(obj: np.ndarray|trimesh.Trimesh,
**mesh_kwargs
) -> float:
"""
Compute the sphericity of a 3D volume or mesh.
......
......@@ -26,7 +26,10 @@ __all__ = [
class FilterBase:
def __init__(self, dask=False, chunks="auto", *args, **kwargs):
def __init__(self,
dask: bool = False,
chunks: str = "auto",
*args, **kwargs):
"""
Base class for image filters.
......@@ -40,7 +43,7 @@ class FilterBase:
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a Gaussian filter to the input.
......@@ -54,7 +57,7 @@ class Gaussian(FilterBase):
class Median(FilterBase):
def __call__(self, input):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a median filter to the input.
......@@ -68,7 +71,7 @@ class Median(FilterBase):
class Maximum(FilterBase):
def __call__(self, input):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a maximum filter to the input.
......@@ -82,7 +85,7 @@ class Maximum(FilterBase):
class Minimum(FilterBase):
def __call__(self, input):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a minimum filter to the input.
......@@ -95,7 +98,7 @@ class Minimum(FilterBase):
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase):
def __call__(self, input):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
Applies a tophat filter to the input.
......@@ -210,7 +213,10 @@ class Pipeline:
return input
def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
def gaussian(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
*args, **kwargs) -> np.ndarray:
"""
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
......@@ -236,7 +242,10 @@ def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
return res
def median(vol, dask=False, chunks='auto', **kwargs):
def median(vol: np.ndarray,
dask: bool = False,
chunks: str ='auto',
**kwargs) -> np.ndarray:
"""
Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
......@@ -260,7 +269,10 @@ def median(vol, dask=False, chunks='auto', **kwargs):
return res
def maximum(vol, dask=False, chunks='auto', **kwargs):
def maximum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
"""
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
......@@ -284,7 +296,10 @@ def maximum(vol, dask=False, chunks='auto', **kwargs):
return res
def minimum(vol, dask=False, chunks='auto', **kwargs):
def minimum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
"""
Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
......@@ -307,7 +322,10 @@ def minimum(vol, dask=False, chunks='auto', **kwargs):
res = ndimage.minimum_filter(vol, **kwargs)
return res
def tophat(vol, dask=False, chunks='auto', **kwargs):
def tophat(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
"""
Remove background from the volume.
......
......@@ -142,7 +142,7 @@ def noise_object_collection(
object_shape: str = None,
seed: int = 0,
verbose: bool = False,
) -> tuple[np.ndarray, object]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Generate a 3D volume of multiple synthetic objects using Perlin noise.
......
......@@ -36,7 +36,7 @@ from qim3d.gui.interface import BaseInterface
class Interface(BaseInterface):
def __init__(self, name_suffix: str = "", verbose: bool = False, img=None):
def __init__(self, name_suffix: str = "", verbose: bool = False, img: np.ndarray = None):
super().__init__(
title="Annotation Tool",
height=768,
......@@ -55,7 +55,7 @@ class Interface(BaseInterface):
self.masks_rgb = None
self.temp_files = []
def get_result(self):
def get_result(self) -> dict:
# Get the temporary files from gradio
temp_path_list = []
for filename in os.listdir(self.temp_dir):
......@@ -95,13 +95,13 @@ class Interface(BaseInterface):
except FileNotFoundError:
files = None
def create_preview(self, img_editor):
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"]
masks = img_editor["layers"][0]
overlay_image = overlay_rgb_images(background, masks)
return overlay_image
def cerate_download_list(self, img_editor):
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
masks_rgb = img_editor["layers"][0]
mask_threshold = 200 # This value is based
......
......@@ -20,6 +20,7 @@ import os
import re
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import outputformat as ouf
......@@ -29,6 +30,8 @@ from qim3d.utils._logger import log
from qim3d.utils import _misc
from qim3d.gui.interface import BaseInterface
from typing import Callable, Any, Dict
import matplotlib
class Interface(BaseInterface):
......@@ -271,7 +274,7 @@ class Interface(BaseInterface):
operations.change(fn=self.show_results, inputs = operations, outputs = results)
cmap.change(fn=self.run_operations, inputs = pipeline_inputs, outputs = pipeline_outputs)
def update_explorer(self, new_path):
def update_explorer(self, new_path: str):
new_path = os.path.expanduser(new_path)
# In case we have a directory
......@@ -367,7 +370,7 @@ class Interface(BaseInterface):
except Exception as error_message:
self.error_message = F"Error when loading data: {error_message}"
def run_operations(self, operations, *args):
def run_operations(self, operations: list[str], *args) -> list[Dict[str, Any]]:
outputs = []
self.calculated_operations = []
for operation in self.all_operations:
......@@ -411,7 +414,7 @@ class Interface(BaseInterface):
case _:
raise NotImplementedError(F"Operation '{operation} is not defined")
def show_results(self, operations):
def show_results(self, operations: list[str]) -> list[Dict[str, Any]]:
update_list = []
for operation in self.all_operations:
if operation in operations and operation in self.calculated_operations:
......@@ -426,7 +429,7 @@ class Interface(BaseInterface):
#
#######################################################
def create_img_fig(self, img, **kwargs):
def create_img_fig(self, img: np.ndarray, **kwargs) -> matplotlib.figure.Figure:
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
ax.imshow(img, interpolation="nearest", **kwargs)
......@@ -437,8 +440,8 @@ class Interface(BaseInterface):
return fig
def update_slice_wrapper(self, letter):
def update_slice(position_slider:float, cmap:str):
def update_slice_wrapper(self, letter: str) -> Callable[[float, str], Dict[str, Any]]:
def update_slice(position_slider: float, cmap:str) -> Dict[str, Any]:
"""
position_slider: float from gradio slider, saying which relative slice we want to see
cmap: string gradio drop down menu, saying what cmap we want to use for display
......@@ -465,7 +468,7 @@ class Interface(BaseInterface):
return gr.update(value = fig_img, label = f"{letter} Slice: {slice_index}", visible = True)
return update_slice
def vol_histogram(self, nbins, min_value, max_value):
def vol_histogram(self, nbins: int, min_value: float, max_value: float) -> tuple[np.ndarray, np.ndarray]:
# Start histogram
vol_hist = np.zeros(nbins)
......@@ -478,7 +481,7 @@ class Interface(BaseInterface):
return vol_hist, bin_edges
def plot_histogram(self):
def plot_histogram(self) -> matplotlib.figure.Figure:
# The Histogram needs results from the projections
if not self.projections_calculated:
_ = self.get_projections()
......@@ -498,7 +501,7 @@ class Interface(BaseInterface):
return fig
def create_projections_figs(self):
def create_projections_figs(self) -> tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]:
if not self.projections_calculated:
projections = self.get_projections()
self.max_projection = projections[0]
......@@ -519,7 +522,7 @@ class Interface(BaseInterface):
self.projections_calculated = True
return max_projection_fig, min_projection_fig
def get_projections(self):
def get_projections(self) -> tuple[np.ndarray, np.ndarray]:
# Create arrays for iteration
max_projection = np.zeros(np.shape(self.vol[0]))
min_projection = np.ones(np.shape(self.vol[0])) * float("inf")
......
......@@ -6,6 +6,7 @@ import gradio as gr
from .qim_theme import QimTheme
import qim3d.gui
import numpy as np
# TODO: when offline it throws an error in cli
......@@ -48,10 +49,10 @@ class BaseInterface(ABC):
def set_invisible(self):
return gr.update(visible=False)
def change_visibility(self, is_visible):
def change_visibility(self, is_visible: bool):
return gr.update(visible = is_visible)
def launch(self, img=None, force_light_mode: bool = True, **kwargs):
def launch(self, img: np.ndarray = None, force_light_mode: bool = True, **kwargs):
"""
img: If None, user can upload image after the interface is launched.
If defined, the interface will be launched with the image already there
......@@ -76,7 +77,7 @@ class BaseInterface(ABC):
**kwargs,
)
def clear(self):
def clear(self) -> None:
"""Used to reset outputs with the clear button"""
return None
......
......@@ -44,7 +44,7 @@ class Interface(InterfaceWithExamples):
self.img = img
self.plot_height = plot_height
def load_data(self, gradiofile):
def load_data(self, gradiofile: gr.File):
try:
self.vol = load(gradiofile.name)
assert self.vol.ndim == 3
......@@ -55,7 +55,7 @@ class Interface(InterfaceWithExamples):
except AssertionError:
raise gr.Error(F"File has to be 3D structure. Your structure has {self.vol.ndim} dimension{'' if self.vol.ndim == 1 else 's'}")
def resize_vol(self, display_size):
def resize_vol(self, display_size: int):
"""Resizes the loaded volume to the display size"""
# Get original size
......@@ -80,32 +80,33 @@ class Interface(InterfaceWithExamples):
f"Resized volume: {self.display_size_z, self.display_size_y, self.display_size_x}"
)
def save_fig(self, fig, filename):
def save_fig(self, fig: go.Figure, filename: str):
# Write Plotly figure to disk
fig.write_html(filename)
def create_fig(self,
gradio_file,
display_size,
opacity,
opacityscale,
only_wireframe,
min_value,
max_value,
surface_count,
colormap,
show_colorbar,
reversescale,
flip_z,
show_axis,
show_ticks,
show_caps,
show_z_slice,
slice_z_location,
show_y_slice,
slice_y_location,
show_x_slice,
slice_x_location,):
gradio_file: gr.File,
display_size: int ,
opacity: float,
opacityscale: str,
only_wireframe: bool,
min_value: float,
max_value: float,
surface_count: int,
colormap: str,
show_colorbar: bool,
reversescale: bool,
flip_z: bool,
show_axis: bool,
show_ticks: bool,
show_caps: bool,
show_z_slice: bool,
slice_z_location: int,
show_y_slice: bool,
slice_y_location: int,
show_x_slice: bool,
slice_x_location: int,
) -> tuple[go.Figure, str]:
# Load volume
self.load_data(gradio_file)
......
......@@ -25,9 +25,11 @@ import numpy as np
from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d
from qim3d.processing import overlay_rgb_images, segment_layers, get_lines
from qim3d.processing import segment_layers, get_lines
from qim3d.operations import overlay_rgb_images
from qim3d.io import load
from qim3d.viz._layers2d import image_with_lines
from typing import Dict, Any
#TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error
......@@ -302,14 +304,14 @@ class Interface(BaseInterface):
def change_plot_type(self, plot_type, ):
def change_plot_type(self, plot_type: str, ) -> tuple[Dict[str, Any], Dict[str, Any]]:
self.plot_type = plot_type
if plot_type == 'Segmentation lines':
return gr.update(visible = False), gr.update(visible = True)
else:
return gr.update(visible = True), gr.update(visible = False)
def change_plot_size(self, x_check, y_check, z_check):
def change_plot_size(self, x_check: int, y_check: int, z_check: int) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""
Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
......@@ -319,10 +321,10 @@ class Interface(BaseInterface):
height = self.heights[index] # also used to define heights of plots in the begining
return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
def change_row_visibility(self, x_check, y_check, z_check):
def change_row_visibility(self, x_check: int, y_check: int, z_check: int):
return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
def update_explorer(self, new_path):
def update_explorer(self, new_path: str):
# Refresh the file explorer object
new_path = os.path.expanduser(new_path)
......@@ -341,13 +343,13 @@ class Interface(BaseInterface):
def set_relaunch_button(self):
return gr.update(value=f"Relaunch", interactive=True)
def set_spinner(self, message):
def set_spinner(self, message: str):
if self.error:
return gr.Button()
# spinner icon/shows the user something is happeing
return gr.update(value=f"{message}", interactive=False)
def load_data(self, base_path, explorer):
def load_data(self, base_path: str, explorer: str):
if base_path and os.path.isfile(base_path):
file_path = base_path
elif explorer and os.path.isfile(explorer):
......
......@@ -47,7 +47,7 @@ from qim3d.gui.interface import InterfaceWithExamples
class Interface(InterfaceWithExamples):
def __init__(self,
img = None,
img: np.ndarray = None,
verbose:bool = False,
plot_height:int = 768,
figsize:int = 6):
......@@ -248,7 +248,7 @@ class Interface(InterfaceWithExamples):
#
#######################################################
def process_input(self, data, dark_objects):
def process_input(self, data: np.ndarray, dark_objects: bool):
# Load volume
try:
self.vol = load(data.name)
......@@ -265,7 +265,7 @@ class Interface(InterfaceWithExamples):
self.vmin = np.min(self.vol)
self.vmax = np.max(self.vol)
def show_slice(self, vol, zpos, vmin=None, vmax=None, cmap="viridis"):
def show_slice(self, vol: np.ndarray, zpos: int, vmin: float = None, vmax: float = None, cmap: str = "viridis"):
plt.close()
z_idx = int(zpos * (vol.shape[0] - 1))
fig, ax = plt.subplots(figsize=(self.figsize, self.figsize))
......@@ -278,19 +278,19 @@ class Interface(InterfaceWithExamples):
return fig
def make_binary(self, threshold):
def make_binary(self, threshold: float):
# Make a binary volume
# Nothing fancy, but we could add new features here
self.vol_binary = self.vol > (threshold * np.max(self.vol))
def compute_localthickness(self, lt_scale):
def compute_localthickness(self, lt_scale: float):
self.vol_thickness = lt.local_thickness(self.vol_binary, lt_scale)
# Valus for visualization
self.vmin_lt = np.min(self.vol_thickness)
self.vmax_lt = np.max(self.vol_thickness)
def thickness_histogram(self, nbins):
def thickness_histogram(self, nbins: int):
# Ignore zero thickness
non_zero_values = self.vol_thickness[self.vol_thickness > 0]
......
......@@ -7,7 +7,7 @@ class QimTheme(gr.themes.Default):
there is a possibility to add some more css if you override _get_css_theme function as shown at the bottom
in comments.
"""
def __init__(self, force_light_mode:bool = True):
def __init__(self, force_light_mode: bool = True):
"""
Parameters:
-----------
......
......@@ -7,6 +7,7 @@ import numpy as np
import tifffile as tiff
import zarr
from tqdm import tqdm
import zarr.core
from qim3d.utils._misc import stringify_path
from qim3d.io._saving import save
......@@ -21,7 +22,7 @@ class Convert:
"""
self.chunk_shape = kwargs.get("chunk_shape", (64, 64, 64))
def convert(self, input_path, output_path):
def convert(self, input_path: str, output_path: str):
def get_file_extension(file_path):
root, ext = os.path.splitext(file_path)
if ext in ['.gz', '.bz2', '.xz']: # handle common compressed extensions
......@@ -67,7 +68,7 @@ class Convert:
else:
raise ValueError("Invalid path")
def convert_tif_to_zarr(self, tif_path, zarr_path):
def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file
Args:
......@@ -97,7 +98,7 @@ class Convert:
return z
def convert_zarr_to_tif(self, zarr_path, tif_path):
def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file
Args:
......@@ -110,7 +111,7 @@ class Convert:
z = zarr.open(zarr_path)
save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path, zarr_path):
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file
Args:
......@@ -139,7 +140,7 @@ class Convert:
return z
def convert_zarr_to_nifti(self, zarr_path, nifti_path, compression=False):
def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
"""Convert a zarr file to a nifti file
Args:
......@@ -153,7 +154,7 @@ class Convert:
save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)):
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
"""Convert a file to another format without loading the entire file into memory
Args:
......
......@@ -76,7 +76,7 @@ class Downloader:
[file_name_n](load_file,optional): Function to download file number n in the given folder.
"""
def __init__(self, folder):
def __init__(self, folder: str):
files = _extract_names(folder)
for idx, file in enumerate(files):
......@@ -88,7 +88,7 @@ class Downloader:
setattr(self, f'{file_name.split(".")[0]}', self._make_fn(folder, file))
def _make_fn(self, folder, file):
def _make_fn(self, folder: str, file: str):
"""Private method that returns a function. The function downloads the chosen file from the folder.
Args:
......@@ -101,7 +101,7 @@ class Downloader:
url_dl = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository"
def _download(load_file=False, virtual_stack=True):
def _download(load_file: bool = False, virtual_stack: bool = True):
"""Downloads the file and optionally also loads it.
Args:
......@@ -121,7 +121,7 @@ class Downloader:
return _download
def _update_progress(pbar, blocknum, bs):
def _update_progress(pbar: tqdm, blocknum: int, bs: int):
"""
Helper function for the ´download_file()´ function. Updates the progress bar.
"""
......@@ -129,7 +129,7 @@ def _update_progress(pbar, blocknum, bs):
pbar.update(blocknum * bs - pbar.n)
def _get_file_size(url):
def _get_file_size(url: str):
"""
Helper function for the ´download_file()´ function. Finds the size of the file.
"""
......@@ -137,7 +137,7 @@ def _get_file_size(url):
return int(urllib.request.urlopen(url).info().get("Content-Length", -1))
def download_file(path, name, file):
def download_file(path: str, name: str, file: str):
"""Downloads the file from path / name / file.
Args:
......@@ -177,7 +177,7 @@ def download_file(path, name, file):
)
def _extract_html(url):
def _extract_html(url: str):
"""Extracts the html content of a webpage in "utf-8"
Args:
......@@ -198,7 +198,7 @@ def _extract_html(url):
return html_content
def _extract_names(name=None):
def _extract_names(name: str = None):
"""Extracts the names of the folders and files.
Finds the names of either the folders if no name is given,
......
......@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
from qim3d.utils._progress_bar import FileLoadingProgressBar
import trimesh
from typing import Optional, Dict
dask.config.set(scheduler="processes")
......@@ -76,7 +78,7 @@ class DataLoader:
self.dim_order = kwargs.get("dim_order", (2, 1, 0))
self.PIL_extensions = (".jp2", ".jpg", "jpeg", ".png", "gif", ".bmp", ".webp")
def load_tiff(self, path):
def load_tiff(self, path: str|os.PathLike):
"""Load a TIFF file from the specified path.
Args:
......@@ -100,7 +102,7 @@ class DataLoader:
return vol
def load_h5(self, path):
def load_h5(self, path: str|os.PathLike) -> tuple[np.ndarray, Optional[Dict]]:
"""Load an HDF5 file from the specified path.
Args:
......@@ -183,7 +185,7 @@ class DataLoader:
else:
return vol
def load_tiff_stack(self, path):
def load_tiff_stack(self, path: str|os.PathLike) -> np.ndarray|np.memmap:
"""Load a stack of TIFF files from the specified path.
Args:
......@@ -237,7 +239,7 @@ class DataLoader:
return vol
def load_txrm(self, path):
def load_txrm(self, path: str|os.PathLike) -> tuple[dask.array.core.Array|np.ndarray, Optional[Dict]]:
"""Load a TXRM/XRM/TXM file from the specified path.
Args:
......@@ -308,7 +310,7 @@ class DataLoader:
else:
return vol
def load_nifti(self, path):
def load_nifti(self, path: str|os.PathLike):
"""Load a NIfTI file from the specified path.
Args:
......@@ -338,7 +340,7 @@ class DataLoader:
else:
return vol
def load_pil(self, path):
def load_pil(self, path: str|os.PathLike):
"""Load a PIL image from the specified path
Args:
......@@ -349,7 +351,7 @@ class DataLoader:
"""
return np.array(Image.open(path))
def load_PIL_stack(self, path):
def load_PIL_stack(self, path: str|os.PathLike):
"""Load a stack of PIL files from the specified path.
Args:
......@@ -433,7 +435,7 @@ class DataLoader:
def _load_vgi_metadata(self, path):
def _load_vgi_metadata(self, path: str|os.PathLike):
"""Helper functions that loads metadata from a VGI file
Args:
......@@ -482,7 +484,7 @@ class DataLoader:
return meta_data
def load_vol(self, path):
def load_vol(self, path: str|os.PathLike):
"""Load a VOL filed based on the VGI metadata file
Args:
......@@ -548,7 +550,7 @@ class DataLoader:
else:
return vol
def load_dicom(self, path):
def load_dicom(self, path: str|os.PathLike):
"""Load a DICOM file
Args:
......@@ -563,7 +565,7 @@ class DataLoader:
else:
return dcm_data.pixel_array
def load_dicom_dir(self, path):
def load_dicom_dir(self, path: str|os.PathLike):
"""Load a directory of DICOM files into a numpy 3d array
Args:
......@@ -605,7 +607,7 @@ class DataLoader:
return vol
def load_zarr(self, path: str):
def load_zarr(self, path: str|os.PathLike):
""" Loads a Zarr array from disk.
Args:
......@@ -654,7 +656,7 @@ class DataLoader:
message + " Set 'force_load=True' to ignore this error."
)
def load(self, path):
def load(self, path: str|os.PathLike):
"""
Load a file or directory based on the given path.
......@@ -757,16 +759,16 @@ def _get_ole_offsets(ole):
def load(
path,
virtual_stack=False,
dataset_name=None,
return_metadata=False,
contains=None,
progress_bar:bool = True,
path: str|os.PathLike,
virtual_stack: bool = False,
dataset_name: bool = None,
return_metadata: bool = False,
contains: bool = None,
progress_bar: bool = True,
force_load: bool = False,
dim_order=(2, 1, 0),
dim_order: tuple = (2, 1, 0),
**kwargs,
):
) -> np.ndarray:
"""
Load data from the specified file or directory.
......@@ -854,7 +856,7 @@ def load(
return data
def load_mesh(filename):
def load_mesh(filename: str) -> trimesh.Trimesh:
"""
Load a mesh from an .obj file using trimesh.
......
......@@ -46,13 +46,13 @@ class OMEScaler(
"""Scaler in the style of OME-Zarr.
This is needed because their current zoom implementation is broken."""
def __init__(self, order=0, downscale=2, max_layer=5, method="scaleZYXdask"):
def __init__(self, order: int = 0, downscale: float = 2, max_layer: int = 5, method: str = "scaleZYXdask"):
self.order = order
self.downscale = downscale
self.max_layer = max_layer
self.method = method
def scaleZYX(self, base):
def scaleZYX(self, base: da.core.Array):
"""Downsample using :func:`scipy.ndimage.zoom`."""
rv = [base]
log.info(f"- Scale 0: {rv[-1].shape}")
......@@ -63,7 +63,7 @@ class OMEScaler(
return list(rv)
def scaleZYXdask(self, base):
def scaleZYXdask(self, base: da.core.Array):
"""
Downsample a 3D volume using Dask and scipy.ndimage.zoom.
......@@ -82,7 +82,7 @@ class OMEScaler(
"""
def resize_zoom(vol, scale_factors, order, scaled_shape):
def resize_zoom(vol: da.core.Array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library
......@@ -181,16 +181,16 @@ class OMEScaler(
def export_ome_zarr(
path,
data,
chunk_size=256,
downsample_rate=2,
order=1,
replace=False,
method="scaleZYX",
path: str|os.PathLike,
data: np.ndarray|da.core.Array,
chunk_size: int = 256,
downsample_rate: int = 2,
order: int = 1,
replace: bool = False,
method: str = "scaleZYX",
progress_bar: bool = True,
progress_bar_repeat_time="auto",
):
progress_bar_repeat_time: str = "auto",
) -> None:
"""
Export 3D image data to OME-Zarr format with pyramidal downsampling.
......@@ -299,7 +299,11 @@ def export_ome_zarr(
return
def import_ome_zarr(path, scale=0, load=True):
def import_ome_zarr(
path: str|os.PathLike,
scale: int = 0,
load: bool = True
) -> np.ndarray:
"""
Import image data from an OME-Zarr file.
......
......@@ -76,7 +76,7 @@ class DataSaver:
self.sliced_dim = kwargs.get("sliced_dim", 0)
self.chunk_shape = kwargs.get("chunk_shape", "auto")
def save_tiff(self, path, data):
def save_tiff(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a TIFF file to the given path.
Args:
......@@ -85,7 +85,7 @@ class DataSaver:
"""
tifffile.imwrite(path, data, compression=self.compression)
def save_tiff_stack(self, path, data):
def save_tiff_stack(self, path: str|os.PathLike, data: np.ndarray):
"""Save data as a TIFF stack containing slices in separate files to the given path.
The slices will be named according to the basename plus a suffix with a zero-filled
value corresponding to the slice number
......@@ -124,7 +124,7 @@ class DataSaver:
f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
)
def save_nifti(self, path, data):
def save_nifti(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a NIfTI file to the given path.
Args:
......@@ -154,7 +154,7 @@ class DataSaver:
# Save image
nib.save(img, path)
def save_vol(self, path, data):
def save_vol(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a VOL file to the given path.
Args:
......@@ -200,7 +200,7 @@ class DataSaver:
"dataset", data=data, compression="gzip" if self.compression else None
)
def save_dicom(self, path, data):
def save_dicom(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a DICOM file to the given path.
Args:
......@@ -255,7 +255,7 @@ class DataSaver:
ds.save_as(path)
def save_to_zarr(self, path, data):
def save_to_zarr(self, path: str|os.PathLike, data: da.core.Array):
"""Saves a Dask array to a Zarr array on disk.
Args:
......@@ -284,7 +284,7 @@ class DataSaver:
)
zarr_array[:] = data
def save_PIL(self, path, data):
def save_PIL(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to a PIL file to the given path.
Args:
......@@ -303,7 +303,7 @@ class DataSaver:
# Save image
img.save(path)
def save(self, path, data):
def save(self, path: str|os.PathLike, data: np.ndarray):
"""Save data to the given path.
Args:
......@@ -401,15 +401,15 @@ class DataSaver:
def save(
path,
data,
replace=False,
compression=False,
basename=None,
sliced_dim=0,
chunk_shape="auto",
path: str|os.PathLike,
data: np.ndarray,
replace: bool = False,
compression: bool = False,
basename: bool = None,
sliced_dim: int = 0,
chunk_shape: str = "auto",
**kwargs,
):
) -> None:
"""Save data to a specified file path.
Args:
......@@ -464,7 +464,10 @@ def save(
).save(path, data)
def save_mesh(filename, mesh):
def save_mesh(
filename: str,
mesh: trimesh.Trimesh
) -> None:
"""
Save a trimesh object to an .obj file.
......
......@@ -28,7 +28,7 @@ class Sync:
return False
def check_destination(self, source, destination, checksum=False, verbose=True):
def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
"""Check if all files from 'source' are in 'destination'
This function compares the files in the 'source' directory to those in
......@@ -80,7 +80,7 @@ class Sync:
return diff_files
def compare_dirs(self, source, destination, checksum=False, verbose=True):
def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
"""Checks whether 'source' and 'destination' directories are synchronized.
This function compares the contents of two directories
......@@ -168,7 +168,7 @@ class Sync:
)
return
def count_files_and_dirs(self, path, verbose=True):
def count_files_and_dirs(self, path: str|os.PathLike, verbose: bool = True) -> tuple[int, int]:
"""Count the number of files and directories in the given path.
This function recursively counts the number of files and
......
......@@ -8,8 +8,8 @@ from qim3d.utils._logger import log
def from_volume(
volume: np.ndarray,
level: float = None,
step_size=1,
allow_degenerate=False,
step_size: int = 1,
allow_degenerate: bool = False,
padding: Tuple[int, int, int] = (2, 2, 2),
**kwargs: Any,
) -> trimesh.Trimesh:
......
......@@ -20,10 +20,10 @@ class Augmentation:
"""
def __init__(self,
resize = 'crop',
transform_train = 'moderate',
transform_validation = None,
transform_test = None,
resize: str = 'crop',
transform_train: str = 'moderate',
transform_validation: str | None = None,
transform_test: str | None = None,
mean: float = 0.5,
std: float = 0.5
):
......@@ -38,7 +38,7 @@ class Augmentation:
self.transform_validation = transform_validation
self.transform_test = transform_test
def augment(self, im_h, im_w, level=None):
def augment(self, im_h: int, im_w: int, level: bool | None = None):
"""
Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
......