From ecfd7c0875e4510cf9e8db2b43c50c1379bb86f7 Mon Sep 17 00:00:00 2001
From: Christian Kento Rasmussen <christian.kento@gmail.com>
Date: Fri, 16 Feb 2024 09:43:57 +0100
Subject: [PATCH] Add watershed segmentation algorithm

---
 qim3d/utils/watershed.py | 26 ++++++++++++++++++++++++++
 1 file changed, 26 insertions(+)
 create mode 100644 qim3d/utils/watershed.py

diff --git a/qim3d/utils/watershed.py b/qim3d/utils/watershed.py
new file mode 100644
index 00000000..1456ca4a
--- /dev/null
+++ b/qim3d/utils/watershed.py
@@ -0,0 +1,26 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from scipy import ndimage as ndi
+from skimage.feature import peak_local_max
+from skimage.segmentation import watershed
+
+
+def watershed_segment(volume):
+    """ Apply watershed algorithm to a 3D volume.
+
+    Args:
+        volume (np.array | torch.Tensor):  A 3D volume.
+
+    Returns:
+        np.array: Segmented watershed Connected Components.
+    """
+    volume = volume > 1
+    distance = ndi.distance_transform_edt(volume)
+    coords = peak_local_max(distance, footprint=np.ones((3,)*volume.ndim), labels=volume)
+    mask = np.zeros(distance.shape, dtype=bool)
+    mask[tuple(coords.T)] = True
+    markers, _ = ndi.label(mask)
+    labels = watershed(-distance, markers, mask=volume)
+    
+    return labels
\ No newline at end of file
-- 
GitLab