Skip to content
Snippets Groups Projects
Commit 995ab1e0 authored by s184058's avatar s184058 Committed by fima
Browse files

3D image filters

parent d2a07ec8
No related branches found
No related tags found
1 merge request!493D image filters
%% Cell type:code id: tags:
``` python
import qim3d
import qim3d.processing.filters as filters
import numpy as np
from scipy import ndimage
```
%% Cell type:code id: tags:
``` python
vol = qim3d.examples.fly_150x256x256
```
%% Cell type:markdown id: tags:
## Using the filter functions directly
%% Cell type:code id: tags:
``` python
### Gaussian filter
out1_gauss = filters.gaussian(vol,3)
# or
out2_gauss = filters.gaussian(vol,sigma=3) # sigma is positional, but can be passed as a kwarg
### Median filter
out_median = filters.median(vol,size=5)
```
%% Cell type:markdown id: tags:
## Using filter classes
%% Cell type:code id: tags:
``` python
gaussian_fn = filters.Gaussian(sigma=3)
out3_gauss = gaussian_fn(vol)
```
%% Cell type:markdown id: tags:
## Using filter classes to construct a pipeline of filters
%% Cell type:code id: tags:
``` python
pipeline = filters.Pipeline(
filters.Gaussian(sigma=3),
filters.Median(size=10))
out_seq = pipeline(vol)
```
%% Cell type:markdown id: tags:
Filter functions can also be appended to the sequence after defining the class instance:
%% Cell type:code id: tags:
``` python
pipeline.append(filters.Maximum(size=5))
out_seq2 = pipeline(vol)
```
%% Cell type:markdown id: tags:
The filter objects are stored in the `filters` dictionary:
%% Cell type:code id: tags:
``` python
print(pipeline.filters)
```
%% Output
{'0': <qim3d.processing.filters.Gaussian object at 0x7b3fbdad7bb0>, '1': <qim3d.processing.filters.Median object at 0x7b3fbdad52a0>, '2': <qim3d.processing.filters.Maximum object at 0x7b40f7d3f6d0>}
......@@ -3,6 +3,7 @@ import qim3d.gui as gui
import qim3d.viz as viz
import qim3d.utils as utils
import qim3d.models as models
import qim3d.processing as processing
import logging
examples = io.ImgExamples()
from .filters import *
\ No newline at end of file
"""Provides filter functions and classes for image processing"""
from typing import Union, Type
import numpy as np
from scipy import ndimage
__all__ = ['Gaussian','Median','Maximum','Minimum','Pipeline','gaussian','median','maximum','minimum']
class FilterBase:
def __init__(self, *args, **kwargs):
"""
Base class for image filters.
Args:
*args: Additional positional arguments for filter initialization.
**kwargs: Additional keyword arguments for filter initialization.
"""
self.args = args
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input):
"""
Applies a Gaussian filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return gaussian(input, *self.args, **self.kwargs)
class Median(FilterBase):
def __call__(self, input):
"""
Applies a median filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return median(input, **self.kwargs)
class Maximum(FilterBase):
def __call__(self, input):
"""
Applies a maximum filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return maximum(input, **self.kwargs)
class Minimum(FilterBase):
def __call__(self, input):
"""
Applies a minimum filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return minimum(input, **self.kwargs)
class Pipeline:
def __init__(self, *args: Type[FilterBase]):
"""
Represents a sequence of image filters.
Args:
*args: Variable number of filter instances to be applied sequentially.
"""
self.filters = {}
for idx, fn in enumerate(args):
self._add_filter(str(idx), fn)
def _add_filter(self, name: str, fn: Type[FilterBase]):
"""
Adds a filter to the sequence.
Args:
name: A string representing the name or identifier of the filter.
fn: An instance of a FilterBase subclass.
Raises:
AssertionError: If `fn` is not an instance of the FilterBase class.
"""
if not isinstance(fn,FilterBase):
filter_names = [subclass.__name__ for subclass in FilterBase.__subclasses__()]
raise AssertionError(f'filters should be instances of one of the following classes: {filter_names}')
self.filters[name] = fn
def append(self, fn: Type[FilterBase]):
"""
Appends a filter to the end of the sequence.
Args:
fn: An instance of a FilterBase subclass to be appended.
"""
self._add_filter(str(len(self.filters)), fn)
def __call__(self, input):
"""
Applies the sequential filters to the input in order.
Args:
input: The input image or volume.
Returns:
The filtered image or volume after applying all sequential filters.
"""
for fn in self.filters.values():
input = fn(input)
return input
def gaussian(vol, *args, **kwargs):
"""
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter.
Args:
vol: The input image or volume.
*args: Additional positional arguments for the Gaussian filter.
**kwargs: Additional keyword arguments for the Gaussian filter.
Returns:
The filtered image or volume.
"""
return ndimage.gaussian_filter(vol, *args, **kwargs)
def median(vol, **kwargs):
"""
Applies a median filter to the input volume using scipy.ndimage.median_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the median filter.
Returns:
The filtered image or volume.
"""
return ndimage.median_filter(vol, **kwargs)
def maximum(vol, **kwargs):
"""
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the maximum filter.
Returns:
The filtered image or volume.
"""
return ndimage.maximum_filter(vol, **kwargs)
def minimum(vol, **kwargs):
"""
Applies a minimum filter to the input volume using scipy.ndimage.mainimum_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the minimum filter.
Returns:
The filtered image or volume.
"""
return ndimage.minimum_filter(vol, **kwargs)
\ No newline at end of file
import qim3d
from qim3d.processing.filters import *
import numpy as np
import pytest
import re
def test_filter_base_initialization():
filter_base = qim3d.processing.filters.FilterBase(3,size=2)
assert filter_base.args == (3,)
assert filter_base.kwargs == {'size': 2}
def test_gaussian_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = gaussian(input_image,sigma=1.5)
# Testing the class method
gaussian_filter_cls = Gaussian(sigma=1.5)
filtered_image_cls = gaussian_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn,filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_median_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = median(input_image, size=3)
# Testing the class method
median_filter_cls = Median(size=3)
filtered_image_cls = median_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_maximum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = maximum(input_image, size=3)
# Testing the class method
maximum_filter_cls = Maximum(size=3)
filtered_image_cls = maximum_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_minimum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = minimum(input_image, size=3)
# Testing the class method
minimum_filter_cls = Minimum(size=3)
filtered_image_cls = minimum_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_sequential_filter_pipeline():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
# Testing the sequential pipeline
sequential_pipeline = Sequential(gaussian_filter, median_filter, maximum_filter)
filtered_image_pipeline = sequential_pipeline(input_image)
# Testing the equivalence to maximum(median(gaussian(input,**kwargs),**kwargs),**kwargs)
expected_output = maximum(median(gaussian(input_image, sigma=1.5), size=3), size=3)
# Assertions
assert filtered_image_pipeline.shape == expected_output.shape == input_image.shape
assert not np.array_equal(filtered_image_pipeline, input_image)
assert np.array_equal(filtered_image_pipeline, expected_output)
def test_sequential_filter_appending():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
# Sequential pipeline with filter initialized at the beginning
sequential_pipeline_initial = Sequential(gaussian_filter, median_filter, maximum_filter)
filtered_image_initial = sequential_pipeline_initial(input_image)
# Sequential pipeline with filter appended
sequential_pipeline_appended = Sequential(gaussian_filter, median_filter)
sequential_pipeline_appended.append(maximum_filter)
filtered_image_appended = sequential_pipeline_appended(input_image)
# Assertions
assert filtered_image_initial.shape == filtered_image_appended.shape == input_image.shape
assert not np.array_equal(filtered_image_appended,input_image)
assert np.array_equal(filtered_image_initial, filtered_image_appended)
def test_assertion_error_not_filterbase_subclass():
# Get valid filter classes
valid_filters = [subclass.__name__ for subclass in qim3d.processing.filters.FilterBase.__subclasses__()]
# Create invalid object
invalid_filter = object() # An object that is not an instance of FilterBase
# Construct error message
message = f"filters should be instances of one of the following classes: {valid_filters}"
# Use pytest.raises to catch the AssertionError
with pytest.raises(AssertionError, match=re.escape(message)):
sequential_pipeline = Sequential(invalid_filter)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment