Skip to content
Snippets Groups Projects
Commit 8dbe4ece authored by fima's avatar fima :beers:
Browse files

Merge branch 'filters_dask' into 'main'

Option to apply filters using dask

See merge request !109
parents d3bac131 e17e23c3
No related branches found
No related tags found
1 merge request!109Option to apply filters using dask
...@@ -5,6 +5,8 @@ from typing import Type, Union ...@@ -5,6 +5,8 @@ from typing import Type, Union
import numpy as np import numpy as np
from scipy import ndimage from scipy import ndimage
from skimage import morphology from skimage import morphology
import dask.array as da
import dask_image.ndfilters as dask_ndfilters
from qim3d.utils.logger import log from qim3d.utils.logger import log
...@@ -24,7 +26,7 @@ __all__ = [ ...@@ -24,7 +26,7 @@ __all__ = [
class FilterBase: class FilterBase:
def __init__(self, *args, **kwargs): def __init__(self, dask=False, chunks="auto", *args, **kwargs):
""" """
Base class for image filters. Base class for image filters.
...@@ -33,9 +35,10 @@ class FilterBase: ...@@ -33,9 +35,10 @@ class FilterBase:
**kwargs: Additional keyword arguments for filter initialization. **kwargs: Additional keyword arguments for filter initialization.
""" """
self.args = args self.args = args
self.dask = dask
self.chunks = chunks
self.kwargs = kwargs self.kwargs = kwargs
class Gaussian(FilterBase): class Gaussian(FilterBase):
def __call__(self, input): def __call__(self, input):
""" """
...@@ -47,7 +50,7 @@ class Gaussian(FilterBase): ...@@ -47,7 +50,7 @@ class Gaussian(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return gaussian(input, *self.args, **self.kwargs) return gaussian(input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs)
class Median(FilterBase): class Median(FilterBase):
...@@ -61,7 +64,7 @@ class Median(FilterBase): ...@@ -61,7 +64,7 @@ class Median(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return median(input, **self.kwargs) return median(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Maximum(FilterBase): class Maximum(FilterBase):
...@@ -75,7 +78,7 @@ class Maximum(FilterBase): ...@@ -75,7 +78,7 @@ class Maximum(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return maximum(input, **self.kwargs) return maximum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Minimum(FilterBase): class Minimum(FilterBase):
...@@ -89,7 +92,7 @@ class Minimum(FilterBase): ...@@ -89,7 +92,7 @@ class Minimum(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return minimum(input, **self.kwargs) return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase): class Tophat(FilterBase):
def __call__(self, input): def __call__(self, input):
...@@ -102,7 +105,7 @@ class Tophat(FilterBase): ...@@ -102,7 +105,7 @@ class Tophat(FilterBase):
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return tophat(input, **self.kwargs) return tophat(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Pipeline: class Pipeline:
...@@ -121,7 +124,7 @@ class Pipeline: ...@@ -121,7 +124,7 @@ class Pipeline:
# Create filter pipeline # Create filter pipeline
pipeline = Pipeline( pipeline = Pipeline(
Median(size=5), Median(size=5),
Gaussian(sigma=3) Gaussian(sigma=3, dask = True)
) )
# Append a third filter to the pipeline # Append a third filter to the pipeline
...@@ -183,7 +186,7 @@ class Pipeline: ...@@ -183,7 +186,7 @@ class Pipeline:
# Create filter pipeline # Create filter pipeline
pipeline = Pipeline( pipeline = Pipeline(
Maximum(size=3) Maximum(size=3, dask=True),
) )
# Append a second filter to the pipeline # Append a second filter to the pipeline
...@@ -207,77 +210,125 @@ class Pipeline: ...@@ -207,77 +210,125 @@ class Pipeline:
return input return input
def gaussian(vol, *args, **kwargs): def gaussian(vol, dask=False, chunks='auto', *args, **kwargs):
""" """
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter. Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
Args: Args:
vol: The input image or volume. vol: The input image or volume.
dask: Whether to use Dask for the Gaussian filter.
chunks: Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing.
*args: Additional positional arguments for the Gaussian filter. *args: Additional positional arguments for the Gaussian filter.
**kwargs: Additional keyword arguments for the Gaussian filter. **kwargs: Additional keyword arguments for the Gaussian filter.
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return ndimage.gaussian_filter(vol, *args, **kwargs)
if dask:
if not isinstance(vol, da.Array):
vol = da.from_array(vol, chunks=chunks)
dask_vol = dask_ndfilters.gaussian_filter(vol, *args, **kwargs)
res = dask_vol.compute()
return res
else:
res = ndimage.gaussian_filter(vol, *args, **kwargs)
return res
def median(vol, **kwargs): def median(vol, dask=False, chunks='auto', **kwargs):
""" """
Applies a median filter to the input volume using scipy.ndimage.median_filter. Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
Args: Args:
vol: The input image or volume. vol: The input image or volume.
dask: Whether to use Dask for the median filter.
chunks: Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing.
**kwargs: Additional keyword arguments for the median filter. **kwargs: Additional keyword arguments for the median filter.
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return ndimage.median_filter(vol, **kwargs) if dask:
if not isinstance(vol, da.Array):
vol = da.from_array(vol, chunks=chunks)
dask_vol = dask_ndfilters.median_filter(vol, **kwargs)
res = dask_vol.compute()
return res
else:
res = ndimage.median_filter(vol, **kwargs)
return res
def maximum(vol, **kwargs): def maximum(vol, dask=False, chunks='auto', **kwargs):
""" """
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter. Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
Args: Args:
vol: The input image or volume. vol: The input image or volume.
dask: Whether to use Dask for the maximum filter.
chunks: Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing.
**kwargs: Additional keyword arguments for the maximum filter. **kwargs: Additional keyword arguments for the maximum filter.
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return ndimage.maximum_filter(vol, **kwargs) if dask:
if not isinstance(vol, da.Array):
vol = da.from_array(vol, chunks=chunks)
dask_vol = dask_ndfilters.maximum_filter(vol, **kwargs)
res = dask_vol.compute()
return res
else:
res = ndimage.maximum_filter(vol, **kwargs)
return res
def minimum(vol, **kwargs): def minimum(vol, dask=False, chunks='auto', **kwargs):
""" """
Applies a minimum filter to the input volume using scipy.ndimage.mainimum_filter. Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
Args: Args:
vol: The input image or volume. vol: The input image or volume.
dask: Whether to use Dask for the minimum filter.
chunks: Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing.
**kwargs: Additional keyword arguments for the minimum filter. **kwargs: Additional keyword arguments for the minimum filter.
Returns: Returns:
The filtered image or volume. The filtered image or volume.
""" """
return ndimage.minimum_filter(vol, **kwargs) if dask:
if not isinstance(vol, da.Array):
vol = da.from_array(vol, chunks=chunks)
dask_vol = dask_ndfilters.minimum_filter(vol, **kwargs)
res = dask_vol.compute()
return res
else:
res = ndimage.minimum_filter(vol, **kwargs)
return res
def tophat(vol, **kwargs): def tophat(vol, dask=False, chunks='auto', **kwargs):
""" """
Remove background from the volume. Remove background from the volume.
Args: Args:
vol: The volume to remove background from vol: The volume to remove background from.
radius: The radius of the structuring element (default: 3) radius: The radius of the structuring element (default: 3).
background: color of the background, 'dark' or 'bright' (default: 'dark'). If 'bright', volume will be inverted. background: Color of the background, 'dark' or 'bright' (default: 'dark'). If 'bright', volume will be inverted.
dask: Whether to use Dask for the tophat filter (not supported, will default to SciPy).
chunks: Defines how to divide the array into blocks when using Dask. Can be an integer, tuple, size in bytes, or "auto" for automatic sizing.
**kwargs: Additional keyword arguments.
Returns: Returns:
vol: The volume with background removed vol: The volume with background removed.
""" """
radius = kwargs["radius"] if "radius" in kwargs else 3 radius = kwargs["radius"] if "radius" in kwargs else 3
background = kwargs["background"] if "background" in kwargs else "dark" background = kwargs["background"] if "background" in kwargs else "dark"
if dask:
log.info("Dask not supported for tophat filter, switching to scipy.")
if background == "bright": if background == "bright":
log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat") log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat")
vol = np.invert(vol) vol = np.invert(vol)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment