From e2cc0aa6fa6639e13299f7d39260648f60e26f19 Mon Sep 17 00:00:00 2001
From: Alessia Saccardo <s212246@dtu.dk>
Date: Tue, 26 Nov 2024 14:11:10 +0100
Subject: [PATCH] add additional thresholding method and remove blink effect

---
 qim3d/viz/__init__.py |   1 +
 qim3d/viz/explore.py  | 144 +++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 144 insertions(+), 1 deletion(-)

diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index 84b2e835..0e5338c8 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -7,6 +7,7 @@ from .explore import (
     slicer,
     slices,
     chunks,
+    threshold,
 )
 from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
 from .k3d import vol, mesh
diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py
index d7392cbd..8c4452e1 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/explore.py
@@ -15,7 +15,8 @@ import matplotlib
 import numpy as np
 import zarr
 from qim3d.utils.logger import log
-
+from ipywidgets import interact, IntSlider, FloatSlider, Dropdown
+from skimage.filters import threshold_otsu, threshold_isodata, threshold_li, threshold_mean, threshold_minimum, threshold_triangle, threshold_yen
 
 import qim3d
 
@@ -775,3 +776,144 @@ def chunks(zarr_path: str, **kwargs):
 
     # Display the VBox
     display(final_layout)
+
+def threshold(
+        volume: np.ndarray,
+        axis: int = 0,
+        cmap_volume: str = 'gray',
+        cmap_threshold: str = 'gray',
+        vmin: float = None,
+        vmax: float = None,
+):
+    """
+    Interactive thresholding of a 3D volume.
+    Args:
+        volume (np.ndarray): 3D volume to threshold.
+        cmap_image (str, optional): Colormap for the original image. Defaults to 'gray'.
+        cmap_threshold (str, optional): Colormap for the thresholded image. Defaults to 'gray'.
+    Example:
+        ```python
+        import qim3d
+        vol = qim3d.examples.bone_128x128x128
+        qim3d.viz.threshold(vol)
+        ```
+    """
+    threshold_methods = {
+        'Otsu': threshold_otsu,
+        'Isodata': threshold_isodata,
+        'Li': threshold_li,
+        'Mean': threshold_mean,
+        'Minimum': threshold_minimum,
+        'Triangle': threshold_triangle,
+        'Yen': threshold_yen
+    }
+    
+    # Create the interactive widget
+    def _slicer(position, threshold, method):
+        fig, axes = plt.subplots(1, 3, figsize=(9, 3))
+
+        slice_img = volume[position, :, :]
+        # If vmin is higher than the highest value in the image ValueError is raised
+        # We don't want to override the values because next slices might be okay
+        new_vmin = (
+            None
+            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
+            else vmin
+        )
+        new_vmax = (
+            None
+            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
+            else vmax
+        )
+
+        axes[0].imshow(slice_img, cmap=cmap_volume, vmin=new_vmin, vmax=new_vmax)
+        axes[0].set_title('Original')
+        axes[0].axis('off')
+
+        if method == 'Manual':
+            threshold_slider.disabled = False
+        else:
+            # Apply the appropriate thresholding function
+            threshold_func = threshold_methods.get(method)
+            if threshold_func:
+                threshold = threshold_func(slice_img)
+                if threshold_slider.value != threshold:
+                    threshold_slider.unobserve_all()
+                    threshold_slider.value = threshold
+                threshold_slider.disabled = True
+            else:
+                raise ValueError(f"Unsupported thresholding method: {method}")
+
+        
+        mask = slice_img > threshold
+        axes[1].imshow(mask, cmap=cmap_threshold)
+        axes[1].set_title('Binary mask')
+        axes[1].axis('off')
+
+        masked_volume = qim3d.processing.operations.overlay_rgb_images(
+            background = slice_img,
+            foreground = mask,
+        )
+        # If vmin is higher than the highest value in the image ValueError is raised
+        # We don't want to override the values because next slices might be okay
+        new_vmin = (
+            None
+            if (isinstance(vmin, (float, int)) and vmin > np.max(masked_volume))
+            else vmin
+        )
+        new_vmax = (
+            None
+            if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume))
+            else vmax
+        )
+        axes[2].imshow(masked_volume, cmap=cmap_threshold, vmin=new_vmin, vmax=new_vmax)
+        axes[2].set_title('Overlay')
+        axes[2].axis('off')
+
+        return fig
+    
+    method_dropdown = widgets.Dropdown(
+        options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
+        value='Manual',  # default value
+        description='Method',
+    )
+
+    position_slider = widgets.IntSlider(
+        value=volume.shape[0] // 2,
+        min=0,
+        max=volume.shape[0] - 1,
+        description='Slice',
+        continuous_update=False,
+    )
+
+    threshold_slider = widgets.IntSlider(
+        value=int((volume.min() + volume.max()) / 2),  
+        min=volume.min(),
+        max=volume.max(), 
+        description='Threshold',
+        continuous_update=False,
+    )
+
+    slicer_obj = widgets.interactive(
+        _slicer,
+        position=position_slider,
+        threshold=threshold_slider,
+        method = method_dropdown,
+    )
+
+    #slicer_obj.layout = widgets.Layout(align_items='flex-start')
+
+    # Group sliders vertically (left column)
+    controls_left = widgets.VBox([position_slider, threshold_slider])
+
+    # Place the method dropdown in a separate column (right)
+    controls_right = widgets.VBox([method_dropdown])
+
+    # Combine the controls horizontally (sliders on the left, dropdown on the right)
+    controls_layout = widgets.HBox([controls_left, controls_right], layout=widgets.Layout(justify_content='space-between'))
+
+    # Combine the controls with the interactive slicer object
+    slicer_obj = widgets.VBox([controls_layout, slicer_obj.children[-1]])  # Add sliders + dropdown above the plot
+
+    return slicer_obj
+
-- 
GitLab