Skip to content
Snippets Groups Projects
Commit a11f6b3d authored by fima's avatar fima :beers:
Browse files

moving code to new branch

parent 909918e3
No related branches found
No related tags found
1 merge request!157Threshold for v1
......@@ -4,13 +4,11 @@ from ._data_exploration import (
chunks,
fade_mask,
histogram,
line_profile,
slicer,
slicer_orthogonal,
slices_grid,
chunks,
histogram,
line_profile,
threshold
threshold,
)
from ._detection import circles
from ._k3d import mesh, volumetric
......
......@@ -4,20 +4,17 @@ Provides a collection of visualization functions.
import math
import warnings
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Tuple, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib
import matplotlib.figure
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import SVG, display, clear_output
import matplotlib
import numpy as np
import seaborn as sns
import skimage.measure
from IPython.display import clear_output, display
import qim3d
from qim3d.utils._logger import log
......@@ -1009,8 +1006,18 @@ def histogram(
if return_fig:
return fig
class _LineProfile:
def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range):
def __init__(
self,
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
):
self.volume = volume
self.slice_axis = slice_axis
......@@ -1039,18 +1046,27 @@ class _LineProfile:
def initialize_widgets(self):
layout = widgets.Layout(width='300px', height='auto')
self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout)
self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout)
self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout)
self.x_widget = widgets.IntSlider(
min=self.pad, step=1, description='', layout=layout
)
self.y_widget = widgets.IntSlider(
min=self.pad, step=1, description='', layout=layout
)
self.angle_widget = widgets.IntSlider(
min=0, max=360, step=1, value=0, description='', layout=layout
)
self.line_fraction_widget = widgets.FloatRangeSlider(
min=0, max=1, step=0.01, value=[0, 1],
description="", layout=layout
min=0, max=1, step=0.01, value=[0, 1], description='', layout=layout
)
self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis')
self.slice_axis_widget = widgets.Dropdown(
options=[0, 1, 2], value=self.slice_axis, description='Slice axis'
)
self.slice_axis_widget.layout.width = '250px'
self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout)
self.slice_index_widget = widgets.IntSlider(
min=0, step=1, description='Slice index', layout=layout
)
self.slice_index_widget.layout.width = '400px'
def calculate_line_endpoints(self, x, y, angle):
......@@ -1091,7 +1107,10 @@ class _LineProfile:
image = np.take(self.volume, slice_index, slice_axis)
angle = np.radians(angle_deg)
src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)]
src, dst = (
np.array(point, dtype='float32')
for point in self.calculate_line_endpoints(x, y, angle)
)
# Rescale endpoints
line_vec = dst - src
......@@ -1106,8 +1125,13 @@ class _LineProfile:
num_segments = 100
x_seg = np.linspace(src[0], dst[0], num_segments)
y_seg = np.linspace(src[1], dst[1], num_segments)
segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]),
np.column_stack([y_seg[2:], x_seg[2:]])], axis=1)
segments = np.stack(
[
np.column_stack([y_seg[:-2], x_seg[:-2]]),
np.column_stack([y_seg[2:], x_seg[2:]]),
],
axis=1,
)
norm = plt.Normalize(vmin=0, vmax=num_segments - 1)
colors = self.cmap(norm(np.arange(num_segments - 1)))
lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2)
......@@ -1124,7 +1148,9 @@ class _LineProfile:
x_pline = np.arange(len(y_pline))
points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :]
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2)
lc = matplotlib.collections.LineCollection(
segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2
)
ax[1].add_collection(lc)
ax[1].autoscale()
......@@ -1135,35 +1161,51 @@ class _LineProfile:
def build_interactive(self):
# Group widgets into two columns
title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;"
title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>")
title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>")
title_style = (
'text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;'
)
title_column1 = widgets.HTML(
f"<div style='{title_style}'>Line parameterization</div>"
)
title_column2 = widgets.HTML(
f"<div style='{title_style}'>Slice selection</div>"
)
# Make label widgets instead of descriptions which have different lengths.
label_layout = widgets.Layout(width='120px')
label_x = widgets.Label("Vertical position", layout=label_layout)
label_y = widgets.Label("Horizontal position", layout=label_layout)
label_angle = widgets.Label("Angle (°)", layout=label_layout)
label_fraction = widgets.Label("Fraction range", layout=label_layout)
label_x = widgets.Label('Vertical position', layout=label_layout)
label_y = widgets.Label('Horizontal position', layout=label_layout)
label_angle = widgets.Label('Angle (°)', layout=label_layout)
label_fraction = widgets.Label('Fraction range', layout=label_layout)
row_x = widgets.HBox([label_x, self.x_widget])
row_y = widgets.HBox([label_y, self.y_widget])
row_angle = widgets.HBox([label_angle, self.angle_widget])
row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget])
controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction])
controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget])
controls_column1 = widgets.VBox(
[title_column1, row_x, row_y, row_angle, row_fraction]
)
controls_column2 = widgets.VBox(
[title_column2, self.slice_axis_widget, self.slice_index_widget]
)
controls = widgets.HBox([controls_column1, controls_column2])
interactive_plot = widgets.interactive_output(
self.update,
{'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget,
'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget,
'fraction_range': self.line_fraction_widget}
{
'slice_axis': self.slice_axis_widget,
'slice_index': self.slice_index_widget,
'x': self.x_widget,
'y': self.y_widget,
'angle_deg': self.angle_widget,
'fraction_range': self.line_fraction_widget,
},
)
return widgets.VBox([controls, interactive_plot])
def line_profile(
volume: np.ndarray,
slice_axis: int = 0,
......@@ -1171,9 +1213,10 @@ def line_profile(
vertical_position: int | str = 'middle',
horizontal_position: int | str = 'middle',
angle: int = 0,
fraction_range: Tuple[float,float]=(0.00, 1.00)
fraction_range: Tuple[float, float] = (0.00, 1.00),
) -> widgets.interactive:
"""Returns an interactive widget for visualizing the intensity profiles of lines on slices.
"""
Returns an interactive widget for visualizing the intensity profiles of lines on slices.
Args:
volume (np.ndarray): The 3D volume of interest.
......@@ -1198,45 +1241,70 @@ def line_profile(
![viz histogram](../../assets/screenshots/viz-line_profile.gif)
"""
def parse_position(pos, pos_range, name):
if isinstance(pos, int):
if not pos_range[0] <= pos < pos_range[1]:
raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]')
raise ValueError(
f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]'
)
return pos
elif isinstance(pos, str):
pos = pos.lower()
if pos == 'start': return pos_range[0]
elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
elif pos == 'end': return pos_range[1]
if pos == 'start':
return pos_range[0]
elif pos == 'middle':
return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
elif pos == 'end':
return pos_range[1]
else:
raise ValueError(
f"Invalid string '{pos}' for {name}. "
"Must be 'start', 'middle', or 'end'."
)
else:
raise TypeError(f'Axis position must be of type int or str.')
raise TypeError('Axis position must be of type int or str.')
if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError("Data type for volume not supported.")
raise ValueError('Data type for volume not supported.')
if volume.ndim != 3:
raise ValueError("Volume must be 3D.")
raise ValueError('Volume must be 3D.')
dims = volume.shape
slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index')
# the omission of the ends for the pivot point is due to border issues.
vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position')
horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position')
vertical_position = parse_position(
vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position'
)
horizontal_position = parse_position(
horizontal_position,
(1, np.delete(dims, slice_axis)[1] - 2),
'horizontal_position',
)
if not isinstance(angle, int | float):
raise ValueError("Invalid type for angle.")
raise ValueError('Invalid type for angle.')
angle = round(angle) % 360
if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]):
raise ValueError("Invalid values for fraction_range.")
if not (
0.0 <= fraction_range[0] <= 1.0
and 0.0 <= fraction_range[1] <= 1.0
and fraction_range[0] <= fraction_range[1]
):
raise ValueError('Invalid values for fraction_range.')
lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range)
lp = _LineProfile(
volume,
slice_axis,
slice_index,
vertical_position,
horizontal_position,
angle,
fraction_range,
)
return lp.build_interactive()
def threshold(
volume: np.ndarray,
cmap_image: str = 'viridis',
......@@ -1290,6 +1358,7 @@ def threshold(
qim3d.viz.threshold(vol)
```
![interactive threshold](../../assets/screenshots/interactive_thresholding.gif)
"""
# Centralized state dictionary to track current parameters
......@@ -1348,10 +1417,14 @@ def threshold(
# Original image
new_vmin = (
None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title('Original')
......@@ -1405,7 +1478,16 @@ def threshold(
)
method_dropdown = widgets.Dropdown(
options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
options=[
'Manual',
'Otsu',
'Isodata',
'Li',
'Mean',
'Minimum',
'Triangle',
'Yen',
],
value=state['method'],
description='Method',
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment