Skip to content
Snippets Groups Projects
Commit 995ab1e0 authored by s184058's avatar s184058 Committed by fima
Browse files

3D image filters

parent d2a07ec8
Branches
No related tags found
1 merge request!493D image filters
%% Cell type:code id: tags:
``` python
import qim3d
import qim3d.processing.filters as filters
import numpy as np
from scipy import ndimage
```
%% Cell type:code id: tags:
``` python
vol = qim3d.examples.fly_150x256x256
```
%% Cell type:markdown id: tags:
## Using the filter functions directly
%% Cell type:code id: tags:
``` python
### Gaussian filter
out1_gauss = filters.gaussian(vol,3)
# or
out2_gauss = filters.gaussian(vol,sigma=3) # sigma is positional, but can be passed as a kwarg
### Median filter
out_median = filters.median(vol,size=5)
```
%% Cell type:markdown id: tags:
## Using filter classes
%% Cell type:code id: tags:
``` python
gaussian_fn = filters.Gaussian(sigma=3)
out3_gauss = gaussian_fn(vol)
```
%% Cell type:markdown id: tags:
## Using filter classes to construct a pipeline of filters
%% Cell type:code id: tags:
``` python
pipeline = filters.Pipeline(
filters.Gaussian(sigma=3),
filters.Median(size=10))
out_seq = pipeline(vol)
```
%% Cell type:markdown id: tags:
Filter functions can also be appended to the sequence after defining the class instance:
%% Cell type:code id: tags:
``` python
pipeline.append(filters.Maximum(size=5))
out_seq2 = pipeline(vol)
```
%% Cell type:markdown id: tags:
The filter objects are stored in the `filters` dictionary:
%% Cell type:code id: tags:
``` python
print(pipeline.filters)
```
%% Output
{'0': <qim3d.processing.filters.Gaussian object at 0x7b3fbdad7bb0>, '1': <qim3d.processing.filters.Median object at 0x7b3fbdad52a0>, '2': <qim3d.processing.filters.Maximum object at 0x7b40f7d3f6d0>}
......@@ -3,6 +3,7 @@ import qim3d.gui as gui
import qim3d.viz as viz
import qim3d.utils as utils
import qim3d.models as models
import qim3d.processing as processing
import logging
examples = io.ImgExamples()
from .filters import *
\ No newline at end of file
"""Provides filter functions and classes for image processing"""
from typing import Union, Type
import numpy as np
from scipy import ndimage
__all__ = ['Gaussian','Median','Maximum','Minimum','Pipeline','gaussian','median','maximum','minimum']
class FilterBase:
def __init__(self, *args, **kwargs):
"""
Base class for image filters.
Args:
*args: Additional positional arguments for filter initialization.
**kwargs: Additional keyword arguments for filter initialization.
"""
self.args = args
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input):
"""
Applies a Gaussian filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return gaussian(input, *self.args, **self.kwargs)
class Median(FilterBase):
def __call__(self, input):
"""
Applies a median filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return median(input, **self.kwargs)
class Maximum(FilterBase):
def __call__(self, input):
"""
Applies a maximum filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return maximum(input, **self.kwargs)
class Minimum(FilterBase):
def __call__(self, input):
"""
Applies a minimum filter to the input.
Args:
input: The input image or volume.
Returns:
The filtered image or volume.
"""
return minimum(input, **self.kwargs)
class Pipeline:
def __init__(self, *args: Type[FilterBase]):
"""
Represents a sequence of image filters.
Args:
*args: Variable number of filter instances to be applied sequentially.
"""
self.filters = {}
for idx, fn in enumerate(args):
self._add_filter(str(idx), fn)
def _add_filter(self, name: str, fn: Type[FilterBase]):
"""
Adds a filter to the sequence.
Args:
name: A string representing the name or identifier of the filter.
fn: An instance of a FilterBase subclass.
Raises:
AssertionError: If `fn` is not an instance of the FilterBase class.
"""
if not isinstance(fn,FilterBase):
filter_names = [subclass.__name__ for subclass in FilterBase.__subclasses__()]
raise AssertionError(f'filters should be instances of one of the following classes: {filter_names}')
self.filters[name] = fn
def append(self, fn: Type[FilterBase]):
"""
Appends a filter to the end of the sequence.
Args:
fn: An instance of a FilterBase subclass to be appended.
"""
self._add_filter(str(len(self.filters)), fn)
def __call__(self, input):
"""
Applies the sequential filters to the input in order.
Args:
input: The input image or volume.
Returns:
The filtered image or volume after applying all sequential filters.
"""
for fn in self.filters.values():
input = fn(input)
return input
def gaussian(vol, *args, **kwargs):
"""
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter.
Args:
vol: The input image or volume.
*args: Additional positional arguments for the Gaussian filter.
**kwargs: Additional keyword arguments for the Gaussian filter.
Returns:
The filtered image or volume.
"""
return ndimage.gaussian_filter(vol, *args, **kwargs)
def median(vol, **kwargs):
"""
Applies a median filter to the input volume using scipy.ndimage.median_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the median filter.
Returns:
The filtered image or volume.
"""
return ndimage.median_filter(vol, **kwargs)
def maximum(vol, **kwargs):
"""
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the maximum filter.
Returns:
The filtered image or volume.
"""
return ndimage.maximum_filter(vol, **kwargs)
def minimum(vol, **kwargs):
"""
Applies a minimum filter to the input volume using scipy.ndimage.mainimum_filter.
Args:
vol: The input image or volume.
**kwargs: Additional keyword arguments for the minimum filter.
Returns:
The filtered image or volume.
"""
return ndimage.minimum_filter(vol, **kwargs)
\ No newline at end of file
import qim3d
from qim3d.processing.filters import *
import numpy as np
import pytest
import re
def test_filter_base_initialization():
filter_base = qim3d.processing.filters.FilterBase(3,size=2)
assert filter_base.args == (3,)
assert filter_base.kwargs == {'size': 2}
def test_gaussian_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = gaussian(input_image,sigma=1.5)
# Testing the class method
gaussian_filter_cls = Gaussian(sigma=1.5)
filtered_image_cls = gaussian_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn,filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_median_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = median(input_image, size=3)
# Testing the class method
median_filter_cls = Median(size=3)
filtered_image_cls = median_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_maximum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = maximum(input_image, size=3)
# Testing the class method
maximum_filter_cls = Maximum(size=3)
filtered_image_cls = maximum_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_minimum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = minimum(input_image, size=3)
# Testing the class method
minimum_filter_cls = Minimum(size=3)
filtered_image_cls = minimum_filter_cls(input_image)
# Assertions
assert filtered_image_cls.shape == filtered_image_fn.shape == input_image.shape
assert np.array_equal(filtered_image_fn, filtered_image_cls)
assert not np.array_equal(filtered_image_fn, input_image)
def test_sequential_filter_pipeline():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
# Testing the sequential pipeline
sequential_pipeline = Sequential(gaussian_filter, median_filter, maximum_filter)
filtered_image_pipeline = sequential_pipeline(input_image)
# Testing the equivalence to maximum(median(gaussian(input,**kwargs),**kwargs),**kwargs)
expected_output = maximum(median(gaussian(input_image, sigma=1.5), size=3), size=3)
# Assertions
assert filtered_image_pipeline.shape == expected_output.shape == input_image.shape
assert not np.array_equal(filtered_image_pipeline, input_image)
assert np.array_equal(filtered_image_pipeline, expected_output)
def test_sequential_filter_appending():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
# Sequential pipeline with filter initialized at the beginning
sequential_pipeline_initial = Sequential(gaussian_filter, median_filter, maximum_filter)
filtered_image_initial = sequential_pipeline_initial(input_image)
# Sequential pipeline with filter appended
sequential_pipeline_appended = Sequential(gaussian_filter, median_filter)
sequential_pipeline_appended.append(maximum_filter)
filtered_image_appended = sequential_pipeline_appended(input_image)
# Assertions
assert filtered_image_initial.shape == filtered_image_appended.shape == input_image.shape
assert not np.array_equal(filtered_image_appended,input_image)
assert np.array_equal(filtered_image_initial, filtered_image_appended)
def test_assertion_error_not_filterbase_subclass():
# Get valid filter classes
valid_filters = [subclass.__name__ for subclass in qim3d.processing.filters.FilterBase.__subclasses__()]
# Create invalid object
invalid_filter = object() # An object that is not an instance of FilterBase
# Construct error message
message = f"filters should be instances of one of the following classes: {valid_filters}"
# Use pytest.raises to catch the AssertionError
with pytest.raises(AssertionError, match=re.escape(message)):
sequential_pipeline = Sequential(invalid_filter)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment