From 5bd12b1e893b1cda14aa6d7f70d1b5f5929bf5fd Mon Sep 17 00:00:00 2001
From: Alessia Saccardo <s212246@dtu.dk>
Date: Thu, 12 Dec 2024 11:37:50 +0100
Subject: [PATCH] threshold exploration refactoring, add vertical line in
 histogram function

---
 qim3d/processing/operations.py |   4 +-
 qim3d/viz/explore.py           | 225 +++++++++++++++++----------------
 2 files changed, 115 insertions(+), 114 deletions(-)

diff --git a/qim3d/processing/operations.py b/qim3d/processing/operations.py
index e559a16d..565d694f 100644
--- a/qim3d/processing/operations.py
+++ b/qim3d/processing/operations.py
@@ -1,6 +1,8 @@
 import numpy as np
 import qim3d.processing.filters as filters
 from qim3d.utils.logger import log
+import skimage
+import scipy
 
 
 def remove_background(
@@ -88,8 +90,6 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i
         ![operations-watershed_after](assets/screenshots/operations-watershed_after.png)
 
     """
-    import skimage
-    import scipy
 
     if len(np.unique(bin_vol)) > 2:
         raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.")
diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py
index 67a71b86..383a23c0 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/explore.py
@@ -816,10 +816,10 @@ def chunks(zarr_path: str, **kwargs):
     display(final_layout)
 
 def histogram(
-    vol: np.ndarray,
+    volume: np.ndarray,
     bins: Union[int, str] = "auto",
     slice_idx: Union[int, str, None] = None,
-    threshold: int = None,
+    vertical_line: int = None,
     axis: int = 0,
     kde: bool = True,
     log_scale: bool = False,
@@ -840,11 +840,12 @@ def histogram(
     Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
 
     Args:
-        vol (np.ndarray): A 3D NumPy array representing the volume to be visualized.
+        volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
         bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
         axis (int, optional): Axis along which to take a slice. Default is 0.
         slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
                                                If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
+        vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
         kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
         log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
         despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
@@ -868,24 +869,24 @@ def histogram(
         ValueError: If `axis` is not a valid axis index (0, 1, or 2).
         ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
     """
-    if not (0 <= axis < vol.ndim):
-        raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
+    if not (0 <= axis < volume.ndim):
+        raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
 
     if slice_idx == "middle":
-        slice_idx = vol.shape[axis] // 2
+        slice_idx = volume.shape[axis] // 2
 
     if slice_idx is not None:
-        if 0 <= slice_idx < vol.shape[axis]:
-            img_slice = np.take(vol, indices=slice_idx, axis=axis)
+        if 0 <= slice_idx < volume.shape[axis]:
+            img_slice = np.take(volume, indices=slice_idx, axis=axis)
             data = img_slice.ravel()
             title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
         else:
             raise ValueError(
-                f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}."
+                f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
             )
     else:
-        data = vol.ravel()
-        title = f"Intensity histogram for whole volume {vol.shape}"
+        data = volume.ravel()
+        title = f"Intensity histogram for whole volume {volume.shape}"
 
     # Use provided Axes or create new figure
     if ax is None:
@@ -910,15 +911,14 @@ def histogram(
         **sns_kwargs,
     )
 
-    if threshold is not None:
+    if vertical_line is not None:
         ax.axvline(
-            x=threshold,
+            x=vertical_line,
             color='red',
             linestyle="--",
             linewidth=2,
-            label=f"Threshold = {round(threshold)}"
+
         )
-        ax.legend()
 
     if despine:
         sns.despine(
@@ -932,7 +932,6 @@ def histogram(
             trim=True,
         )
 
-
     ax.set_xlabel("Voxel Intensity")
     ax.set_ylabel("Frequency")
 
@@ -951,7 +950,6 @@ def histogram(
 def threshold(
         volume: np.ndarray,
         cmap_image: str = 'viridis',
-        cmap_overlay: str = 'gray',
         vmin: float = None,
         vmax: float = None,
 ) -> widgets.VBox:
@@ -1004,6 +1002,13 @@ def threshold(
         ![interactive threshold](assets/screenshots/interactive_thresholding.gif)
     """
 
+    # Centralized state dictionary to track current parameters
+    state = {
+        'position': volume.shape[0] // 2,
+        'threshold': int((volume.min() + volume.max()) / 2),
+        'method': 'Manual',
+    }
+
     threshold_methods = {
         'Otsu': threshold_otsu,
         'Isodata': threshold_isodata,
@@ -1011,129 +1016,125 @@ def threshold(
         'Mean': threshold_mean,
         'Minimum': threshold_minimum,
         'Triangle': threshold_triangle,
-        'Yen': threshold_yen
+        'Yen': threshold_yen,
     }
-    
-    # Create the interactive widget
-    def _slicer(position, threshold, method):
-        fig, axes = plt.subplots(1, 4, figsize=(25, 5))
 
-        slice_img = volume[position, :, :]
-        # If vmin is higher than the highest value in the image ValueError is raised
-        # We don't want to override the values because next slices might be okay
-        new_vmin = (
-            None
-            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
-            else vmin
-        )
-        new_vmax = (
-            None
-            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
-            else vmax
-        )
+    # Create an output widget to display the plot
+    output = widgets.Output()
 
-        # Add original image to the plot
-        axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
-        axes[0].set_title('Original')
-        axes[0].axis('off')        
+    # Function to update the state and trigger visualization
+    def update_state(change):
+        # Update state based on widget values
+        state['position'] = position_slider.value
+        state['method'] = method_dropdown.value
 
-        # Compute the threshold value
-        if method == 'Manual':
+        if state['method'] == 'Manual':
+            state['threshold'] = threshold_slider.value
             threshold_slider.disabled = False
         else:
-            # Apply the appropriate thresholding function
-            threshold_func = threshold_methods.get(method)
+            threshold_func = threshold_methods.get(state['method'])
             if threshold_func:
-                threshold = threshold_func(slice_img)
-                if threshold_slider.value != threshold:
-                    threshold_slider.unobserve_all()
-                    threshold_slider.value = threshold
+                slice_img = volume[state['position'], :, :]
+                computed_threshold = threshold_func(slice_img)
+                state['threshold'] = computed_threshold
+
+                # Programmatically update the slider without triggering callbacks
+                threshold_slider.unobserve_all()
+                threshold_slider.value = computed_threshold
                 threshold_slider.disabled = True
+                threshold_slider.observe(update_state, names='value')
             else:
-                raise ValueError(f"Unsupported thresholding method: {method}")
-
-        # Compute and add the histogram to the plot
-        histogram(
-            vol=volume,
-            bins=32,
-            slice_idx=position,
-            threshold=threshold,
-            axis=1,
-            kde=False,
-            ax=axes[1],
-            show=False,
-        )
+                raise ValueError(f"Unsupported thresholding method: {state['method']}")
 
-        axes[1].set_title(f'Histogram')
-        
-        # Compute and add the binary mask to the plot
-        mask = slice_img > threshold
-        axes[2].imshow(mask, cmap='grey')
-        axes[2].set_title('Binary mask')
-        axes[2].axis('off')
-
-        # both mask and img should be rgb
-        # mask data in first channel and then black the other sure --> no cmap_overlay
-        # Compute and add the overlay to the plot
-        masked_volume = qim3d.processing.operations.overlay_rgb_images(
-            background = slice_img,
-            foreground = mask,
-        )
+        # Trigger visualization
+        update_visualization()
 
-        # If vmin is higher than the highest value in the image ValueError is raised
-        # We don't want to override the values because next slices might be okay
-        new_vmin = (
-            None
-            if (isinstance(vmin, (float, int)) and vmin > np.max(masked_volume))
-            else vmin
-        )
-        new_vmax = (
-            None
-            if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume))
-            else vmax
-        )
-        axes[3].imshow(masked_volume, cmap=cmap_overlay, vmin=new_vmin, vmax=new_vmax)
-        axes[3].set_title('Overlay')
-        axes[3].axis('off')
+    # Visualization function
+    def update_visualization():
+        slice_img = volume[state['position'], :, :]
+        with output:
+            output.clear_output(wait=True)  # Clear previous plot
+            fig, axes = plt.subplots(1, 4, figsize=(25, 5))
 
-        return fig
-    
-    method_dropdown = widgets.Dropdown(
-        options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
-        value='Manual',  # default value
-        description='Method',
-    )
+            # Original image
+            new_vmin = (
+                None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
+            )
+            new_vmax = (
+                None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
+            )
+            axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
+            axes[0].set_title('Original')
+            axes[0].axis('off')
+
+            # Histogram
+            histogram(
+                volume=volume,
+                bins=32,
+                slice_idx=state['position'],
+                vertical_line=state['threshold'],
+                axis=1,
+                kde=False,
+                ax=axes[1],
+                show=False,
+            )
+            axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")
+
+            # Binary mask
+            mask = slice_img > state['threshold']
+            axes[2].imshow(mask, cmap='gray')
+            axes[2].set_title('Binary mask')
+            axes[2].axis('off')
+
+            # Overlay
+            mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
+            mask_rgb[:, :, 0] = mask
+            masked_volume = qim3d.processing.operations.overlay_rgb_images(
+                background=slice_img,
+                foreground=mask_rgb,
+            )
+            axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
+            axes[3].set_title('Overlay')
+            axes[3].axis('off')
+
+            plt.show()
 
+    # Widgets
     position_slider = widgets.IntSlider(
-        value=volume.shape[0] // 2,
+        value=state['position'],
         min=0,
         max=volume.shape[0] - 1,
         description='Slice',
-        continuous_update=False,
     )
 
     threshold_slider = widgets.IntSlider(
-        value=int((volume.min() + volume.max()) / 2),  
+        value=state['threshold'],
         min=volume.min(),
-        max=volume.max(), 
+        max=volume.max(),
         description='Threshold',
-        continuous_update=False,
     )
 
-    slicer_obj = widgets.interactive(
-        _slicer,
-        position=position_slider,
-        threshold=threshold_slider,
-        method = method_dropdown,
+    method_dropdown = widgets.Dropdown(
+        options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
+        value=state['method'],
+        description='Method',
     )
 
+    # Attach the state update function to widgets
+    position_slider.observe(update_state, names='value')
+    threshold_slider.observe(update_state, names='value')
+    method_dropdown.observe(update_state, names='value')
+
+    # Layout
     controls_left = widgets.VBox([position_slider, threshold_slider])
     controls_right = widgets.VBox([method_dropdown])
-    controls_layout = widgets.HBox([controls_left, controls_right], layout=widgets.Layout(justify_content='space-between'))
-    slicer_obj = widgets.VBox([controls_layout, slicer_obj.children[-1]]) 
-    slicer_obj.layout.align_items = "flex-start"
-
+    controls_layout = widgets.HBox(
+        [controls_left, controls_right],
+        layout=widgets.Layout(justify_content='flex-start'),
+    )
+    interactive_ui = widgets.VBox([controls_layout, output])
+    update_visualization()
 
-    return slicer_obj
+    return interactive_ui
 
         
\ No newline at end of file
-- 
GitLab