From 58c2f5018e66d394cc6607d88b1fbabf276884df Mon Sep 17 00:00:00 2001 From: s214735 <s214735@dtu.dk> Date: Thu, 26 Dec 2024 10:44:58 +0100 Subject: [PATCH] Checked rest of files --- qim3d/ml/_augmentations.py | 2 +- qim3d/ml/_data.py | 20 +++++++--- qim3d/ml/_ml_utils.py | 26 ++++++------- qim3d/ml/models/_unet.py | 37 ++++++++++-------- .../operations/_common_operations_methods.py | 2 +- qim3d/processing/_layers.py | 9 ++++- qim3d/processing/_local_thickness.py | 2 +- qim3d/processing/_structure_tensor.py | 2 +- qim3d/segmentation/_connected_components.py | 2 +- qim3d/utils/_doi.py | 18 ++++----- qim3d/utils/_misc.py | 20 +++++----- qim3d/utils/_ome_zarr.py | 4 +- qim3d/utils/_progress_bar.py | 8 ++-- qim3d/utils/_server.py | 4 +- qim3d/utils/_system.py | 4 +- qim3d/viz/_cc.py | 12 +++--- qim3d/viz/_data_exploration.py | 20 +++++----- qim3d/viz/_detection.py | 2 +- qim3d/viz/_k3d.py | 38 +++++++++---------- qim3d/viz/_layers2d.py | 2 +- qim3d/viz/_metrics.py | 37 ++++++++++-------- 21 files changed, 150 insertions(+), 121 deletions(-) diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index ea81e53a..8cb0b472 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -38,7 +38,7 @@ class Augmentation: self.transform_validation = transform_validation self.transform_test = transform_test - def augment(self, im_h, im_w, level=None): + def augment(self, im_h: int, im_w: int, level: bool = None): """ Returns an albumentations.core.composition.Compose class depending on the augmentation level. A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level. diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 38fbdac7..17f9a281 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -4,7 +4,9 @@ from PIL import Image from qim3d.utils._logger import log import torch import numpy as np - +from typing import Optional, Callable +import torch.nn as nn +from ._data import Augmentation class Dataset(torch.utils.data.Dataset): """ @@ -36,7 +38,7 @@ class Dataset(torch.utils.data.Dataset): transform=albumentations.Compose([ToTensorV2()])) image, target = dataset[idx] """ - def __init__(self, root_path: str, split="train", transform=None): + def __init__(self, root_path: str, split: str = "train", transform: Optional[Callable] = None): super().__init__() # Check if split is valid @@ -58,7 +60,7 @@ class Dataset(torch.utils.data.Dataset): def __len__(self): return len(self.sample_images) - def __getitem__(self, idx): + def __getitem__(self, idx: int): image_path = self.sample_images[idx] target_path = self.sample_targets[idx] @@ -76,7 +78,7 @@ class Dataset(torch.utils.data.Dataset): # TODO: working with images of different sizes - def check_shape_consistency(self,sample_images): + def check_shape_consistency(self,sample_images: tuple[str]): image_shapes= [] for image_path in sample_images: image_shape = self._get_shape(image_path) @@ -131,7 +133,7 @@ def check_resize(im_height: int, im_width: int, resize: str, n_channels: int): return h_adjust, w_adjust -def prepare_datasets(path: str, val_fraction: float, model, augmentation): +def prepare_datasets(path: str, val_fraction: float, model: nn.Module, augmentation: Augmentation): """ Splits and augments the train/validation/test datasets. @@ -172,7 +174,13 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): return train_set, val_set, test_set -def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False): +def prepare_dataloaders(train_set: torch.utils.data, + val_set: torch.utils.data, + test_set: torch.utils.data, + batch_size: int, + shuffle_train: bool = True, + num_workers: int = 8, + pin_memory: bool = False): """ Prepares the dataloaders for model training. diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index 7c8a3b86..1c442791 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -9,18 +9,18 @@ from qim3d.viz._metrics import plot_metrics from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm - +from models._unet import Hyperparameters def train_model( - model, - hyperparameters, - train_loader, - val_loader, - eval_every=1, - print_every=5, - plot=True, - return_loss=False, -): + model: torch.nn.Module, + hyperparameters: Hyperparameters, + train_loader: torch.utils.data.DataLoader, + val_loader: torch.utils.data.DataLoader, + eval_every: int = 1, + print_every: int = 5, + plot: bool = True, + return_loss: bool = False, +) -> tuple[tuple[float], tuple[float]]: """Function for training Neural Network models. Args: @@ -137,7 +137,7 @@ def train_model( return train_loss, val_loss -def model_summary(dataloader, model): +def model_summary(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module): """Prints the summary of a PyTorch model. Args: @@ -160,7 +160,7 @@ def model_summary(dataloader, model): return model_s -def inference(data, model): +def inference(data: torch.utils.data.Dataset, model: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Performs inference on input data using the specified model. Performs inference on the input data using the provided model. The input data should be in the form of a list, @@ -242,7 +242,7 @@ def inference(data, model): return inputs, targets, preds -def volume_inference(volume, model, threshold=0.5): +def volume_inference(volume: np.ndarray, model: torch.nn.Module, threshold:float = 0.5) -> np.ndarray: """ Compute on the entire volume Args: diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py index 353d10bd..27ee78a4 100644 --- a/qim3d/ml/models/_unet.py +++ b/qim3d/ml/models/_unet.py @@ -119,13 +119,13 @@ class Hyperparameters: def __init__( self, - model, - n_epochs=10, - learning_rate=1e-3, - optimizer="Adam", - momentum=0, - weight_decay=0, - loss_function="Focal", + model: nn.Module, + n_epochs: int = 10, + learning_rate: float = 1e-3, + optimizer: str = "Adam", + momentum: float = 0, + weight_decay: float = 0, + loss_function: str = "Focal", ): # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation. @@ -168,13 +168,13 @@ class Hyperparameters: def model_params( self, - model, - n_epochs, - optimizer, - learning_rate, - weight_decay, - momentum, - loss_function, + model: nn.Module, + n_epochs: int, + optimizer: str, + learning_rate: float, + weight_decay: float, + momentum: float, + loss_function: str, ): optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum) @@ -188,7 +188,12 @@ class Hyperparameters: return hyper_dict # selecting the optimizer - def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum): + def _optimizer(self, + model: nn.Module, + optimizer: str, + learning_rate: float, + weight_decay: float, + momentum: float): from torch.optim import Adam, SGD, RMSprop if optimizer == "Adam": @@ -212,7 +217,7 @@ class Hyperparameters: return optim # selecting the loss function - def _loss_functions(self, loss_function): + def _loss_functions(self, loss_function: str): from monai.losses import FocalLoss, DiceLoss, DiceCELoss from torch.nn import BCEWithLogitsLoss diff --git a/qim3d/operations/_common_operations_methods.py b/qim3d/operations/_common_operations_methods.py index 1c14e163..5975c30e 100644 --- a/qim3d/operations/_common_operations_methods.py +++ b/qim3d/operations/_common_operations_methods.py @@ -153,7 +153,7 @@ def fade_mask( def overlay_rgb_images( - background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black:bool = True, + background: np.ndarray, foreground: np.ndarray, alpha: float = 0.5, hide_black: bool = True, ) -> np.ndarray: """ Overlay an RGB foreground onto an RGB background using alpha blending. diff --git a/qim3d/processing/_layers.py b/qim3d/processing/_layers.py index d91ef0a4..ed90133f 100644 --- a/qim3d/processing/_layers.py +++ b/qim3d/processing/_layers.py @@ -2,7 +2,14 @@ import numpy as np from slgbuilder import GraphObject from slgbuilder import MaxflowBuilder -def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, delta:float = 1, min_margin:int = 10, max_margin:int = None, wrap:bool = False): +def segment_layers(data: np.ndarray, + inverted: bool = False, + n_layers: int = 1, + delta: float = 1, + min_margin: int = 10, + max_margin: int = None, + wrap: bool = False + ) -> list: """ Works on 2D and 3D data. Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation diff --git a/qim3d/processing/_local_thickness.py b/qim3d/processing/_local_thickness.py index a7e877d6..947f8865 100644 --- a/qim3d/processing/_local_thickness.py +++ b/qim3d/processing/_local_thickness.py @@ -10,7 +10,7 @@ def local_thickness( image: np.ndarray, scale: float = 1, mask: Optional[np.ndarray] = None, - visualize=False, + visualize: bool = False, **viz_kwargs ) -> np.ndarray: """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness) diff --git a/qim3d/processing/_structure_tensor.py b/qim3d/processing/_structure_tensor.py index 10eba890..f6dc98cd 100644 --- a/qim3d/processing/_structure_tensor.py +++ b/qim3d/processing/_structure_tensor.py @@ -12,7 +12,7 @@ def structure_tensor( rho: float = 6.0, base_noise: bool = True, full: bool = False, - visualize=False, + visualize: bool = False, **viz_kwargs ) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for the 3D structure tensor implementation from the [structure_tensor package](https://github.com/Skielex/structure-tensor/). diff --git a/qim3d/segmentation/_connected_components.py b/qim3d/segmentation/_connected_components.py index f43f0dc2..1f29b235 100644 --- a/qim3d/segmentation/_connected_components.py +++ b/qim3d/segmentation/_connected_components.py @@ -4,7 +4,7 @@ from qim3d.utils._logger import log class CC: - def __init__(self, connected_components, num_connected_components): + def __init__(self, connected_components: np.ndarray, num_connected_components: int): """ Initializes a ConnectedComponents object. diff --git a/qim3d/utils/_doi.py b/qim3d/utils/_doi.py index b3ece663..6c4ae626 100644 --- a/qim3d/utils/_doi.py +++ b/qim3d/utils/_doi.py @@ -4,7 +4,7 @@ import requests from qim3d.utils._logger import log -def _validate_response(response): +def _validate_response(response: requests.Response) -> bool: # Check if we got a good response if not response.ok: log.error(f"Could not read the provided DOI ({response.reason})") @@ -13,7 +13,7 @@ def _validate_response(response): return True -def _doi_to_url(doi): +def _doi_to_url(doi: str) -> str: if doi[:3] != "http": url = "https://doi.org/" + doi else: @@ -22,7 +22,7 @@ def _doi_to_url(doi): return url -def _make_request(doi, header): +def _make_request(doi: str, header: str) -> requests.Response: # Get url from doi url = _doi_to_url(doi) @@ -35,7 +35,7 @@ def _make_request(doi, header): return response -def _log_and_get_text(doi, header): +def _log_and_get_text(doi, header) -> str: response = _make_request(doi, header) if response and response.encoding: @@ -50,13 +50,13 @@ def _log_and_get_text(doi, header): return text -def get_bibtex(doi): +def get_bibtex(doi: str): """Generates bibtex from doi""" header = {"Accept": "application/x-bibtex"} return _log_and_get_text(doi, header) -def cusom_header(doi, header): +def cusom_header(doi: str, header: str) -> str: """Allows a custom header to be passed For example: @@ -67,7 +67,7 @@ def cusom_header(doi, header): """ return _log_and_get_text(doi, header) -def get_metadata(doi): +def get_metadata(doi: str) -> dict: """Generates a metadata dictionary from doi""" header = {"Accept": "application/vnd.citationstyles.csl+json"} response = _make_request(doi, header) @@ -76,7 +76,7 @@ def get_metadata(doi): return metadata -def get_reference(doi): +def get_reference(doi: str) -> str: """Generates a metadata dictionary from doi and use it to build a reference string""" metadata = get_metadata(doi) @@ -84,7 +84,7 @@ def get_reference(doi): return reference_string -def build_reference_string(metadata): +def build_reference_string(metadata: dict) -> str: """Generates a reference string from metadata""" authors = ", ".join([f"{author['family']} {author['given']}" for author in metadata['author']]) year = metadata['issued']['date-parts'][0][0] diff --git a/qim3d/utils/_misc.py b/qim3d/utils/_misc.py index a754ecb8..3d8660b4 100644 --- a/qim3d/utils/_misc.py +++ b/qim3d/utils/_misc.py @@ -12,7 +12,7 @@ import difflib import qim3d -def get_local_ip(): +def get_local_ip() -> str: """Retrieves the local IP address of the current machine. The function uses a socket to determine the local IP address. @@ -42,7 +42,7 @@ def get_local_ip(): return ip_address -def port_from_str(s): +def port_from_str(s: str) -> int: """ Generates a port number from a given string. @@ -65,7 +65,7 @@ def port_from_str(s): return int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10**4) -def gradio_header(title, port): +def gradio_header(title: str, port: int) -> None: """Display the header for a Gradio server. Displays a formatted header containing the provided title, @@ -99,7 +99,7 @@ def gradio_header(title, port): ouf.showlist(details, style="box", title="Starting gradio server") -def sizeof(num, suffix="B"): +def sizeof(num: float, suffix: str = "B") -> str: """Converts a number to a human-readable string representing its size. Converts the given number to a human-readable string representing its size in @@ -131,7 +131,7 @@ def sizeof(num, suffix="B"): return f"{num:.1f} Y{suffix}" -def find_similar_paths(path): +def find_similar_paths(path: str) -> list[str]: parent_dir = os.path.dirname(path) or "." parent_files = os.listdir(parent_dir) if os.path.isdir(parent_dir) else "" valid_paths = [os.path.join(parent_dir, file) for file in parent_files] @@ -165,14 +165,14 @@ def get_file_size(file_path: str) -> int: return file_size -def stringify_path(path): +def stringify_path(path: os.PathLike) -> str: """Converts an os.PathLike object to a string""" if isinstance(path, os.PathLike): path = path.__fspath__() return path -def get_port_dict(): +def get_port_dict() -> dict: # Gets user and port username = getpass.getuser() url = f"https://platform.qim.dk/qim-api/get-port/{username}" @@ -189,7 +189,7 @@ def get_port_dict(): return port_dict -def get_css(): +def get_css() -> str: current_directory = os.path.dirname(os.path.abspath(__file__)) parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir)) @@ -201,7 +201,7 @@ def get_css(): return css_content -def downscale_img(img, max_voxels=512**3): +def downscale_img(img: np.ndarray, max_voxels: int = 512**3) -> np.ndarray: """Downscale image if total number of voxels exceeds 512³. Args: @@ -226,7 +226,7 @@ def downscale_img(img, max_voxels=512**3): return zoom(img, zoom_factor, order=0) -def scale_to_float16(arr: np.ndarray): +def scale_to_float16(arr: np.ndarray) -> np.ndarray: """ Scale the input array to the float16 data type. diff --git a/qim3d/utils/_ome_zarr.py b/qim3d/utils/_ome_zarr.py index 452997f1..db7ad145 100644 --- a/qim3d/utils/_ome_zarr.py +++ b/qim3d/utils/_ome_zarr.py @@ -1,7 +1,7 @@ from zarr.util import normalize_chunks, normalize_dtype, normalize_shape import numpy as np -def get_chunk_size(shape:tuple, dtype): +def get_chunk_size(shape:tuple, dtype: tuple) -> tuple[int, ...]: """ How the chunk size is computed in zarr.storage.init_array_metadata which is ran in the chain of functions we use in qim3d.io.export_ome_zarr function @@ -20,7 +20,7 @@ def get_chunk_size(shape:tuple, dtype): return chunks -def get_n_chunks(shapes:tuple, dtypes:tuple): +def get_n_chunks(shapes: tuple, dtypes: tuple) -> int: """ Estimates how many chunks we will use in advence so we can pass this number to a progress bar and track how many have been already written to disk diff --git a/qim3d/utils/_progress_bar.py b/qim3d/utils/_progress_bar.py index 629026f4..57f33738 100644 --- a/qim3d/utils/_progress_bar.py +++ b/qim3d/utils/_progress_bar.py @@ -24,7 +24,7 @@ class RepeatTimer(Timer): self.function(*self.args, **self.kwargs) class ProgressBar(ABC): - def __init__(self,tqdm_kwargs:dict, repeat_time: float, *args, **kwargs): + def __init__(self, tqdm_kwargs: dict, repeat_time: float, *args, **kwargs): """ Context manager for ('with' statement) to track progress during a long progress over which we don't have control (like loading a file) and thus can not insert the tqdm @@ -98,7 +98,7 @@ class FileLoadingProgressBar(ProgressBar): super().__init__( tqdm_kwargs, repeat_time) self.process = psutil.Process() - def get_new_update(self): + def get_new_update(self) -> int: counters = self.process.io_counters() try: memory = counters.read_chars @@ -107,7 +107,7 @@ class FileLoadingProgressBar(ProgressBar): return memory class OmeZarrExportProgressBar(ProgressBar): - def __init__(self,path:str, n_chunks:int, reapeat_time="auto"): + def __init__(self,path: str, n_chunks: int, reapeat_time: str = "auto"): """ Context manager to track the exporting of OmeZarr files. @@ -152,7 +152,7 @@ class OmeZarrExportProgressBar(ProgressBar): self.last_update = 0 def get_new_update(self): - def file_count(folder_path:str): + def file_count(folder_path: str) -> int: """ Goes recursively through the folders and counts how many files are there, Doesn't count metadata json files diff --git a/qim3d/utils/_server.py b/qim3d/utils/_server.py index 9d20fccd..e4a46673 100644 --- a/qim3d/utils/_server.py +++ b/qim3d/utils/_server.py @@ -14,7 +14,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler): self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type") super().end_headers() - def list_directory(self, path): + def list_directory(self, path: str): """Helper to produce a directory listing, includes hidden files.""" try: file_list = os.listdir(path) @@ -49,7 +49,7 @@ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler): # Write the encoded HTML directly to the response self.wfile.write(encoded) -def start_http_server(directory, port=8000): +def start_http_server(directory: str, port: int = 8000) -> HTTPServer: """ Starts an HTTP server serving the specified directory on the given port with CORS enabled. diff --git a/qim3d/utils/_system.py b/qim3d/utils/_system.py index 19792627..ccccb672 100644 --- a/qim3d/utils/_system.py +++ b/qim3d/utils/_system.py @@ -36,7 +36,7 @@ class Memory: round(self.free_pct, 1), ) -def _test_disk_speed(file_size_bytes=1024, ntimes=10): +def _test_disk_speed(file_size_bytes: int = 1024, ntimes: int =10) -> tuple[float, float, float, float]: ''' Test the write and read speed of the disk by writing a file of a given size and then reading it back. @@ -95,7 +95,7 @@ def _test_disk_speed(file_size_bytes=1024, ntimes=10): return avg_write_speed, write_speed_std, avg_read_speed, read_speed_std -def disk_report(file_size=1024 * 1024 * 100, ntimes=10): +def disk_report(file_size: int = 1024 * 1024 * 100, ntimes: int = 10) -> None: ''' Report the average write and read speed of the disk. diff --git a/qim3d/viz/_cc.py b/qim3d/viz/_cc.py index 1c0ed560..14ecf1f4 100644 --- a/qim3d/viz/_cc.py +++ b/qim3d/viz/_cc.py @@ -2,15 +2,15 @@ import matplotlib.pyplot as plt import numpy as np import qim3d from qim3d.utils._logger import log - +from segmentation._connected_components import CC def plot_cc( - connected_components, + connected_components: CC, component_indexs: list | tuple = None, - max_cc_to_plot=32, - overlay=None, - crop=False, - show=True, + max_cc_to_plot: int = 32, + overlay: np.ndarray = None, + crop: bool = False, + show: bool = True, cmap: str = "viridis", vmin: float = None, vmax: float = None, diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index 2198cd94..766aae15 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -9,7 +9,9 @@ from typing import List, Optional, Union import dask.array as da import ipywidgets as widgets +import matplotlib.figure import matplotlib.pyplot as plt +import matplotlib from IPython.display import SVG, display import matplotlib import numpy as np @@ -29,7 +31,7 @@ def slices_grid( color_map: str = "magma", value_min: float = None, value_max: float = None, - image_size=None, + image_size: int = None, image_height: int = 2, image_width: int = 2, display_figure: bool = False, @@ -38,7 +40,7 @@ def slices_grid( color_bar: bool = False, color_bar_style: str = "small", **matplotlib_imshow_kwargs, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Displays one or several slices from a 3d volume. By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices. @@ -296,7 +298,7 @@ def slices_grid( return fig -def _get_slice_range(position: int, num_slices: int, n_total): +def _get_slice_range(position: int, num_slices: int, n_total) -> np.ndarray: """Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" start_idx = position - num_slices // 2 end_idx = ( @@ -324,7 +326,7 @@ def slicer( image_width: int = 3, display_positions: bool = False, interpolation: Optional[str] = None, - image_size=None, + image_size: int = None, color_bar: bool = False, **matplotlib_imshow_kwargs, ) -> widgets.interactive: @@ -401,8 +403,8 @@ def slicer_orthogonal( image_width: int = 3, display_positions: bool = False, interpolation: Optional[str] = None, - image_size=None, -): + image_size: int = None, +)-> widgets.interactive: """Interactive widget for visualizing orthogonal slices of a 3D volume. Args: @@ -461,7 +463,7 @@ def fade_mask( color_map: str = "magma", value_min: float = None, value_max: float = None, -): +)-> widgets.interactive: """Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. @@ -596,7 +598,7 @@ def fade_mask( return slicer_obj -def chunks(zarr_path: str, **kwargs): +def chunks(zarr_path: str, **kwargs)-> widgets.interactive: """ Function to visualize chunks of a Zarr dataset using the specified visualization method. @@ -861,7 +863,7 @@ def histogram( return_fig=False, show=True, **sns_kwargs, -): +) -> Optional[matplotlib.figure.Figure]: """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. diff --git a/qim3d/viz/_detection.py b/qim3d/viz/_detection.py index a0926105..b4ab9098 100644 --- a/qim3d/viz/_detection.py +++ b/qim3d/viz/_detection.py @@ -6,7 +6,7 @@ from IPython.display import clear_output, display import qim3d -def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs): +def circles(blobs: np.ndarray, vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive: """ Plots the blobs found on a slice of the volume. diff --git a/qim3d/viz/_k3d.py b/qim3d/viz/_k3d.py index b80e9556..8b2c23ac 100644 --- a/qim3d/viz/_k3d.py +++ b/qim3d/viz/_k3d.py @@ -15,18 +15,18 @@ from qim3d.utils._misc import downscale_img, scale_to_float16 def volumetric( - img, - aspectmode="data", - show=True, - save=False, - grid_visible=False, - color_map='magma', - constant_opacity=False, - vmin=None, - vmax=None, - samples="auto", - max_voxels=512**3, - data_type="scaled_float16", + img: np.ndarray, + aspectmode: str = "data", + show: bool = True, + save: bool = False, + grid_visible: bool = False, + color_map: str = 'magma', + constant_opacity: bool = False, + vmin: float = None, + vmax: float = None, + samples: int|str = "auto", + max_voxels: int = 512**3, + data_type: str = "scaled_float16", **kwargs, ): """ @@ -175,13 +175,13 @@ def volumetric( def mesh( - verts, - faces, - wireframe=True, - flat_shading=True, - grid_visible=False, - show=True, - save=False, + verts: np.ndarray, + faces: np.ndarray, + wireframe: bool = True, + flat_shading: bool = True, + grid_visible: bool = False, + show: bool = True, + save: bool = False, **kwargs, ): """ diff --git a/qim3d/viz/_layers2d.py b/qim3d/viz/_layers2d.py index 1e284461..676845a5 100644 --- a/qim3d/viz/_layers2d.py +++ b/qim3d/viz/_layers2d.py @@ -8,7 +8,7 @@ import numpy as np from PIL import Image -def image_with_lines(image:np.ndarray, lines: list, line_thickness:float|int) -> Image: +def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> Image: """ Plots the image and plots the lines on top of it. Then extracts it as PIL.Image and in the same size as the input image was. Paramters: diff --git a/qim3d/viz/_metrics.py b/qim3d/viz/_metrics.py index 53a7f3a6..ded37d24 100644 --- a/qim3d/viz/_metrics.py +++ b/qim3d/viz/_metrics.py @@ -1,19 +1,21 @@ """Visualization tools""" +import matplotlib.figure import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib import colormaps from qim3d.utils._logger import log - +import torch +import matplotlib def plot_metrics( - *metrics, - linestyle="-", - batch_linestyle="dotted", + *metrics: tuple[dict[str, float]], + linestyle: str = "-", + batch_linestyle: str = "dotted", labels: list = None, figsize: tuple = (16, 6), - show=False + show: bool = False ): """ Plots the metrics over epochs and batches. @@ -79,8 +81,13 @@ def plot_metrics( def grid_overview( - data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha=0.5, show=False -): + data: list|torch.utils.data.Dataset, + num_images: int = 7, + cmap_im: str = "gray", + cmap_segm: str = "viridis", + alpha: float = 0.5, + show: bool = False +)-> matplotlib.figure.Figure: """Displays an overview grid of images, labels, and masks (if they exist). Labels are the annotated target segmentations @@ -174,13 +181,13 @@ def grid_overview( def grid_pred( - in_targ_preds, - num_images=7, - cmap_im="gray", - cmap_segm="viridis", - alpha=0.5, - show=False, -): + in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray], + num_images: int = 7, + cmap_im: str = "gray", + cmap_segm: str = "viridis", + alpha: float = 0.5, + show: bool = False, +)-> matplotlib.figure.Figure: """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. Displays a grid of subplots representing different aspects of the input images and segmentations. @@ -282,7 +289,7 @@ def grid_pred( return fig -def vol_masked(vol, vol_mask, viz_delta=128): +def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray: """ Applies masking to a volume based on a binary volume mask. -- GitLab