diff --git a/docs/assets/screenshots/viz-line_profile.gif b/docs/assets/screenshots/viz-line_profile.gif new file mode 100644 index 0000000000000000000000000000000000000000..74b322faa292ad685d1d304c1e0f276114064d1a Binary files /dev/null and b/docs/assets/screenshots/viz-line_profile.gif differ diff --git a/docs/doc/visualization/viz.md b/docs/doc/visualization/viz.md index 0f22a927ac31aaab6a0940fdbaa075caab389d0f..7a4b39a8ec49409fe51b190a079971a7568f852b 100644 --- a/docs/doc/visualization/viz.md +++ b/docs/doc/visualization/viz.md @@ -23,6 +23,7 @@ The `qim3d` library aims to provide easy ways to explore and get insights from v - plot_cc - colormaps - fade_mask + - line_profile ::: qim3d.viz.colormaps options: diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 8a25ebe10d1d756e6a7d177a7c565bc73732e21b..a2be34e4c2addf8b4641d9d9c7d3a292e9423581 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -7,6 +7,9 @@ from ._data_exploration import ( slicer, slicer_orthogonal, slices_grid, + chunks, + histogram, + line_profile ) from ._detection import circles from ._k3d import mesh, volumetric diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index dd85f8cfb0fe02d2d255ba0a60e468d74a411513..b6ae03ac8f3e4494969ad84339cd3bad1bdd9de4 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -4,17 +4,20 @@ Provides a collection of visualization functions. import math import warnings -from typing import List, Optional, Union + +from typing import List, Optional, Union, Tuple import dask.array as da import ipywidgets as widgets import matplotlib import matplotlib.figure import matplotlib.pyplot as plt +import matplotlib +from IPython.display import SVG, display, clear_output +import matplotlib import numpy as np import seaborn as sns -import zarr -from IPython.display import display +import skimage.measure import qim3d from qim3d.utils._logger import log @@ -1005,3 +1008,231 @@ def histogram( if return_fig: return fig + +class _LineProfile: + def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range): + self.volume = volume + self.slice_axis = slice_axis + + self.dims = np.array(volume.shape) + self.pad = 1 # Padding on pivot point to avoid border issues + self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1] + + self.initialize_widgets() + self.update_slice_axis(slice_axis) + self.slice_index_widget.value = slice_index + self.x_widget.value = horizontal_position + self.y_widget.value = vertical_position + self.angle_widget.value = angle + self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]] + + def update_slice_axis(self, slice_axis): + self.slice_axis = slice_axis + self.slice_index_widget.max = self.volume.shape[slice_axis] - 1 + self.slice_index_widget.value = self.volume.shape[slice_axis] // 2 + + self.x_max, self.y_max = np.delete(self.dims, self.slice_axis) - 1 + self.x_widget.max = self.x_max - self.pad + self.x_widget.value = self.x_max // 2 + self.y_widget.max = self.y_max - self.pad + self.y_widget.value = self.y_max // 2 + + def initialize_widgets(self): + layout = widgets.Layout(width='300px', height='auto') + self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) + self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) + self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout) + self.line_fraction_widget = widgets.FloatRangeSlider( + min=0, max=1, step=0.01, value=[0, 1], + description="", layout=layout + ) + + self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis') + self.slice_axis_widget.layout.width = '250px' + + self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout) + self.slice_index_widget.layout.width = '400px' + + def calculate_line_endpoints(self, x, y, angle): + """ + Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)] + """ + if np.isclose(angle, 0): + return [0, y], [self.x_max, y] + elif np.isclose(angle, np.pi/2): + return [x, 0], [x, self.y_max] + elif np.isclose(angle, np.pi): + return [self.x_max, y], [0, y] + elif np.isclose(angle, 3*np.pi/2): + return [x, self.y_max], [x, 0] + elif np.isclose(angle, 2*np.pi): + return [0, y], [self.x_max, y] + + t_left = -x / np.cos(angle) + t_bottom = -y / np.sin(angle) + t_right = (self.x_max - x) / np.cos(angle) + t_top = (self.y_max - y) / np.sin(angle) + t_values = np.array([t_left, t_top, t_right, t_bottom]) + t_pos = np.min(t_values[t_values > 0]) + t_neg = np.max(t_values[t_values < 0]) + + src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)] + dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)] + return src, dst + + def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range): + if slice_axis != self.slice_axis: + self.update_slice_axis(slice_axis) + x = self.x_widget.value + y = self.y_widget.value + slice_index = self.slice_index_widget.value + + clear_output(wait=True) + + image = np.take(self.volume, slice_index, slice_axis) + angle = np.radians(angle_deg) + src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)] + + # Rescale endpoints + line_vec = dst - src + dst = src + fraction_range[1] * line_vec + src = src + fraction_range[0] * line_vec + + y_pline = skimage.measure.profile_line(image, src, dst) + + fig, ax = plt.subplots(1, 2, figsize=(10, 5)) + + # Image with color-gradiented line + num_segments = 100 + x_seg = np.linspace(src[0], dst[0], num_segments) + y_seg = np.linspace(src[1], dst[1], num_segments) + segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]), + np.column_stack([y_seg[2:], x_seg[2:]])], axis=1) + norm = plt.Normalize(vmin=0, vmax=num_segments-1) + colors = self.cmap(norm(np.arange(num_segments - 1))) + lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2) + + ax[0].imshow(image,cmap='gray') + ax[0].add_collection(lc) + # pivot point + ax[0].plot(y,x,marker='s', linestyle='', color='cyan', markersize=4) + ax[0].set_xlabel(f'axis {np.delete(np.arange(3), self.slice_axis)[1]}') + ax[0].set_ylabel(f'axis {np.delete(np.arange(3), self.slice_axis)[0]}') + + # Profile intensity plot + norm = plt.Normalize(0, vmax=len(y_pline) - 1) + x_pline = np.arange(len(y_pline)) + points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :] + segments = np.concatenate([points[:-1], points[1:]], axis=1) + lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2) + + ax[1].add_collection(lc) + ax[1].autoscale() + ax[1].set_xlabel('Distance along line') + ax[1].grid(True) + plt.tight_layout() + plt.show() + + def build_interactive(self): + # Group widgets into two columns + title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" + title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>") + title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>") + + # Make label widgets instead of descriptions which have different lengths. + label_layout = widgets.Layout(width='120px') + label_x = widgets.Label("Vertical position", layout=label_layout) + label_y = widgets.Label("Horizontal position", layout=label_layout) + label_angle = widgets.Label("Angle (°)", layout=label_layout) + label_fraction = widgets.Label("Fraction range", layout=label_layout) + + row_x = widgets.HBox([label_x, self.x_widget]) + row_y = widgets.HBox([label_y, self.y_widget]) + row_angle = widgets.HBox([label_angle, self.angle_widget]) + row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget]) + + controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction]) + controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget]) + controls = widgets.HBox([controls_column1, controls_column2]) + + interactive_plot = widgets.interactive_output( + self.update, + {'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget, + 'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget, + 'fraction_range': self.line_fraction_widget} + ) + + return widgets.VBox([controls, interactive_plot]) + +def line_profile( + volume: np.ndarray, + slice_axis: int=0, + slice_index: int | str='middle', + vertical_position: int | str='middle', + horizontal_position: int | str='middle', + angle: int=0, + fraction_range: Tuple[float,float]=(0.00, 1.00) + ) -> widgets.interactive: + """Returns an interactive widget for visualizing the intensity profiles of lines on slices. + + Args: + volume (np.ndarray): The 3D volume of interest. + slice_axis (int, optional): Specifies the initial axis along which to slice. + slice_index (int or str, optional): Specifies the initial slice index along slice_axis. + vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point. + horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point. + angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. + fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0]. + + Returns: + widget (widgets.widget_box.VBox): The interactive widget. + + + Example: + ```python + import qim3d + + vol = qim3d.examples.bone_128x128x128 + qim3d.viz.line_profile(vol) + ``` +  + + """ + def parse_position(pos, pos_range, name): + if isinstance(pos, int): + if not pos_range[0] <= pos < pos_range[1]: + raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]') + return pos + elif isinstance(pos, str): + pos = pos.lower() + if pos == 'start': return pos_range[0] + elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 + elif pos == 'end': return pos_range[1] + else: + raise ValueError( + f"Invalid string '{pos}' for {name}. " + "Must be 'start', 'middle', or 'end'." + ) + else: + raise TypeError(f'Axis position must be of type int or str.') + + if not isinstance(volume, (np.ndarray, da.core.Array)): + raise ValueError("Data type for volume not supported.") + if volume.ndim != 3: + raise ValueError("Volume must be 3D.") + + dims = volume.shape + slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index') + # the omission of the ends for the pivot point is due to border issues. + vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position') + horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position') + + if not isinstance(angle, int | float): + raise ValueError("Invalid type for angle.") + angle = round(angle) % 360 + + if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]): + raise ValueError("Invalid values for fraction_range.") + + lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range) + return lp.build_interactive()