Skip to content
Snippets Groups Projects
Commit ed4dfcf8 authored by fima's avatar fima :beers:
Browse files

Refactor tests for processing and adapt it to new library structure, plus fix...

parent 76e3a4c7
No related branches found
No related tags found
1 merge request!145Refactor tests for processing and adapt it to new library structure, plus fix...
...@@ -27,10 +27,7 @@ __all__ = [ ...@@ -27,10 +27,7 @@ __all__ = [
class FilterBase: class FilterBase:
def __init__(self, def __init__(self, *args, dask: bool = False, chunks: str = "auto", **kwargs):
dask: bool = False,
chunks: str = "auto",
*args, **kwargs):
""" """
Base class for image filters. Base class for image filters.
...@@ -43,6 +40,7 @@ class FilterBase: ...@@ -43,6 +40,7 @@ class FilterBase:
self.chunks = chunks self.chunks = chunks
self.kwargs = kwargs self.kwargs = kwargs
class Gaussian(FilterBase): class Gaussian(FilterBase):
def __call__(self, input: np.ndarray) -> np.ndarray: def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
...@@ -54,7 +52,9 @@ class Gaussian(FilterBase): ...@@ -54,7 +52,9 @@ class Gaussian(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return gaussian(input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs) return gaussian(
input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs
)
class Median(FilterBase): class Median(FilterBase):
...@@ -98,6 +98,7 @@ class Minimum(FilterBase): ...@@ -98,6 +98,7 @@ class Minimum(FilterBase):
""" """
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs) return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase): class Tophat(FilterBase):
def __call__(self, input: np.ndarray) -> np.ndarray: def __call__(self, input: np.ndarray) -> np.ndarray:
""" """
...@@ -142,8 +143,9 @@ class Pipeline: ...@@ -142,8 +143,9 @@ class Pipeline:
``` ```
![original volume](assets/screenshots/filter_original.png) ![original volume](assets/screenshots/filter_original.png)
![filtered volume](assets/screenshots/filter_processed.png) ![filtered volume](assets/screenshots/filter_processed.png)
""" """
def __init__(self, *args: Type[FilterBase]): def __init__(self, *args: Type[FilterBase]):
""" """
Represents a sequence of image filters. Represents a sequence of image filters.
...@@ -182,7 +184,7 @@ class Pipeline: ...@@ -182,7 +184,7 @@ class Pipeline:
Args: Args:
fn: An instance of a FilterBase subclass to be appended. fn: An instance of a FilterBase subclass to be appended.
Example: Example:
```python ```python
import qim3d import qim3d
...@@ -214,10 +216,9 @@ class Pipeline: ...@@ -214,10 +216,9 @@ class Pipeline:
return input return input
def gaussian(vol: np.ndarray, def gaussian(
dask: bool = False, vol: np.ndarray, dask: bool = False, chunks: str = "auto", *args, **kwargs
chunks: str = 'auto', ) -> np.ndarray:
*args, **kwargs) -> np.ndarray:
""" """
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter. Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
...@@ -231,7 +232,7 @@ def gaussian(vol: np.ndarray, ...@@ -231,7 +232,7 @@ def gaussian(vol: np.ndarray,
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
if dask: if dask:
if not isinstance(vol, da.Array): if not isinstance(vol, da.Array):
vol = da.from_array(vol, chunks=chunks) vol = da.from_array(vol, chunks=chunks)
...@@ -243,10 +244,9 @@ def gaussian(vol: np.ndarray, ...@@ -243,10 +244,9 @@ def gaussian(vol: np.ndarray,
return res return res
def median(vol: np.ndarray, def median(
dask: bool = False, vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
chunks: str ='auto', ) -> np.ndarray:
**kwargs) -> np.ndarray:
""" """
Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter. Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
...@@ -270,10 +270,9 @@ def median(vol: np.ndarray, ...@@ -270,10 +270,9 @@ def median(vol: np.ndarray,
return res return res
def maximum(vol: np.ndarray, def maximum(
dask: bool = False, vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
chunks: str = 'auto', ) -> np.ndarray:
**kwargs) -> np.ndarray:
""" """
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter. Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
...@@ -297,10 +296,9 @@ def maximum(vol: np.ndarray, ...@@ -297,10 +296,9 @@ def maximum(vol: np.ndarray,
return res return res
def minimum(vol: np.ndarray, def minimum(
dask: bool = False, vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
chunks: str = 'auto', ) -> np.ndarray:
**kwargs) -> np.ndarray:
""" """
Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter. Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
...@@ -323,10 +321,10 @@ def minimum(vol: np.ndarray, ...@@ -323,10 +321,10 @@ def minimum(vol: np.ndarray,
res = ndimage.minimum_filter(vol, **kwargs) res = ndimage.minimum_filter(vol, **kwargs)
return res return res
def tophat(vol: np.ndarray,
dask: bool = False, def tophat(
chunks: str = 'auto', vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
**kwargs) -> np.ndarray: ) -> np.ndarray:
""" """
Remove background from the volume. Remove background from the volume.
...@@ -347,15 +345,17 @@ def tophat(vol: np.ndarray, ...@@ -347,15 +345,17 @@ def tophat(vol: np.ndarray,
if dask: if dask:
log.info("Dask not supported for tophat filter, switching to scipy.") log.info("Dask not supported for tophat filter, switching to scipy.")
if background == "bright": if background == "bright":
log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat") log.info(
"Bright background selected, volume will be temporarily inverted when applying white_tophat"
)
vol = np.invert(vol) vol = np.invert(vol)
selem = morphology.ball(radius) selem = morphology.ball(radius)
vol = vol - morphology.white_tophat(vol, selem) vol = vol - morphology.white_tophat(vol, selem)
if background == "bright": if background == "bright":
vol = np.invert(vol) vol = np.invert(vol)
return vol return vol
\ No newline at end of file
...@@ -28,6 +28,7 @@ import gradio as gr ...@@ -28,6 +28,7 @@ import gradio as gr
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import qim3d import qim3d
from qim3d.gui.interface import BaseInterface
# TODO: img in launch should be self.img # TODO: img in launch should be self.img
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import numpy as np import numpy as np
from typing import Optional, Callable from typing import Optional, Callable
import torch.nn as nn import torch.nn as nn
from ._data import Augmentation from ._augmentations import Augmentation
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
""" """
......
...@@ -9,7 +9,7 @@ from qim3d.viz._metrics import plot_metrics ...@@ -9,7 +9,7 @@ from qim3d.viz._metrics import plot_metrics
from tqdm.auto import tqdm from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm from tqdm.contrib.logging import logging_redirect_tqdm
from models._unet import Hyperparameters from .models._unet import Hyperparameters
def train_model( def train_model(
model: torch.nn.Module, model: torch.nn.Module,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment