Skip to content
Snippets Groups Projects
Commit 24a63566 authored by s214735's avatar s214735
Browse files

Second check of all

parent 6b3f808f
Branches
No related tags found
1 merge request!143Type hints
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import tifffile as tiff import tifffile as tiff
import zarr import zarr
from tqdm import tqdm from tqdm import tqdm
import zarr.core
from qim3d.utils._misc import stringify_path from qim3d.utils._misc import stringify_path
from qim3d.io._saving import save from qim3d.io._saving import save
...@@ -67,7 +68,7 @@ class Convert: ...@@ -67,7 +68,7 @@ class Convert:
else: else:
raise ValueError("Invalid path") raise ValueError("Invalid path")
def convert_tif_to_zarr(self, tif_path: str, zarr_path: str): def convert_tif_to_zarr(self, tif_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a tiff file to a zarr file """Convert a tiff file to a zarr file
Args: Args:
...@@ -97,7 +98,7 @@ class Convert: ...@@ -97,7 +98,7 @@ class Convert:
return z return z
def convert_zarr_to_tif(self, zarr_path: str, tif_path: str): def convert_zarr_to_tif(self, zarr_path: str, tif_path: str) -> None:
"""Convert a zarr file to a tiff file """Convert a zarr file to a tiff file
Args: Args:
...@@ -110,7 +111,7 @@ class Convert: ...@@ -110,7 +111,7 @@ class Convert:
z = zarr.open(zarr_path) z = zarr.open(zarr_path)
save(tif_path, z) save(tif_path, z)
def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str): def convert_nifti_to_zarr(self, nifti_path: str, zarr_path: str) -> zarr.core.Array:
"""Convert a nifti file to a zarr file """Convert a nifti file to a zarr file
Args: Args:
...@@ -139,7 +140,7 @@ class Convert: ...@@ -139,7 +140,7 @@ class Convert:
return z return z
def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False): def convert_zarr_to_nifti(self, zarr_path: str, nifti_path: str, compression: bool = False) -> None:
"""Convert a zarr file to a nifti file """Convert a zarr file to a nifti file
Args: Args:
...@@ -153,7 +154,7 @@ class Convert: ...@@ -153,7 +154,7 @@ class Convert:
save(nifti_path, z, compression=compression) save(nifti_path, z, compression=compression)
def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)): def convert(input_path: str, output_path: str, chunk_shape: tuple = (64, 64, 64)) -> None:
"""Convert a file to another format without loading the entire file into memory """Convert a file to another format without loading the entire file into memory
Args: Args:
......
...@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory ...@@ -29,6 +29,8 @@ from qim3d.utils._system import Memory
from qim3d.utils._progress_bar import FileLoadingProgressBar from qim3d.utils._progress_bar import FileLoadingProgressBar
import trimesh import trimesh
from typing import Optional, Dict
dask.config.set(scheduler="processes") dask.config.set(scheduler="processes")
...@@ -100,7 +102,7 @@ class DataLoader: ...@@ -100,7 +102,7 @@ class DataLoader:
return vol return vol
def load_h5(self, path: str): def load_h5(self, path: str) -> tuple[np.ndarray, Optional[Dict]]:
"""Load an HDF5 file from the specified path. """Load an HDF5 file from the specified path.
Args: Args:
...@@ -183,7 +185,7 @@ class DataLoader: ...@@ -183,7 +185,7 @@ class DataLoader:
else: else:
return vol return vol
def load_tiff_stack(self, path: str): def load_tiff_stack(self, path: str) -> np.ndarray|np.memmap:
"""Load a stack of TIFF files from the specified path. """Load a stack of TIFF files from the specified path.
Args: Args:
...@@ -237,7 +239,7 @@ class DataLoader: ...@@ -237,7 +239,7 @@ class DataLoader:
return vol return vol
def load_txrm(self, path: str): def load_txrm(self, path: str) -> tuple[dask.array|np.ndarray, Optional[Dict]]:
"""Load a TXRM/XRM/TXM file from the specified path. """Load a TXRM/XRM/TXM file from the specified path.
Args: Args:
...@@ -766,7 +768,7 @@ def load( ...@@ -766,7 +768,7 @@ def load(
force_load: bool = False, force_load: bool = False,
dim_order: tuple = (2, 1, 0), dim_order: tuple = (2, 1, 0),
**kwargs, **kwargs,
): ) -> np.ndarray:
""" """
Load data from the specified file or directory. Load data from the specified file or directory.
...@@ -854,7 +856,7 @@ def load( ...@@ -854,7 +856,7 @@ def load(
return data return data
def load_mesh(filename: str): def load_mesh(filename: str) -> trimesh.Trimesh:
""" """
Load a mesh from an .obj file using trimesh. Load a mesh from an .obj file using trimesh.
......
...@@ -82,7 +82,7 @@ class OMEScaler( ...@@ -82,7 +82,7 @@ class OMEScaler(
""" """
def resize_zoom(vol, scale_factors, order, scaled_shape): def resize_zoom(vol: dask.array, scale_factors, order, scaled_shape):
# Get the chunksize needed so that all the blocks match the new shape # Get the chunksize needed so that all the blocks match the new shape
# This snippet comes from the original OME-Zarr-python library # This snippet comes from the original OME-Zarr-python library
...@@ -182,7 +182,7 @@ class OMEScaler( ...@@ -182,7 +182,7 @@ class OMEScaler(
def export_ome_zarr( def export_ome_zarr(
path: str, path: str,
data: np.ndarray, data: np.ndarray|dask.array,
chunk_size: int = 256, chunk_size: int = 256,
downsample_rate: int = 2, downsample_rate: int = 2,
order: int = 1, order: int = 1,
......
...@@ -28,7 +28,7 @@ class Sync: ...@@ -28,7 +28,7 @@ class Sync:
return False return False
def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True): def check_destination(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> list[str]:
"""Check if all files from 'source' are in 'destination' """Check if all files from 'source' are in 'destination'
This function compares the files in the 'source' directory to those in This function compares the files in the 'source' directory to those in
...@@ -80,7 +80,7 @@ class Sync: ...@@ -80,7 +80,7 @@ class Sync:
return diff_files return diff_files
def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True): def compare_dirs(self, source: str, destination: str, checksum: bool = False, verbose: bool = True) -> None:
"""Checks whether 'source' and 'destination' directories are synchronized. """Checks whether 'source' and 'destination' directories are synchronized.
This function compares the contents of two directories This function compares the contents of two directories
...@@ -168,7 +168,7 @@ class Sync: ...@@ -168,7 +168,7 @@ class Sync:
) )
return return
def count_files_and_dirs(self, path: str, verbose: bool = True): def count_files_and_dirs(self, path: str, verbose: bool = True) -> tuple[int, int]:
"""Count the number of files and directories in the given path. """Count the number of files and directories in the given path.
This function recursively counts the number of files and This function recursively counts the number of files and
......
...@@ -20,10 +20,10 @@ class Augmentation: ...@@ -20,10 +20,10 @@ class Augmentation:
""" """
def __init__(self, def __init__(self,
resize = 'crop', resize: str = 'crop',
transform_train = 'moderate', transform_train: str = 'moderate',
transform_validation = None, transform_validation: str | None = None,
transform_test = None, transform_test: str | None = None,
mean: float = 0.5, mean: float = 0.5,
std: float = 0.5 std: float = 0.5
): ):
...@@ -38,7 +38,7 @@ class Augmentation: ...@@ -38,7 +38,7 @@ class Augmentation:
self.transform_validation = transform_validation self.transform_validation = transform_validation
self.transform_test = transform_test self.transform_test = transform_test
def augment(self, im_h: int, im_w: int, level: bool = None): def augment(self, im_h: int, im_w: int, level: bool | None = None):
""" """
Returns an albumentations.core.composition.Compose class depending on the augmentation level. Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
......
...@@ -101,7 +101,7 @@ class Dataset(torch.utils.data.Dataset): ...@@ -101,7 +101,7 @@ class Dataset(torch.utils.data.Dataset):
return Image.open(str(image_path)).size return Image.open(str(image_path)).size
def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): def check_resize(im_height: int, im_width: int, resize: str, n_channels: int) -> tuple[int, int]:
""" """
Checks the compatibility of the image shape with the depth of the model. Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate. If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
...@@ -133,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): ...@@ -133,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int):
return h_adjust, w_adjust return h_adjust, w_adjust
def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation): def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
""" """
Splits and augments the train/validation/test datasets. Splits and augments the train/validation/test datasets.
...@@ -180,7 +180,7 @@ def prepare_dataloaders(train_set: torch.utils.data, ...@@ -180,7 +180,7 @@ def prepare_dataloaders(train_set: torch.utils.data,
batch_size: int, batch_size: int,
shuffle_train: bool = True, shuffle_train: bool = True,
num_workers: int = 8, num_workers: int = 8,
pin_memory: bool = False): pin_memory: bool = False) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
""" """
Prepares the dataloaders for model training. Prepares the dataloaders for model training.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
import numpy as np import numpy as np
from torchinfo import summary from torchinfo import summary, ModelStatistics
from qim3d.utils._logger import log from qim3d.utils._logger import log
from qim3d.viz._metrics import plot_metrics from qim3d.viz._metrics import plot_metrics
...@@ -137,7 +137,7 @@ def train_model( ...@@ -137,7 +137,7 @@ def train_model(
return train_loss, val_loss return train_loss, val_loss
def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module): def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module) -> ModelStatistics:
"""Prints the summary of a PyTorch model. """Prints the summary of a PyTorch model.
Args: Args:
......
...@@ -23,7 +23,7 @@ class CC: ...@@ -23,7 +23,7 @@ class CC:
""" """
return self.cc_count return self.cc_count
def get_cc(self, index=None, crop=False): def get_cc(self, index: int|None =None, crop: bool=False) -> np.ndarray:
""" """
Get the connected component with the given index, if index is None selects a random component. Get the connected component with the given index, if index is None selects a random component.
...@@ -52,7 +52,7 @@ class CC: ...@@ -52,7 +52,7 @@ class CC:
return volume return volume
def get_bounding_box(self, index=None): def get_bounding_box(self, index: int|None =None)-> list[tuple]:
"""Get the bounding boxes of the connected components. """Get the bounding boxes of the connected components.
Args: Args:
......
...@@ -298,7 +298,7 @@ def slices_grid( ...@@ -298,7 +298,7 @@ def slices_grid(
return fig return fig
def _get_slice_range(position: int, num_slices: int, n_total) -> np.ndarray: def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray:
"""Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
start_idx = position - num_slices // 2 start_idx = position - num_slices // 2
end_idx = ( end_idx = (
...@@ -856,14 +856,14 @@ def histogram( ...@@ -856,14 +856,14 @@ def histogram(
log_scale: bool = False, log_scale: bool = False,
despine: bool = True, despine: bool = True,
show_title: bool = True, show_title: bool = True,
color="qim3d", color: str = "qim3d",
edgecolor=None, edgecolor: str|None = None,
figsize=(8, 4.5), figsize: tuple[float, float] = (8, 4.5),
element="step", element: str = "step",
return_fig=False, return_fig: bool = False,
show=True, show: bool = True,
**sns_kwargs, **sns_kwargs,
) -> Optional[matplotlib.figure.Figure]: ) -> None|matplotlib.figure.Figure:
""" """
Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
......
...@@ -22,8 +22,8 @@ def volumetric( ...@@ -22,8 +22,8 @@ def volumetric(
grid_visible: bool = False, grid_visible: bool = False,
color_map: str = 'magma', color_map: str = 'magma',
constant_opacity: bool = False, constant_opacity: bool = False,
vmin: float = None, vmin: float|None = None,
vmax: float = None, vmax: float|None = None,
samples: int|str = "auto", samples: int|str = "auto",
max_voxels: int = 512**3, max_voxels: int = 512**3,
data_type: str = "scaled_float16", data_type: str = "scaled_float16",
......
...@@ -13,7 +13,7 @@ def plot_metrics( ...@@ -13,7 +13,7 @@ def plot_metrics(
*metrics: tuple[dict[str, float]], *metrics: tuple[dict[str, float]],
linestyle: str = "-", linestyle: str = "-",
batch_linestyle: str = "dotted", batch_linestyle: str = "dotted",
labels: list = None, labels: list|None = None,
figsize: tuple = (16, 6), figsize: tuple = (16, 6),
show: bool = False show: bool = False
): ):
......
...@@ -19,9 +19,9 @@ def vectors( ...@@ -19,9 +19,9 @@ def vectors(
vec: np.ndarray, vec: np.ndarray,
axis: int = 0, axis: int = 0,
volume_cmap:str = 'grey', volume_cmap:str = 'grey',
vmin:float = None, vmin: float|None = None,
vmax:float = None, vmax: float|None = None,
slice_idx: Optional[Union[int, float]] = None, slice_idx: Union[int, float]|None = None,
grid_size: int = 10, grid_size: int = 10,
interactive: bool = True, interactive: bool = True,
figsize: Tuple[int, int] = (10, 5), figsize: Tuple[int, int] = (10, 5),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment