From ed4dfcf8771fb9aa475ce2fddee727958fedd34d Mon Sep 17 00:00:00 2001
From: fima <fima@dtu.dk>
Date: Fri, 10 Jan 2025 13:36:23 +0100
Subject: [PATCH] Refactor tests for processing and adapt it to new library
 structure, plus fix...

---
 qim3d/filters/_common_filter_methods.py | 68 ++++++++++++-------------
 qim3d/gui/annotation_tool.py            |  1 +
 qim3d/ml/_data.py                       |  2 +-
 qim3d/ml/_ml_utils.py                   |  2 +-
 4 files changed, 37 insertions(+), 36 deletions(-)

diff --git a/qim3d/filters/_common_filter_methods.py b/qim3d/filters/_common_filter_methods.py
index 026d2926..1b43bae2 100644
--- a/qim3d/filters/_common_filter_methods.py
+++ b/qim3d/filters/_common_filter_methods.py
@@ -27,10 +27,7 @@ __all__ = [
 
 
 class FilterBase:
-    def __init__(self, 
-                 dask: bool = False, 
-                 chunks: str = "auto", 
-                 *args, **kwargs):
+    def __init__(self, *args, dask: bool = False, chunks: str = "auto", **kwargs):
         """
         Base class for image filters.
 
@@ -43,6 +40,7 @@ class FilterBase:
         self.chunks = chunks
         self.kwargs = kwargs
 
+
 class Gaussian(FilterBase):
     def __call__(self, input: np.ndarray) -> np.ndarray:
         """
@@ -54,7 +52,9 @@ class Gaussian(FilterBase):
         Returns:
             The filtered image or volume.
         """
-        return gaussian(input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs)
+        return gaussian(
+            input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs
+        )
 
 
 class Median(FilterBase):
@@ -98,6 +98,7 @@ class Minimum(FilterBase):
         """
         return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
 
+
 class Tophat(FilterBase):
     def __call__(self, input: np.ndarray) -> np.ndarray:
         """
@@ -142,8 +143,9 @@ class Pipeline:
         ```
         ![original volume](assets/screenshots/filter_original.png)
         ![filtered volume](assets/screenshots/filter_processed.png)
-            
-        """
+
+    """
+
     def __init__(self, *args: Type[FilterBase]):
         """
         Represents a sequence of image filters.
@@ -182,7 +184,7 @@ class Pipeline:
 
         Args:
             fn: An instance of a FilterBase subclass to be appended.
-        
+
         Example:
             ```python
             import qim3d
@@ -214,10 +216,9 @@ class Pipeline:
         return input
 
 
-def gaussian(vol: np.ndarray, 
-             dask: bool = False, 
-             chunks: str = 'auto',
-             *args, **kwargs) -> np.ndarray:
+def gaussian(
+    vol: np.ndarray, dask: bool = False, chunks: str = "auto", *args, **kwargs
+) -> np.ndarray:
     """
     Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
 
@@ -231,7 +232,7 @@ def gaussian(vol: np.ndarray,
     Returns:
         The filtered image or volume.
     """
-    
+
     if dask:
         if not isinstance(vol, da.Array):
             vol = da.from_array(vol, chunks=chunks)
@@ -243,10 +244,9 @@ def gaussian(vol: np.ndarray,
         return res
 
 
-def median(vol: np.ndarray, 
-           dask: bool = False, 
-           chunks: str ='auto', 
-           **kwargs) -> np.ndarray:
+def median(
+    vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
+) -> np.ndarray:
     """
     Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
 
@@ -270,10 +270,9 @@ def median(vol: np.ndarray,
         return res
 
 
-def maximum(vol: np.ndarray, 
-            dask: bool = False, 
-            chunks: str = 'auto', 
-            **kwargs) -> np.ndarray:
+def maximum(
+    vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
+) -> np.ndarray:
     """
     Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
 
@@ -297,10 +296,9 @@ def maximum(vol: np.ndarray,
         return res
 
 
-def minimum(vol: np.ndarray, 
-            dask: bool = False, 
-            chunks: str = 'auto', 
-            **kwargs) -> np.ndarray:
+def minimum(
+    vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
+) -> np.ndarray:
     """
     Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
 
@@ -323,10 +321,10 @@ def minimum(vol: np.ndarray,
         res = ndimage.minimum_filter(vol, **kwargs)
         return res
 
-def tophat(vol: np.ndarray, 
-           dask: bool = False, 
-           chunks: str = 'auto', 
-           **kwargs) -> np.ndarray:
+
+def tophat(
+    vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
+) -> np.ndarray:
     """
     Remove background from the volume.
 
@@ -347,15 +345,17 @@ def tophat(vol: np.ndarray,
 
     if dask:
         log.info("Dask not supported for tophat filter, switching to scipy.")
-    
+
     if background == "bright":
-        log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat")
+        log.info(
+            "Bright background selected, volume will be temporarily inverted when applying white_tophat"
+        )
         vol = np.invert(vol)
-    
+
     selem = morphology.ball(radius)
     vol = vol - morphology.white_tophat(vol, selem)
 
     if background == "bright":
         vol = np.invert(vol)
-    
-    return vol
\ No newline at end of file
+
+    return vol
diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py
index 79910e76..a3cac14e 100644
--- a/qim3d/gui/annotation_tool.py
+++ b/qim3d/gui/annotation_tool.py
@@ -28,6 +28,7 @@ import gradio as gr
 import numpy as np
 from PIL import Image
 import qim3d
+from qim3d.gui.interface import BaseInterface
 
 # TODO: img in launch should be self.img
 
diff --git a/qim3d/ml/_data.py b/qim3d/ml/_data.py
index 6050001e..253da4a1 100644
--- a/qim3d/ml/_data.py
+++ b/qim3d/ml/_data.py
@@ -6,7 +6,7 @@ import torch
 import numpy as np
 from typing import Optional, Callable
 import torch.nn as nn
-from ._data import Augmentation
+from ._augmentations import Augmentation
 
 class Dataset(torch.utils.data.Dataset):
     """
diff --git a/qim3d/ml/_ml_utils.py b/qim3d/ml/_ml_utils.py
index f46a7481..98196f2d 100644
--- a/qim3d/ml/_ml_utils.py
+++ b/qim3d/ml/_ml_utils.py
@@ -9,7 +9,7 @@ from qim3d.viz._metrics import plot_metrics
 
 from tqdm.auto import tqdm
 from tqdm.contrib.logging import logging_redirect_tqdm
-from models._unet import Hyperparameters
+from .models._unet import Hyperparameters
 
 def train_model(
     model: torch.nn.Module,
-- 
GitLab