From a11f6b3d5be5c88f2d1410d021d957b607b94ab4 Mon Sep 17 00:00:00 2001
From: Felipe <fima@dtu.dk>
Date: Mon, 24 Feb 2025 15:18:42 +0100
Subject: [PATCH] moving code to new branch

---
 qim3d/viz/__init__.py          |   6 +-
 qim3d/viz/_data_exploration.py | 262 ++++++++++++++++++++++-----------
 2 files changed, 174 insertions(+), 94 deletions(-)

diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index e7ba84a..f785319 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -4,13 +4,11 @@ from ._data_exploration import (
     chunks,
     fade_mask,
     histogram,
+    line_profile,
     slicer,
     slicer_orthogonal,
     slices_grid,
-    chunks,
-    histogram,
-    line_profile,
-    threshold
+    threshold,
 )
 from ._detection import circles
 from ._k3d import mesh, volumetric
diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py
index 855d730..84c0266 100644
--- a/qim3d/viz/_data_exploration.py
+++ b/qim3d/viz/_data_exploration.py
@@ -4,20 +4,17 @@ Provides a collection of visualization functions.
 
 import math
 import warnings
-
-from typing import List, Optional, Union, Tuple
+from typing import List, Optional, Tuple, Union
 
 import dask.array as da
 import ipywidgets as widgets
 import matplotlib
 import matplotlib.figure
 import matplotlib.pyplot as plt
-import matplotlib
-from IPython.display import SVG, display, clear_output
-import matplotlib
 import numpy as np
 import seaborn as sns
 import skimage.measure
+from IPython.display import clear_output, display
 
 import qim3d
 from qim3d.utils._logger import log
@@ -1009,13 +1006,23 @@ def histogram(
     if return_fig:
         return fig
 
+
 class _LineProfile:
-    def __init__(self, volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range):
+    def __init__(
+        self,
+        volume,
+        slice_axis,
+        slice_index,
+        vertical_position,
+        horizontal_position,
+        angle,
+        fraction_range,
+    ):
         self.volume = volume
         self.slice_axis = slice_axis
 
         self.dims = np.array(volume.shape)
-        self.pad = 1 # Padding on pivot point to avoid border issues
+        self.pad = 1  # Padding on pivot point to avoid border issues
         self.cmap = [matplotlib.cm.plasma, matplotlib.cm.spring][1]
 
         self.initialize_widgets()
@@ -1025,7 +1032,7 @@ class _LineProfile:
         self.y_widget.value = vertical_position
         self.angle_widget.value = angle
         self.line_fraction_widget.value = [fraction_range[0], fraction_range[1]]
-    
+
     def update_slice_axis(self, slice_axis):
         self.slice_axis = slice_axis
         self.slice_index_widget.max = self.volume.shape[slice_axis] - 1
@@ -1039,35 +1046,44 @@ class _LineProfile:
 
     def initialize_widgets(self):
         layout = widgets.Layout(width='300px', height='auto')
-        self.x_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout)
-        self.y_widget = widgets.IntSlider(min=self.pad, step=1, description="", layout=layout)
-        self.angle_widget = widgets.IntSlider(min=0, max=360, step=1, value=0, description="", layout=layout)
+        self.x_widget = widgets.IntSlider(
+            min=self.pad, step=1, description='', layout=layout
+        )
+        self.y_widget = widgets.IntSlider(
+            min=self.pad, step=1, description='', layout=layout
+        )
+        self.angle_widget = widgets.IntSlider(
+            min=0, max=360, step=1, value=0, description='', layout=layout
+        )
         self.line_fraction_widget = widgets.FloatRangeSlider(
-            min=0, max=1, step=0.01, value=[0, 1], 
-            description="", layout=layout
+            min=0, max=1, step=0.01, value=[0, 1], description='', layout=layout
         )
 
-        self.slice_axis_widget = widgets.Dropdown(options=[0,1,2], value=self.slice_axis, description='Slice axis')
+        self.slice_axis_widget = widgets.Dropdown(
+            options=[0, 1, 2], value=self.slice_axis, description='Slice axis'
+        )
         self.slice_axis_widget.layout.width = '250px'
 
-        self.slice_index_widget = widgets.IntSlider(min=0, step=1, description="Slice index", layout=layout)
+        self.slice_index_widget = widgets.IntSlider(
+            min=0, step=1, description='Slice index', layout=layout
+        )
         self.slice_index_widget.layout.width = '400px'
-    
+
     def calculate_line_endpoints(self, x, y, angle):
         """
         Line is parameterized as: [x + t*np.cos(angle), y + t*np.sin(angle)]
         """
         if np.isclose(angle, 0):
             return [0, y], [self.x_max, y]
-        elif np.isclose(angle, np.pi/2):
+        elif np.isclose(angle, np.pi / 2):
             return [x, 0], [x, self.y_max]
         elif np.isclose(angle, np.pi):
             return [self.x_max, y], [0, y]
-        elif np.isclose(angle, 3*np.pi/2):
+        elif np.isclose(angle, 3 * np.pi / 2):
             return [x, self.y_max], [x, 0]
-        elif np.isclose(angle, 2*np.pi):
+        elif np.isclose(angle, 2 * np.pi):
             return [0, y], [self.x_max, y]
-        
+
         t_left = -x / np.cos(angle)
         t_bottom = -y / np.sin(angle)
         t_right = (self.x_max - x) / np.cos(angle)
@@ -1075,23 +1091,26 @@ class _LineProfile:
         t_values = np.array([t_left, t_top, t_right, t_bottom])
         t_pos = np.min(t_values[t_values > 0])
         t_neg = np.max(t_values[t_values < 0])
-        
+
         src = [x + t_neg * np.cos(angle), y + t_neg * np.sin(angle)]
         dst = [x + t_pos * np.cos(angle), y + t_pos * np.sin(angle)]
         return src, dst
-    
+
     def update(self, slice_axis, slice_index, x, y, angle_deg, fraction_range):
         if slice_axis != self.slice_axis:
             self.update_slice_axis(slice_axis)
             x = self.x_widget.value
             y = self.y_widget.value
             slice_index = self.slice_index_widget.value
-        
+
         clear_output(wait=True)
-        
+
         image = np.take(self.volume, slice_index, slice_axis)
         angle = np.radians(angle_deg)
-        src, dst = [np.array(point, dtype='float32') for point in self.calculate_line_endpoints(x, y, angle)]
+        src, dst = (
+            np.array(point, dtype='float32')
+            for point in self.calculate_line_endpoints(x, y, angle)
+        )
 
         # Rescale endpoints
         line_vec = dst - src
@@ -1101,30 +1120,37 @@ class _LineProfile:
         y_pline = skimage.measure.profile_line(image, src, dst)
 
         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
-        
+
         # Image with color-gradiented line
         num_segments = 100
         x_seg = np.linspace(src[0], dst[0], num_segments)
         y_seg = np.linspace(src[1], dst[1], num_segments)
-        segments = np.stack([np.column_stack([y_seg[:-2], x_seg[:-2]]), 
-                             np.column_stack([y_seg[2:], x_seg[2:]])], axis=1)
-        norm = plt.Normalize(vmin=0, vmax=num_segments-1)
+        segments = np.stack(
+            [
+                np.column_stack([y_seg[:-2], x_seg[:-2]]),
+                np.column_stack([y_seg[2:], x_seg[2:]]),
+            ],
+            axis=1,
+        )
+        norm = plt.Normalize(vmin=0, vmax=num_segments - 1)
         colors = self.cmap(norm(np.arange(num_segments - 1)))
         lc = matplotlib.collections.LineCollection(segments, colors=colors, linewidth=2)
 
-        ax[0].imshow(image,cmap='gray')
+        ax[0].imshow(image, cmap='gray')
         ax[0].add_collection(lc)
         # pivot point
-        ax[0].plot(y,x,marker='s', linestyle='', color='cyan', markersize=4)
+        ax[0].plot(y, x, marker='s', linestyle='', color='cyan', markersize=4)
         ax[0].set_xlabel(f'axis {np.delete(np.arange(3), self.slice_axis)[1]}')
         ax[0].set_ylabel(f'axis {np.delete(np.arange(3), self.slice_axis)[0]}')
-        
+
         # Profile intensity plot
         norm = plt.Normalize(0, vmax=len(y_pline) - 1)
         x_pline = np.arange(len(y_pline))
         points = np.column_stack((x_pline, y_pline))[:, np.newaxis, :]
         segments = np.concatenate([points[:-1], points[1:]], axis=1)
-        lc = matplotlib.collections.LineCollection(segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2)
+        lc = matplotlib.collections.LineCollection(
+            segments, cmap=self.cmap, norm=norm, array=x_pline[:-1], linewidth=2
+        )
 
         ax[1].add_collection(lc)
         ax[1].autoscale()
@@ -1132,48 +1158,65 @@ class _LineProfile:
         ax[1].grid(True)
         plt.tight_layout()
         plt.show()
-    
+
     def build_interactive(self):
         # Group widgets into two columns
-        title_style = "text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;"
-        title_column1 = widgets.HTML(f"<div style='{title_style}'>Line parameterization</div>")
-        title_column2 = widgets.HTML(f"<div style='{title_style}'>Slice selection</div>")
+        title_style = (
+            'text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;'
+        )
+        title_column1 = widgets.HTML(
+            f"<div style='{title_style}'>Line parameterization</div>"
+        )
+        title_column2 = widgets.HTML(
+            f"<div style='{title_style}'>Slice selection</div>"
+        )
 
         # Make label widgets instead of descriptions which have different lengths.
         label_layout = widgets.Layout(width='120px')
