From ed4dfcf8771fb9aa475ce2fddee727958fedd34d Mon Sep 17 00:00:00 2001 From: fima <fima@dtu.dk> Date: Fri, 10 Jan 2025 13:36:23 +0100 Subject: [PATCH] Refactor tests for processing and adapt it to new library structure, plus fix... --- qim3d/filters/_common_filter_methods.py | 68 ++++++++++++------------- qim3d/gui/annotation_tool.py | 1 + qim3d/ml/_data.py | 2 +- qim3d/ml/_ml_utils.py | 2 +- 4 files changed, 37 insertions(+), 36 deletions(-) diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py index 026d2926..1b43bae2 100644 --- a/qim3d/filters/_common_filter_methods.py +++ b/qim3d/filters/_common_filter_methods.py @@ -27,10 +27,7 @@ __all__ = [ class FilterBase: - def __init__(self, - dask: bool = False, - chunks: str = "auto", - *args, **kwargs): + def __init__(self, *args, dask: bool = False, chunks: str = "auto", **kwargs): """ Base class for image filters. @@ -43,6 +40,7 @@ class FilterBase: self.chunks = chunks self.kwargs = kwargs + class Gaussian(FilterBase): def __call__(self, input: np.ndarray) -> np.ndarray: """ @@ -54,7 +52,9 @@ class Gaussian(FilterBase): Returns: The filtered image or volume. """ - return gaussian(input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs) + return gaussian( + input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs + ) class Median(FilterBase): @@ -98,6 +98,7 @@ class Minimum(FilterBase): """ return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs) + class Tophat(FilterBase): def __call__(self, input: np.ndarray) -> np.ndarray: """ @@ -142,8 +143,9 @@ class Pipeline: ```   - - """ + + """ + def __init__(self, *args: Type[FilterBase]): """ Represents a sequence of image filters. @@ -182,7 +184,7 @@ class Pipeline: Args: fn: An instance of a FilterBase subclass to be appended. - + Example: ```python import qim3d @@ -214,10 +216,9 @@ class Pipeline: return input -def gaussian(vol: np.ndarray, - dask: bool = False, - chunks: str = 'auto', - *args, **kwargs) -> np.ndarray: +def gaussian( + vol: np.ndarray, dask: bool = False, chunks: str = "auto", *args, **kwargs +) -> np.ndarray: """ Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter. @@ -231,7 +232,7 @@ def gaussian(vol: np.ndarray, Returns: The filtered image or volume. """ - + if dask: if not isinstance(vol, da.Array): vol = da.from_array(vol, chunks=chunks) @@ -243,10 +244,9 @@ def gaussian(vol: np.ndarray, return res -def median(vol: np.ndarray, - dask: bool = False, - chunks: str ='auto', - **kwargs) -> np.ndarray: +def median( + vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs +) -> np.ndarray: """ Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter. @@ -270,10 +270,9 @@ def median(vol: np.ndarray, return res -def maximum(vol: np.ndarray, - dask: bool = False, - chunks: str = 'auto', - **kwargs) -> np.ndarray: +def maximum( + vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs +) -> np.ndarray: """ Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter. @@ -297,10 +296,9 @@ def maximum(vol: np.ndarray, return res -def minimum(vol: np.ndarray, - dask: bool = False, - chunks: str = 'auto', - **kwargs) -> np.ndarray: +def minimum( + vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs +) -> np.ndarray: """ Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter. @@ -323,10 +321,10 @@ def minimum(vol: np.ndarray, res = ndimage.minimum_filter(vol, **kwargs) return res -def tophat(vol: np.ndarray, - dask: bool = False, - chunks: str = 'auto', - **kwargs) -> np.ndarray: + +def tophat( + vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs +) -> np.ndarray: """ Remove background from the volume. @@ -347,15 +345,17 @@ def tophat(vol: np.ndarray, if dask: log.info("Dask not supported for tophat filter, switching to scipy.") - + if background == "bright": - log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat") + log.info( + "Bright background selected, volume will be temporarily inverted when applying white_tophat" + ) vol = np.invert(vol) - + selem = morphology.ball(radius) vol = vol - morphology.white_tophat(vol, selem) if background == "bright": vol = np.invert(vol) - - return vol \ No newline at end of file + + return vol diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py index 79910e76..a3cac14e 100644 --- a/qim3d/gui/annotation_tool.py +++ b/qim3d/gui/annotation_tool.py @@ -28,6 +28,7 @@ import gradio as gr import numpy as np from PIL import Image import qim3d +from qim3d.gui.interface import BaseInterface # TODO: img in launch should be self.img diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py index 6050001e..253da4a1 100644 --- a/qim3d/ml/_data.py +++ b/qim3d/ml/_data.py @@ -6,7 +6,7 @@ import torch import numpy as np from typing import Optional, Callable import torch.nn as nn -from ._data import Augmentation +from ._augmentations import Augmentation class Dataset(torch.utils.data.Dataset): """ diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py index f46a7481..98196f2d 100644 --- a/qim3d/ml/_ml_utils.py +++ b/qim3d/ml/_ml_utils.py @@ -9,7 +9,7 @@ from qim3d.viz._metrics import plot_metrics from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -from models._unet import Hyperparameters +from .models._unet import Hyperparameters def train_model( model: torch.nn.Module, -- GitLab