From a11f6b3d5be5c88f2d1410d021d957b607b94ab4 Mon Sep 17 00:00:00 2001 From: Felipe <fima@dtu.dk> Date: Mon, 24 Feb 2025 15:18:42 +0100 Subject: [PATCH] moving code to new branch --- qim3d/viz/__init__.py | 6 +- qim3d/viz/_data_exploration.py | 262 ++++++++++++++++++++++----------- 2 files changed, 174 insertions(+), 94 deletions(-) diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index e7ba84a..f785319 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -4,13 +4,11 @@ from ._data_exploration import ( chunks, fade_mask, histogram, + line_profile, slicer, slicer_orthogonal, slices_grid, - chunks, - histogram, - line_profile, - threshold + threshold, ) from ._detection import circles from ._k3d import mesh, volumetric diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index 855d730..84c0266 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -4,20 +4,17 @@ Provides a collection of visualization functions. import math import warnings - -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import dask.array as da import ipywidgets as widgets import matplotlib import matplotlib.figure import matplotlib.pyplot as plt -import matplotlib -from IPython.display import SVG, display, clear_output -import matplotlib import numpy as np import seaborn as sns import skimage.measure +from IPython.display import clear_output, display import qim3d from qim3d.utils._logger import log @@ -1009,13 +1006,23 @@ def histogram( if return_fig: return fig + class _LineProfile: - def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range): + def __init__( + self, + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ): self.volume = volume self.slice_axis = slice_axis self.dims = np.array(volume.shape) - self.pad = 1 # Padding on pivot point to avoid border issues + self.pad = 1 # Padding on pivot point to avoid border issues self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1] self.initialize_widgets() @@ -1025,7 +1032,7 @@ class _LineProfile: self.y_widget.value = vertical_position self.angle_widget.value = angle self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]] - + def update_slice_axis(self, slice_axis): self.slice_axis = slice_axis self.slice_index_widget.max = self.volume.shape[slice_axis] - 1 @@ -1039,35 +1046,44 @@ class _LineProfile: def initialize_widgets(self): layout = widgets.Layout(width='300px', height='auto') - self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) - self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout) - self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout) + self.x_widget = widgets.IntSlider( + min=self.pad, step=1, description='', layout=layout + ) + self.y_widget = widgets.IntSlider( + min=self.pad, step=1, description='', layout=layout + ) + self.angle_widget = widgets.IntSlider( + min=0, max=360, step=1, value=0, description='', layout=layout + ) self.line_fraction_widget = widgets.FloatRangeSlider( - min=0, max=1, step=0.01, value=[0, 1], - description="", layout=layout + min=0, max=1, step=0.01, value=[0, 1], description='', layout=layout ) - self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis') + self.slice_axis_widget = widgets.Dropdown( + options=[0, 1, 2], value=self.slice_axis, description='Slice axis' + ) self.slice_axis_widget.layout.width = '250px' - self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout) + self.slice_index_widget = widgets.IntSlider( + min=0, step=1, description='Slice index', layout=layout + ) self.slice_index_widget.layout.width = '400px' - + def calculate_line_endpoints(self, x, y, angle): """ Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)] """ if np.isclose(angle, 0): return [0, y], [self.x_max, y] - elif np.isclose(angle, np.pi/2): + elif np.isclose(angle, np.pi / 2): return [x, 0], [x, self.y_max] elif np.isclose(angle, np.pi): return [self.x_max, y], [0, y] - elif np.isclose(angle, 3*np.pi/2): + elif np.isclose(angle, 3 * np.pi / 2): return [x, self.y_max], [x, 0] - elif np.isclose(angle, 2*np.pi): + elif np.isclose(angle, 2 * np.pi): return [0, y], [self.x_max, y] - + t_left = -x / np.cos(angle) t_bottom = -y / np.sin(angle) t_right = (self.x_max - x) / np.cos(angle) @@ -1075,23 +1091,26 @@ class _LineProfile: t_values = np.array([t_left, t_top, t_right, t_bottom]) t_pos = np.min(t_values[t_values > 0]) t_neg = np.max(t_values[t_values < 0]) - + src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)] dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)] return src, dst - + def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range): if slice_axis != self.slice_axis: self.update_slice_axis(slice_axis) x = self.x_widget.value y = self.y_widget.value slice_index = self.slice_index_widget.value - + clear_output(wait=True) - + image = np.take(self.volume, slice_index, slice_axis) angle = np.radians(angle_deg) - src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)] + src, dst = ( + np.array(point, dtype='float32') + for point in self.calculate_line_endpoints(x, y, angle) + ) # Rescale endpoints line_vec = dst - src @@ -1101,30 +1120,37 @@ class _LineProfile: y_pline = skimage.measure.profile_line(image, src, dst) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) - + # Image with color-gradiented line num_segments = 100 x_seg = np.linspace(src[0], dst[0], num_segments) y_seg = np.linspace(src[1], dst[1], num_segments) - segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]), - np.column_stack([y_seg[2:], x_seg[2:]])], axis=1) - norm = plt.Normalize(vmin=0, vmax=num_segments-1) + segments = np.stack( + [ + np.column_stack([y_seg[:-2], x_seg[:-2]]), + np.column_stack([y_seg[2:], x_seg[2:]]), + ], + axis=1, + ) + norm = plt.Normalize(vmin=0, vmax=num_segments - 1) colors = self.cmap(norm(np.arange(num_segments - 1))) lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2) - ax[0].imshow(image,cmap='gray') + ax[0].imshow(image, cmap='gray') ax[0].add_collection(lc) # pivot point - ax[0].plot(y,x,marker='s', linestyle='', color='cyan', markersize=4) + ax[0].plot(y, x, marker='s', linestyle='', color='cyan', markersize=4) ax[0].set_xlabel(f'axis {np.delete(np.arange(3), self.slice_axis)[1]}') ax[0].set_ylabel(f'axis {np.delete(np.arange(3), self.slice_axis)[0]}') - + # Profile intensity plot norm = plt.Normalize(0, vmax=len(y_pline) - 1) x_pline = np.arange(len(y_pline)) points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :] segments = np.concatenate([points[:-1], points[1:]], axis=1) - lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2) + lc = matplotlib.collections.LineCollection( + segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2 + ) ax[1].add_collection(lc) ax[1].autoscale() @@ -1132,48 +1158,65 @@ class _LineProfile: ax[1].grid(True) plt.tight_layout() plt.show() - + def build_interactive(self): # Group widgets into two columns - title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;" - title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>") - title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>") + title_style = ( + 'text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;' + ) + title_column1 = widgets.HTML( + f"<div style='{title_style}'>Line parameterization</div>" + ) + title_column2 = widgets.HTML( + f"<div style='{title_style}'>Slice selection</div>" + ) # Make label widgets instead of descriptions which have different lengths. label_layout = widgets.Layout(width='120px') - label_x = widgets.Label("Vertical position", layout=label_layout) - label_y = widgets.Label("Horizontal position", layout=label_layout) - label_angle = widgets.Label("Angle (°)", layout=label_layout) - label_fraction = widgets.Label("Fraction range", layout=label_layout) + label_x = widgets.Label('Vertical position', layout=label_layout) + label_y = widgets.Label('Horizontal position', layout=label_layout) + label_angle = widgets.Label('Angle (°)', layout=label_layout) + label_fraction = widgets.Label('Fraction range', layout=label_layout) row_x = widgets.HBox([label_x, self.x_widget]) row_y = widgets.HBox([label_y, self.y_widget]) row_angle = widgets.HBox([label_angle, self.angle_widget]) row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget]) - controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction]) - controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget]) + controls_column1 = widgets.VBox( + [title_column1, row_x, row_y, row_angle, row_fraction] + ) + controls_column2 = widgets.VBox( + [title_column2, self.slice_axis_widget, self.slice_index_widget] + ) controls = widgets.HBox([controls_column1, controls_column2]) interactive_plot = widgets.interactive_output( - self.update, - {'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget, - 'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget, - 'fraction_range': self.line_fraction_widget} + self.update, + { + 'slice_axis': self.slice_axis_widget, + 'slice_index': self.slice_index_widget, + 'x': self.x_widget, + 'y': self.y_widget, + 'angle_deg': self.angle_widget, + 'fraction_range': self.line_fraction_widget, + }, ) return widgets.VBox([controls, interactive_plot]) + def line_profile( - volume: np.ndarray, - slice_axis: int=0, - slice_index: int | str='middle', - vertical_position: int | str='middle', - horizontal_position: int | str='middle', - angle: int=0, - fraction_range: Tuple[float,float]=(0.00, 1.00) - ) -> widgets.interactive: - """Returns an interactive widget for visualizing the intensity profiles of lines on slices. + volume: np.ndarray, + slice_axis: int = 0, + slice_index: int | str = 'middle', + vertical_position: int | str = 'middle', + horizontal_position: int | str = 'middle', + angle: int = 0, + fraction_range: Tuple[float, float] = (0.00, 1.00), +) -> widgets.interactive: + """ + Returns an interactive widget for visualizing the intensity profiles of lines on slices. Args: volume (np.ndarray): The 3D volume of interest. @@ -1181,13 +1224,13 @@ def line_profile( slice_index (int or str, optional): Specifies the initial slice index along slice_axis. vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point. horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point. - angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. + angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0]. Returns: widget (widgets.widget_box.VBox): The interactive widget. - + Example: ```python import qim3d @@ -1198,57 +1241,82 @@ def line_profile(  """ + def parse_position(pos, pos_range, name): if isinstance(pos, int): if not pos_range[0] <= pos < pos_range[1]: - raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]') + raise ValueError( + f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]' + ) return pos elif isinstance(pos, str): pos = pos.lower() - if pos == 'start': return pos_range[0] - elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 - elif pos == 'end': return pos_range[1] + if pos == 'start': + return pos_range[0] + elif pos == 'middle': + return pos_range[0] + (pos_range[1] - pos_range[0]) // 2 + elif pos == 'end': + return pos_range[1] else: raise ValueError( f"Invalid string '{pos}' for {name}. " "Must be 'start', 'middle', or 'end'." ) else: - raise TypeError(f'Axis position must be of type int or str.') - + raise TypeError('Axis position must be of type int or str.') + if not isinstance(volume, (np.ndarray, da.core.Array)): - raise ValueError("Data type for volume not supported.") + raise ValueError('Data type for volume not supported.') if volume.ndim != 3: - raise ValueError("Volume must be 3D.") - + raise ValueError('Volume must be 3D.') + dims = volume.shape slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index') # the omission of the ends for the pivot point is due to border issues. - vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position') - horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position') - + vertical_position = parse_position( + vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position' + ) + horizontal_position = parse_position( + horizontal_position, + (1, np.delete(dims, slice_axis)[1] - 2), + 'horizontal_position', + ) + if not isinstance(angle, int | float): - raise ValueError("Invalid type for angle.") + raise ValueError('Invalid type for angle.') angle = round(angle) % 360 - if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]): - raise ValueError("Invalid values for fraction_range.") + if not ( + 0.0 <= fraction_range[0] <= 1.0 + and 0.0 <= fraction_range[1] <= 1.0 + and fraction_range[0] <= fraction_range[1] + ): + raise ValueError('Invalid values for fraction_range.') - lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range) + lp = _LineProfile( + volume, + slice_axis, + slice_index, + vertical_position, + horizontal_position, + angle, + fraction_range, + ) return lp.build_interactive() + def threshold( - volume: np.ndarray, - cmap_image: str = 'viridis', - vmin: float = None, - vmax: float = None, + volume: np.ndarray, + cmap_image: str = 'viridis', + vmin: float = None, + vmax: float = None, ) -> widgets.VBox: """ - This function provides an interactive interface to explore thresholding on a - 3D volume slice-by-slice. Users can either manually set the threshold value - using a slider or select an automatic thresholding method from `skimage`. + This function provides an interactive interface to explore thresholding on a + 3D volume slice-by-slice. Users can either manually set the threshold value + using a slider or select an automatic thresholding method from `skimage`. - The visualization includes the original image slice, a binary mask showing regions above the + The visualization includes the original image slice, a binary mask showing regions above the threshold and an overlay combining the binary mask and the original image. Args: @@ -1257,16 +1325,16 @@ def threshold( cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'. vmin (float, optional): Minimum value for the colormap. Defaults to None. vmax (float, optional): Maximum value for the colormap. Defaults to None. - + Returns: slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume. - + Interactivity: - **Manual Thresholding**: - Select 'Manual' from the dropdown menu to manually adjust the threshold + Select 'Manual' from the dropdown menu to manually adjust the threshold using the slider. - **Automatic Thresholding**: - Choose a method from the dropdown menu to apply an automatic thresholding + Choose a method from the dropdown menu to apply an automatic thresholding algorithm. Available methods include: - Otsu - Isodata @@ -1276,7 +1344,7 @@ def threshold( - Triangle - Yen - The threshold slider will display the computed value and will be disabled + The threshold slider will display the computed value and will be disabled in this mode. @@ -1290,6 +1358,7 @@ def threshold( qim3d.viz.threshold(vol) ```  + """ # Centralized state dictionary to track current parameters @@ -1348,10 +1417,14 @@ def threshold( # Original image new_vmin = ( - None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin + None + if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) + else vmin ) new_vmax = ( - None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax + None + if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) + else vmax ) axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) axes[0].set_title('Original') @@ -1405,7 +1478,16 @@ def threshold( ) method_dropdown = widgets.Dropdown( - options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'], + options=[ + 'Manual', + 'Otsu', + 'Isodata', + 'Li', + 'Mean', + 'Minimum', + 'Triangle', + 'Yen', + ], value=state['method'], description='Method', ) -- GitLab