-        label_x = widgets.Label("Vertical position", layout=label_layout)
-        label_y = widgets.Label("Horizontal position", layout=label_layout)
-        label_angle = widgets.Label("Angle (°)", layout=label_layout)
-        label_fraction = widgets.Label("Fraction range", layout=label_layout)
+        label_x = widgets.Label('Vertical position', layout=label_layout)
+        label_y = widgets.Label('Horizontal position', layout=label_layout)
+        label_angle = widgets.Label('Angle (°)', layout=label_layout)
+        label_fraction = widgets.Label('Fraction range', layout=label_layout)
 
         row_x = widgets.HBox([label_x, self.x_widget])
         row_y = widgets.HBox([label_y, self.y_widget])
         row_angle = widgets.HBox([label_angle, self.angle_widget])
         row_fraction = widgets.HBox([label_fraction, self.line_fraction_widget])
 
-        controls_column1 = widgets.VBox([title_column1, row_x, row_y, row_angle, row_fraction])
-        controls_column2 = widgets.VBox([title_column2, self.slice_axis_widget, self.slice_index_widget])
+        controls_column1 = widgets.VBox(
+            [title_column1, row_x, row_y, row_angle, row_fraction]
+        )
+        controls_column2 = widgets.VBox(
+            [title_column2, self.slice_axis_widget, self.slice_index_widget]
+        )
         controls = widgets.HBox([controls_column1, controls_column2])
 
         interactive_plot = widgets.interactive_output(
-            self.update, 
-            {'slice_axis': self.slice_axis_widget, 'slice_index': self.slice_index_widget, 
-            'x': self.x_widget, 'y': self.y_widget, 'angle_deg': self.angle_widget,
-            'fraction_range': self.line_fraction_widget}
+            self.update,
+            {
+                'slice_axis': self.slice_axis_widget,
+                'slice_index': self.slice_index_widget,
+                'x': self.x_widget,
+                'y': self.y_widget,
+                'angle_deg': self.angle_widget,
+                'fraction_range': self.line_fraction_widget,
+            },
         )
 
         return widgets.VBox([controls, interactive_plot])
 
+
 def line_profile(
-        volume: np.ndarray,
-        slice_axis: int=0,
-        slice_index: int | str='middle',
-        vertical_position: int | str='middle',
-        horizontal_position: int | str='middle',
-        angle: int=0,
-        fraction_range: Tuple[float,float]=(0.00, 1.00)
-    ) -> widgets.interactive:
-    """Returns an interactive widget for visualizing the intensity profiles of lines on slices.
+    volume: np.ndarray,
+    slice_axis: int = 0,
+    slice_index: int | str = 'middle',
+    vertical_position: int | str = 'middle',
+    horizontal_position: int | str = 'middle',
+    angle: int = 0,
+    fraction_range: Tuple[float, float] = (0.00, 1.00),
+) -> widgets.interactive:
+    """
+    Returns an interactive widget for visualizing the intensity profiles of lines on slices.
 
     Args:
         volume (np.ndarray): The 3D volume of interest.
@@ -1181,13 +1224,13 @@ def line_profile(
         slice_index (int or str, optional): Specifies the initial slice index along slice_axis.
         vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point.
         horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point.
-        angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo. 
+        angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo.
         fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0].
 
     Returns:
         widget (widgets.widget_box.VBox): The interactive widget.
 
-    
+
     Example:
         ```python
         import qim3d
@@ -1198,57 +1241,82 @@ def line_profile(
         ![viz histogram](../../assets/screenshots/viz-line_profile.gif)
 
     """
+
     def parse_position(pos, pos_range, name):
         if isinstance(pos, int):
             if not pos_range[0] <= pos < pos_range[1]:
-                raise ValueError(f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]')
+                raise ValueError(
+                    f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]'
+                )
             return pos
         elif isinstance(pos, str):
             pos = pos.lower()
-            if pos == 'start': return pos_range[0]
-            elif pos == 'middle': return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
-            elif pos == 'end': return pos_range[1]
+            if pos == 'start':
+                return pos_range[0]
+            elif pos == 'middle':
+                return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
+            elif pos == 'end':
+                return pos_range[1]
             else:
                 raise ValueError(
                     f"Invalid string '{pos}' for {name}. "
                     "Must be 'start', 'middle', or 'end'."
                 )
         else:
-            raise TypeError(f'Axis position must be of type int or str.')
-    
+            raise TypeError('Axis position must be of type int or str.')
+
     if not isinstance(volume, (np.ndarray, da.core.Array)):
-        raise ValueError("Data type for volume not supported.")
+        raise ValueError('Data type for volume not supported.')
     if volume.ndim != 3:
-        raise ValueError("Volume must be 3D.")
-    
+        raise ValueError('Volume must be 3D.')
+
     dims = volume.shape
     slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index')
     # the omission of the ends for the pivot point is due to border issues.
-    vertical_position = parse_position(vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position')
-    horizontal_position = parse_position(horizontal_position, (1, np.delete(dims, slice_axis)[1] - 2), 'horizontal_position')
-    
+    vertical_position = parse_position(
+        vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position'
+    )
+    horizontal_position = parse_position(
+        horizontal_position,
+        (1, np.delete(dims, slice_axis)[1] - 2),
+        'horizontal_position',
+    )
+
     if not isinstance(angle, int | float):
-        raise ValueError("Invalid type for angle.")
+        raise ValueError('Invalid type for angle.')
     angle = round(angle) % 360
 
-    if not (0.0 <= fraction_range[0] <= 1.0 and 0.0 <= fraction_range[1] <= 1.0 and fraction_range[0] <= fraction_range[1]):
-        raise ValueError("Invalid values for fraction_range.")
+    if not (
+        0.0 <= fraction_range[0] <= 1.0
+        and 0.0 <= fraction_range[1] <= 1.0
+        and fraction_range[0] <= fraction_range[1]
+    ):
+        raise ValueError('Invalid values for fraction_range.')
 
-    lp = _LineProfile(volume, slice_axis, slice_index, vertical_position, horizontal_position, angle, fraction_range)
+    lp = _LineProfile(
+        volume,
+        slice_axis,
+        slice_index,
+        vertical_position,
+        horizontal_position,
+        angle,
+        fraction_range,
+    )
     return lp.build_interactive()
 
+
 def threshold(
-        volume: np.ndarray,
-        cmap_image: str = 'viridis',
-        vmin: float = None,
-        vmax: float = None,
+    volume: np.ndarray,
+    cmap_image: str = 'viridis',
+    vmin: float = None,
+    vmax: float = None,
 ) -> widgets.VBox:
     """
-    This function provides an interactive interface to explore thresholding on a 
-    3D volume slice-by-slice. Users can either manually set the threshold value 
-    using a slider or select an automatic thresholding method from `skimage`. 
+    This function provides an interactive interface to explore thresholding on a
+    3D volume slice-by-slice. Users can either manually set the threshold value
+    using a slider or select an automatic thresholding method from `skimage`.
 
-    The visualization includes the original image slice, a binary mask showing regions above the 
+    The visualization includes the original image slice, a binary mask showing regions above the
     threshold and an overlay combining the binary mask and the original image.
 
     Args:
@@ -1257,16 +1325,16 @@ def threshold(
         cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'.
         vmin (float, optional): Minimum value for the colormap. Defaults to None.
         vmax (float, optional): Maximum value for the colormap. Defaults to None.
-    
+
     Returns:
         slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume.
-    
+
     Interactivity:
         - **Manual Thresholding**:
-            Select 'Manual' from the dropdown menu to manually adjust the threshold 
+            Select 'Manual' from the dropdown menu to manually adjust the threshold
             using the slider.
         - **Automatic Thresholding**:
-            Choose a method from the dropdown menu to apply an automatic thresholding 
+            Choose a method from the dropdown menu to apply an automatic thresholding
             algorithm. Available methods include:
             - Otsu
             - Isodata
@@ -1276,7 +1344,7 @@ def threshold(
             - Triangle
             - Yen
 
-            The threshold slider will display the computed value and will be disabled 
+            The threshold slider will display the computed value and will be disabled
             in this mode.
 
 
@@ -1290,6 +1358,7 @@ def threshold(
         qim3d.viz.threshold(vol)
         ```
         ![interactive threshold](../../assets/screenshots/interactive_thresholding.gif)
+
     """
 
     # Centralized state dictionary to track current parameters
@@ -1348,10 +1417,14 @@ def threshold(
 
             # Original image
             new_vmin = (
-                None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
+                None
+                if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
+                else vmin
             )
             new_vmax = (
-                None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
+                None
+                if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
+                else vmax
             )
             axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
             axes[0].set_title('Original')
@@ -1405,7 +1478,16 @@ def threshold(
     )
 
     method_dropdown = widgets.Dropdown(
-        options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
+        options=[
+            'Manual',
+            'Otsu',
+            'Isodata',
+            'Li',
+            'Mean',
+            'Minimum',
+            'Triangle',
+            'Yen',
+        ],
         value=state['method'],
         description='Method',
     )
-- 
GitLab