diff --git a/qim3d/io/_convert.py b/qim3d/io/_convert.py index 9244ef88e8aa73e19ad98c237d37299e85460b0c..0d776822689430b733dfa1467ecdff0b48d87bb1 100644 --- a/qim3d/io/_convert.py +++ b/qim3d/io/_convert.py @@ -7,6 +7,7 @@ import numpy as np import tifffile as tiff import zarr from tqdm import tqdm +import zarr.core from qim3d.utils._misc import stringify_path from qim3d.io._saving import save @@ -67,7 +68,7 @@ class Convert: else: raise ValueError("Invalid path") - def convert_tif_to_zarr(self, tif_path: str, zarr_path: str): + def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array: """Convert a tiff file to a zarr file Args: @@ -97,7 +98,7 @@ class Convert: return z - def convert_zarr_to_tif(self, zarr_path: str, tif_path: str): + def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None: """Convert a zarr file to a tiff file Args: @@ -110,7 +111,7 @@ class Convert: z = zarr.open(zarr_path) save(tif_path, z) - def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str): + def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array: """Convert a nifti file to a zarr file Args: @@ -139,7 +140,7 @@ class Convert: return z - def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False): + def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None: """Convert a zarr file to a nifti file Args: @@ -153,7 +154,7 @@ class Convert: save(nifti_path, z, compression=compression) -def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)): +def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None: """Convert a file to another format without loading the entire file into memory Args: diff --git a/qim3d/io/_loading.py b/qim3d/io/_loading.py index b558c3f4eabf7a64f819b26abb9c073eb709cda6..0f4641e7708bb36178eb98d1041f30e9ee420862 100644 --- a/qim3d/io/_loading.py +++ b/qim3d/io/_loading.py @@ -29,6 +29,8 @@ from qim3d.utils._system import Memory from qim3d.utils._progress_bar import FileLoadingProgressBar import trimesh +from typing import Optional, Dict + dask.config.set(scheduler="processes") @@ -100,7 +102,7 @@ class DataLoader: return vol - def load_h5(self, path: str): + def load_h5(self, path: str) -> tuple[np.ndarray, Optional[Dict]]: """Load an HDF5 file from the specified path. Args: @@ -183,7 +185,7 @@ class DataLoader: else: return vol - def load_tiff_stack(self, path: str): + def load_tiff_stack(self, path: str) -> np.ndarray|np.memmap: """Load a stack of TIFF files from the specified path. Args: @@ -237,7 +239,7 @@ class DataLoader: return vol - def load_txrm(self, path: str): + def load_txrm(self, path: str) -> tuple[dask.array|np.ndarray, Optional[Dict]]: """Load a TXRM/XRM/TXM file from the specified path. Args: @@ -766,7 +768,7 @@ def load( force_load: bool = False, dim_order: tuple = (2, 1, 0), **kwargs, -): +) -> np.ndarray: """ Load data from the specified file or directory. @@ -854,7 +856,7 @@ def load( return data -def load_mesh(filename: str): +def load_mesh(filename: str) -> trimesh.Trimesh: """ Load a mesh from an .obj file using trimesh. diff --git a/qim3d/io/_ome_zarr.py b/qim3d/io/_ome_zarr.py index 36ec5b800f5bad7e91dc679cec6a86e97a6456f0..2481f2ba09d68b77d88ffc7448c21de84d9bdb92 100644 --- a/qim3d/io/_ome_zarr.py +++ b/qim3d/io/_ome_zarr.py @@ -82,7 +82,7 @@ class OMEScaler( """ - def resize_zoom(vol, scale_factors, order, scaled_shape): + def resize_zoom(vol: dask.array, scale_factors, order, scaled_shape): # Get the chunksize needed so that all the blocks match the new shape # This snippet comes from the original OME-Zarr-python library @@ -182,7 +182,7 @@ class OMEScaler( def export_ome_zarr( path: str, - data: np.ndarray, + data: np.ndarray|dask.array, chunk_size: int = 256, downsample_rate: int = 2, order: int = 1, diff --git a/qim3d/io/_sync.py b/qim3d/io/_sync.py index 5cc4b5e566490f3efa586356f7e21d0a1c920c57..1119907206e66742fd0d9a740bda47ca65de2743 100644 --- a/qim3d/io/_sync.py +++ b/qim3d/io/_sync.py @@ -28,7 +28,7 @@ class Sync: return False - def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True): + def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]: """Check if all files from 'source' are in 'destination' This function compares the files in the 'source' directory to those in @@ -80,7 +80,7 @@ class Sync: return diff_files - def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True): + def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None: """Checks whether 'source' and 'destination' directories are synchronized. This function compares the contents of two directories @@ -168,7 +168,7 @@ class Sync: ) return - def count_files_and_dirs(self, path: str, verbose: bool = True): + def count_files_and_dirs(self, path: str, verbose: bool = True) -> tuple[int, int]: """Count the number of files and directories in the given path. This function recursively counts the number of files and diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index 8cb0b472af009b9801c88d1b0c18c0b56c78a415..e12b1a51849593e0c14287f0c2ab1795eb86ace9 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -20,10 +20,10 @@ class Augmentation: """ def __init__(self, - resize = 'crop', - transform_train = 'moderate', - transform_validation = None, - transform_test = None, + resize: str = 'crop', + transform_train: str = 'moderate', + transform_validation: str | None = None, + transform_test: str | None = None, mean: float = 0.5, std: float = 0.5 ): @@ -38,7 +38,7 @@ class Augmentation: self.transform_validation = transform_validation self.transform_test = transform_test - def augment(self, im_h: int, im_w: int, level: bool = None): + def augment(self, im_h: int, im_w: int, level: bool | None = None): """ Returns an albumentations.core.composition.Compose class depending on the augmentation level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 17f9a281eb8b5dd71388d3438efbdd103d42b2cf..d54f1bf6d7a457457b8fea638fd8057410d2a853 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -101,7 +101,7 @@ class Dataset(torch.utils.data.Dataset): return Image.open(str(image_path)).size -def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): +def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]: """ Checks the compatibility of the image shape with the depth of the model. If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate. @@ -133,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): return h_adjust, w_adjust -def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation): +def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]: """ Splits and augments the train/validation/test datasets. @@ -180,7 +180,7 @@ def prepare_dataloaders(train_set: torch.utils.data, batch_size: int, shuffle_train: bool = True, num_workers: int = 8, - pin_memory: bool = False): + pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """ Prepares the dataloaders for model training. diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index 1c4427919492b7dfff37372df94a7086bc1eac10..f46a7481fe93cd299eb43566bd947c7ea3355daf 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -3,7 +3,7 @@ import torch import numpy as np -from torchinfo import summary +from torchinfo import summary, ModelStatistics from qim3d.utils._logger import log from qim3d.viz._metrics import plot_metrics @@ -137,7 +137,7 @@ def train_model( return train_loss, val_loss -def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module): +def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module) -> ModelStatistics: """Prints the summary of a PyTorch model. Args: diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py index 1f29b23579e4aa9fbe5fa92e2a3045ff4d50317d..45725feec526e0a31de67a0737f99d6f955d2333 100644 --- a/qim3d/segmentation/_connected_components.py +++ b/qim3d/segmentation/_connected_components.py @@ -23,7 +23,7 @@ class CC: """ return self.cc_count - def get_cc(self, index=None, crop=False): + def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray: """ Get the connected component with the given index, if index is None selects a random component. @@ -52,7 +52,7 @@ class CC: return volume - def get_bounding_box(self, index=None): + def get_bounding_box(self, index: int|None =None)-> list[tuple]: """Get the bounding boxes of the connected components. Args: diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index 766aae157505fa7bccf72e6bfaa57b5adee02387..e03f08526b176ff61d878cf11876473c7a36577c 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -298,7 +298,7 @@ def slices_grid( return fig -def _get_slice_range(position: int, num_slices: int, n_total) -> np.ndarray: +def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray: """Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" start_idx = position - num_slices // 2 end_idx = ( @@ -856,14 +856,14 @@ def histogram( log_scale: bool = False, despine: bool = True, show_title: bool = True, - color="qim3d", - edgecolor=None, - figsize=(8, 4.5), - element="step", - return_fig=False, - show=True, + color: str = "qim3d", + edgecolor: str|None = None, + figsize: tuple[float, float] = (8, 4.5), + element: str = "step", + return_fig: bool = False, + show: bool = True, **sns_kwargs, -) -> Optional[matplotlib.figure.Figure]: +) -> None|matplotlib.figure.Figure: """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py index 8b2c23accd93cdd19f7c263e273031f7764c1463..b45a138ccaf80682851dc53bff770a29e11c6172 100644 --- a/qim3d/viz/_k3d.py +++ b/qim3d/viz/_k3d.py @@ -22,8 +22,8 @@ def volumetric( grid_visible: bool = False, color_map: str = 'magma', constant_opacity: bool = False, - vmin: float = None, - vmax: float = None, + vmin: float|None = None, + vmax: float|None = None, samples: int|str = "auto", max_voxels: int = 512**3, data_type: str = "scaled_float16", diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py index ded37d240add0f7a9add57fc1f19ebb39dabfe9b..16d7c7ce94bf427e82c20221d6226333f11b1bd5 100644 --- a/qim3d/viz/_metrics.py +++ b/qim3d/viz/_metrics.py @@ -13,7 +13,7 @@ def plot_metrics( *metrics: tuple[dict[str, float]], linestyle: str = "-", batch_linestyle: str = "dotted", - labels: list = None, + labels: list|None = None, figsize: tuple = (16, 6), show: bool = False ): diff --git a/qim3d/viz/_structure_tensor.py b/qim3d/viz/_structure_tensor.py index 63d141c41ae9dd365e89a9d82c959d08173d76a6..8148810287fe206334cf1f1e72103e7b3db8b74e 100644 --- a/qim3d/viz/_structure_tensor.py +++ b/qim3d/viz/_structure_tensor.py @@ -19,9 +19,9 @@ def vectors( vec: np.ndarray, axis: int = 0, volume_cmap:str = 'grey', - vmin:float = None, - vmax:float = None, - slice_idx: Optional[Union[int, float]] = None, + vmin: float|None = None, + vmax: float|None = None, + slice_idx: Union[int, float]|None = None, grid_size: int = 10, interactive: bool = True, figsize: Tuple[int, int] = (10, 5),