From 98a35b34008526d9c3dea79d85045c0f34289eb7 Mon Sep 17 00:00:00 2001
From: Christian <s224389@dtu.dk>
Date: Thu, 16 Jan 2025 18:38:28 +0100
Subject: [PATCH] Added interactive kernel size tuning

---
 GUI_draft_live.py | 365 ++++++++++++++++++++++++++++++++++------------
 1 file changed, 269 insertions(+), 96 deletions(-)

diff --git a/GUI_draft_live.py b/GUI_draft_live.py
index f5f8bf9..e4af0fe 100644
--- a/GUI_draft_live.py
+++ b/GUI_draft_live.py
@@ -8,14 +8,191 @@ from scipy.signal import savgol_filter
 from PyQt5.QtWidgets import (
     QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
     QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
-    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
+    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
+    QSlider, QLabel
 )
 from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
-from PyQt5.QtCore import Qt, QRectF
+from PyQt5.QtCore import Qt, QRectF, QSize
 
 from live_wire import compute_cost_image, find_path
 
 
+# ------------------------------------------------------------------------
+# A pan & zoom QGraphicsView
+# ------------------------------------------------------------------------
+class PanZoomGraphicsView(QGraphicsView):
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setDragMode(QGraphicsView.NoDrag)  # We'll handle panning manually
+        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
+        self._panning = False
+        self._pan_start = None
+
+    def wheelEvent(self, event):
+        """
+        Zoom in/out with mouse wheel.
+        """
+        zoom_in_factor = 1.25
+        zoom_out_factor = 1 / zoom_in_factor
+        if event.angleDelta().y() > 0:
+            self.scale(zoom_in_factor, zoom_in_factor)
+        else:
+            self.scale(zoom_out_factor, zoom_out_factor)
+        event.accept()
+
+    def mousePressEvent(self, event):
+        """
+        If left button: Start panning (unless overridden in a subclass).
+        """
+        if event.button() == Qt.LeftButton:
+            self._panning = True
+            self._pan_start = event.pos()
+            self.setCursor(Qt.ClosedHandCursor)
+        super().mousePressEvent(event)
+
+    def mouseMoveEvent(self, event):
+        """
+        If panning, translate the scene.
+        """
+        if self._panning and self._pan_start is not None:
+            delta = event.pos() - self._pan_start
+            self._pan_start = event.pos()
+            self.translate(delta.x(), delta.y())
+        super().mouseMoveEvent(event)
+
+    def mouseReleaseEvent(self, event):
+        """
+        End panning.
+        """
+        if event.button() == Qt.LeftButton:
+            self._panning = False
+            self.setCursor(Qt.ArrowCursor)
+        super().mouseReleaseEvent(event)
+
+
+# ------------------------------------------------------------------------
+# A specialized PanZoomGraphicsView for the circle editor
+#    Only pan if user did NOT click on the draggable circle
+# ------------------------------------------------------------------------
+class CircleEditorGraphicsView(PanZoomGraphicsView):
+    def mousePressEvent(self, event):
+        if event.button() == Qt.LeftButton:
+            # Check if the user clicked on the circle item
+            clicked_item = self.itemAt(event.pos().x(), event.pos().y())
+            if clicked_item is not None:
+                # Walk up parent chain to see if it is our DraggableCircleItem
+                it = clicked_item
+                while it is not None and not hasattr(it, "boundingRect"):
+                    it = it.parentItem()
+                from PyQt5.QtWidgets import QGraphicsEllipseItem
+                if isinstance(it, DraggableCircleItem):
+                    # Let normal item-dragging occur, don't initiate panning
+                    return QGraphicsView.mousePressEvent(self, event)
+        # Otherwise proceed with normal panning logic
+        super().mousePressEvent(event)
+
+
+# ------------------------------------------------------------------------
+# Draggable circle item (centered at (x, y) with radius)
+# ------------------------------------------------------------------------
+class DraggableCircleItem(QGraphicsEllipseItem):
+    def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
+        super().__init__(0, 0, 2*radius, 2*radius, parent)
+        self._r = radius
+
+        pen = QPen(color)
+        brush = QBrush(color)
+        self.setPen(pen)
+        self.setBrush(brush)
+
+        # Enable item-based dragging
+        self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
+                      QGraphicsEllipseItem.ItemIsSelectable |
+                      QGraphicsEllipseItem.ItemSendsScenePositionChanges)
+
+        # Position so that (x, y) is the center
+        self.setPos(x - radius, y - radius)
+
+    def set_radius(self, r):
+        # Keep the same center, just change radius
+        old_center = self.sceneBoundingRect().center()
+        self._r = r
+        self.setRect(0, 0, 2*r, 2*r)
+        new_center = self.sceneBoundingRect().center()
+        diff_x = old_center.x() - new_center.x()
+        diff_y = old_center.y() - new_center.y()
+        self.moveBy(diff_x, diff_y)
+
+    def radius(self):
+        return self._r
+
+
+# ------------------------------------------------------------------------
+# Circle editor widget with slider + done
+# ------------------------------------------------------------------------
+class CircleEditorWidget(QWidget):
+    def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
+        super().__init__(parent)
+        self._pixmap = pixmap
+        self._done_callback = done_callback
+        self._init_radius = init_radius
+
+        layout = QVBoxLayout(self)
+        self.setLayout(layout)
+
+        # Use specialized CircleEditorGraphicsView
+        self._graphics_view = CircleEditorGraphicsView()
+        self._scene = QGraphicsScene(self)
+        self._graphics_view.setScene(self._scene)
+        layout.addWidget(self._graphics_view)
+
+        self._image_item = QGraphicsPixmapItem(self._pixmap)
+        self._scene.addItem(self._image_item)
+
+        # Put circle in center
+        cx = self._pixmap.width() / 2
+        cy = self._pixmap.height() / 2
+        self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
+        self._scene.addItem(self._circle_item)
+
+        # Fit in view
+        self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
+        self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
+
+        # Bottom controls (slider + done)
+        bottom_layout = QHBoxLayout()
+        layout.addLayout(bottom_layout)
+
+        lbl = QLabel("size:")
+        bottom_layout.addWidget(lbl)
+
+        self._slider = QSlider(Qt.Horizontal)
+        self._slider.setRange(1, 200)
+        self._slider.setValue(self._init_radius)
+        bottom_layout.addWidget(self._slider)
+
+        self._btn_done = QPushButton("Done")
+        bottom_layout.addWidget(self._btn_done)
+
+        # Connect signals
+        self._slider.valueChanged.connect(self._on_slider_changed)
+        self._btn_done.clicked.connect(self._on_done_clicked)
+
+    def _on_slider_changed(self, value):
+        self._circle_item.set_radius(value)
+
+    def _on_done_clicked(self):
+        final_radius = self._circle_item.radius()
+        if self._done_callback is not None:
+            self._done_callback(final_radius)
+
+    def sizeHint(self):
+        return QSize(800, 600)
+
+
+# ------------------------------------------------------------------------
+# Labeled point item
+# ------------------------------------------------------------------------
 class LabeledPointItem(QGraphicsEllipseItem):
     def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
         super().__init__(0, 0, 2*radius, 2*radius, parent)
@@ -84,25 +261,23 @@ class LabeledPointItem(QGraphicsEllipseItem):
         return self._removable
 
 
-class ImageGraphicsView(QGraphicsView):
+# ------------------------------------------------------------------------
+# The original ImageGraphicsView with pan & zoom
+# ------------------------------------------------------------------------
+class ImageGraphicsView(PanZoomGraphicsView):
     def __init__(self, parent=None):
         super().__init__(parent)
         self.scene = QGraphicsScene(self)
         self.setScene(self.scene)
 
-        # Zoom around mouse pointer
-        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
-
         # Image display
         self.image_item = QGraphicsPixmapItem()
         self.scene.addItem(self.image_item)
 
         self.anchor_points = []    # List[(x, y)]
-        self.point_items = []      # LabeledPointItem objects
-        self.full_path_points = [] # QGraphicsEllipseItems for the path
-
-        # We'll store the entire path coords (smoothed) for reference
-        self._full_path_xy = []
+        self.point_items = []      # LabeledPointItem
+        self.full_path_points = [] # QGraphicsEllipseItems for path
+        self._full_path_xy = []    # entire path coords (smoothed)
 
         self.dot_radius = 4
         self.path_radius = 1
@@ -110,10 +285,6 @@ class ImageGraphicsView(QGraphicsView):
         self._img_w = 0
         self._img_h = 0
 
-        # Pan/Drag
-        self.setDragMode(QGraphicsView.ScrollHandDrag)
-        self.viewport().setCursor(Qt.ArrowCursor)
-
         self._mouse_pressed = False
         self._press_view_pos = None
         self._drag_threshold = 5
@@ -185,33 +356,25 @@ class ImageGraphicsView(QGraphicsView):
         self.scene.addItem(item)
 
     def _add_guide_point(self, x, y):
-        """User clicked => find the correct sub-path, insert the point in that sub-path."""
         x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
         y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
 
         self._revert_cost_to_original()
 
         if not self._full_path_xy:
-            # If there's no existing path built, just insert at the end
             self._insert_anchor_point(-1, x_clamped, y_clamped,
                                       label="", removable=True, z_val=1, radius=self.dot_radius)
         else:
-            # Insert the new anchor in between the correct anchors,
-            # by finding nearest coordinate in _full_path_xy, then
-            # walking left+right until we find bounding anchors.
             self._insert_anchor_between_subpath(x_clamped, y_clamped)
 
         self._apply_all_guide_points_to_cost()
         self._rebuild_full_path()
 
     def _insert_anchor_between_subpath(self, x_new, y_new):
-        """Find the subpath bounding (x_new,y_new) and insert the new anchor accordingly."""
         if not self._full_path_xy:
-            # Fallback if no path
             self._insert_anchor_point(-1, x_new, y_new)
             return
 
-        # 1) Find nearest coordinate in the path
         best_idx = None
         best_d2 = float('inf')
         for i, (px, py) in enumerate(self._full_path_xy):
@@ -223,7 +386,6 @@ class ImageGraphicsView(QGraphicsView):
                 best_idx = i
 
         if best_idx is None:
-            # fallback
             self._insert_anchor_point(-1, x_new, y_new)
             return
 
@@ -237,7 +399,7 @@ class ImageGraphicsView(QGraphicsView):
                     return True
             return False
 
-        # 2) Walk left
+        # Walk left
         left_anchor_pt = None
         iL = best_idx
         while iL >= 0:
@@ -247,7 +409,7 @@ class ImageGraphicsView(QGraphicsView):
                 break
             iL -= 1
 
-        # 3) Walk right
+        # Walk right
         right_anchor_pt = None
         iR = best_idx
         while iR < len(self._full_path_xy):
@@ -257,17 +419,14 @@ class ImageGraphicsView(QGraphicsView):
                 break
             iR += 1
 
-        # fallback if missing anchors
         if not left_anchor_pt or not right_anchor_pt:
             self._insert_anchor_point(-1, x_new, y_new)
             return
 
-        # If they happen to be the same anchor, fallback
         if left_anchor_pt == right_anchor_pt:
             self._insert_anchor_point(-1, x_new, y_new)
             return
 
-        # 4) Map these anchor coords to indices in self.anchor_points
         left_idx = None
         right_idx = None
         for i, (ax, ay) in enumerate(self.anchor_points):
@@ -280,7 +439,6 @@ class ImageGraphicsView(QGraphicsView):
             self._insert_anchor_point(-1, x_new, y_new)
             return
 
-        # 5) Insert new point in between
         if left_idx < right_idx:
             insert_idx = left_idx + 1
         else:
@@ -326,7 +484,6 @@ class ImageGraphicsView(QGraphicsView):
     # PATH BUILDING
     # --------------------------------------------------------------------
     def _rebuild_full_path(self):
-        # Clear old path visuals
         for item in self.full_path_points:
             self.scene.removeItem(item)
         self.full_path_points.clear()
@@ -343,28 +500,19 @@ class ImageGraphicsView(QGraphicsView):
             if i == 0:
                 big_xy.extend(sub_xy)
             else:
-                # Avoid repeating the shared anchor
                 if len(sub_xy) > 1:
                     big_xy.extend(sub_xy[1:])
 
-        # Smooth if we have enough points
         if len(big_xy) >= 7:
             arr_xy = np.array(big_xy)
             smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
             big_xy = smoothed.tolist()
 
-        # Store the entire path
         self._full_path_xy = big_xy[:]
 
-        # Draw the path
         n_points = len(big_xy)
         for i, (px, py) in enumerate(big_xy):
-            if n_points > 1:
-                fraction = i / (n_points - 1)
-            else:
-                fraction = 0
-
-            # If rainbow is on, use the rainbow color; else use a constant color
+            fraction = i / (n_points - 1) if n_points > 1 else 0
             if self._rainbow_enabled:
                 color = self._rainbow_color(fraction)
             else:
@@ -378,13 +526,12 @@ class ImageGraphicsView(QGraphicsView):
             self.full_path_points.append(path_item)
             self.scene.addItem(path_item)
 
-        # Keep S/E on top if they have labels
+        # Keep anchor labels on top
         for p_item in self.point_items:
             if p_item._text_item:
                 p_item.setZValue(100)
 
     def _compute_subpath_xy(self, xA, yA, xB, yB):
-        """Return the raw path from (xA,yA)->(xB,yB)."""
         if self.cost_image is None:
             return []
         h, w = self.cost_image.shape
@@ -402,18 +549,13 @@ class ImageGraphicsView(QGraphicsView):
         return [(c, r) for (r, c) in path_rc]
 
     def _rainbow_color(self, fraction):
-        """
-        fraction: 0..1
-        Returns a QColor whose hue is fraction * 300 (for example),
-        at full saturation and full brightness.
-        """
-        hue = int(300 * fraction)  # up to 300 degrees
+        hue = int(300 * fraction)
         saturation = 255
         value = 255
         return QColor.fromHsv(hue, saturation, value)
 
     # --------------------------------------------------------------------
-    # MOUSE EVENTS
+    # MOUSE EVENTS (with pan & zoom from PanZoomGraphicsView)
     # --------------------------------------------------------------------
     def mousePressEvent(self, event):
         if event.button() == Qt.LeftButton:
@@ -421,37 +563,27 @@ class ImageGraphicsView(QGraphicsView):
             self._was_dragging = False
             self._press_view_pos = event.pos()
 
-            # See if user is clicking near an existing anchor => drag it
             idx = self._find_item_near(event.pos(), threshold=10)
             if idx is not None:
                 self._dragging_idx = idx
                 self._drag_counter = 0
-
                 scene_pos = self.mapToScene(event.pos())
                 px, py = self.point_items[idx].get_pos()
                 self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
-                self.setDragMode(QGraphicsView.NoDrag)
-                self.viewport().setCursor(Qt.ClosedHandCursor)
+                self.setCursor(Qt.ClosedHandCursor)
                 return
-            else:
-                # No anchor => we may add a new point
-                self.setDragMode(QGraphicsView.ScrollHandDrag)
-                self.viewport().setCursor(Qt.ClosedHandCursor)
 
         elif event.button() == Qt.RightButton:
-            # Right-click => remove anchor if removable
             self._remove_point_by_click(event.pos())
 
         super().mousePressEvent(event)
 
     def mouseMoveEvent(self, event):
         if self._dragging_idx is not None:
-            # Dragging anchor
             scene_pos = self.mapToScene(event.pos())
             x_new = scene_pos.x() - self._drag_offset[0]
             y_new = scene_pos.y() - self._drag_offset[1]
 
-            # clamp so user can't drag outside
             r = self.point_items[self._dragging_idx]._r
             x_clamped = self._clamp(x_new, r, self._img_w - r)
             y_clamped = self._clamp(y_new, r, self._img_h - r)
@@ -459,33 +591,29 @@ class ImageGraphicsView(QGraphicsView):
 
             self._drag_counter += 1
             if self._drag_counter >= 4:
-                # partial path update
                 self._drag_counter = 0
                 self._revert_cost_to_original()
                 self._apply_all_guide_points_to_cost()
                 self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                 self._rebuild_full_path()
-
         else:
             if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                 dist = (event.pos() - self._press_view_pos).manhattanLength()
                 if dist > self._drag_threshold:
                     self._was_dragging = True
 
-            super().mouseMoveEvent(event)
+        super().mouseMoveEvent(event)
 
     def mouseReleaseEvent(self, event):
         super().mouseReleaseEvent(event)
         if event.button() == Qt.LeftButton and self._mouse_pressed:
             self._mouse_pressed = False
-            self.viewport().setCursor(Qt.ArrowCursor)
+            self.setCursor(Qt.ArrowCursor)
 
             if self._dragging_idx is not None:
-                # finished dragging => final update
                 idx = self._dragging_idx
                 self._dragging_idx = None
                 self._drag_offset = (0, 0)
-                self.setDragMode(QGraphicsView.ScrollHandDrag)
 
                 newX, newY = self.point_items[idx].get_pos()
                 self.anchor_points[idx] = (newX, newY)
@@ -493,9 +621,8 @@ class ImageGraphicsView(QGraphicsView):
                 self._revert_cost_to_original()
                 self._apply_all_guide_points_to_cost()
                 self._rebuild_full_path()
-
             else:
-                # If user wasn't dragging => add new guide point
+                # No drag => add point
                 if not self._was_dragging:
                     scene_pos = self.mapToScene(event.pos())
                     x, y = scene_pos.x(), scene_pos.y()
@@ -507,7 +634,6 @@ class ImageGraphicsView(QGraphicsView):
         idx = self._find_item_near(view_pos, threshold=10)
         if idx is None:
             return
-        # skip if S/E
         if not self.point_items[idx].is_removable():
             return
 
@@ -534,18 +660,6 @@ class ImageGraphicsView(QGraphicsView):
             return closest_idx
         return None
 
-    # --------------------------------------------------------------------
-    # ZOOM
-    # --------------------------------------------------------------------
-    def wheelEvent(self, event):
-        zoom_in_factor = 1.25
-        zoom_out_factor = 1 / zoom_in_factor
-        if event.angleDelta().y() > 0:
-            self.scale(zoom_in_factor, zoom_in_factor)
-        else:
-            self.scale(zoom_out_factor, zoom_out_factor)
-        event.accept()
-
     # --------------------------------------------------------------------
     # UTILS
     # --------------------------------------------------------------------
@@ -564,7 +678,6 @@ class ImageGraphicsView(QGraphicsView):
         self._full_path_xy.clear()
 
     def clear_guide_points(self):
-        """Remove all removable anchors, keep S/E. Rebuild path."""
         i = 0
         while i < len(self.anchor_points):
             if self.point_items[i].is_removable():
@@ -584,50 +697,106 @@ class ImageGraphicsView(QGraphicsView):
         self._rebuild_full_path()
 
     def get_full_path_xy(self):
-        """Return the entire path (x,y) array after smoothing."""
         return self._full_path_xy
 
 
+# ------------------------------------------------------------------------
+# Main Window
+# ------------------------------------------------------------------------
 class MainWindow(QMainWindow):
     def __init__(self):
         super().__init__()
         self.setWindowTitle("Test GUI")
 
-        main_widget = QWidget()
-        main_layout = QVBoxLayout(main_widget)
+        self._last_loaded_pixmap = None
+        self._circle_radius_for_later_use = 0
+
+        # Original main widget
+        self._main_widget = QWidget()
+        self._main_layout = QVBoxLayout(self._main_widget)
 
+        # Image view
         self.image_view = ImageGraphicsView()
-        main_layout.addWidget(self.image_view)
+        self._main_layout.addWidget(self.image_view)
 
-        # Buttons layout
+        # Button row
         btn_layout = QHBoxLayout()
 
-        # Load Image
         self.btn_load_image = QPushButton("Load Image")
         self.btn_load_image.clicked.connect(self.load_image)
         btn_layout.addWidget(self.btn_load_image)
 
-        # Export Path
         self.btn_export_path = QPushButton("Export Path")
         self.btn_export_path.clicked.connect(self.export_path)
         btn_layout.addWidget(self.btn_export_path)
 
-        # Clear Points
         self.btn_clear_points = QPushButton("Clear Points")
         self.btn_clear_points.clicked.connect(self.clear_points)
         btn_layout.addWidget(self.btn_clear_points)
 
-        # Toggle Rainbow
         self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
         self.btn_toggle_rainbow.clicked.connect(self.toggle_rainbow)
         btn_layout.addWidget(self.btn_toggle_rainbow)
 
-        main_layout.addLayout(btn_layout)
-        self.setCentralWidget(main_widget)
+        # New circle editor button
+        self.btn_open_editor = QPushButton("Open Circle Editor")
+        self.btn_open_editor.clicked.connect(self.open_circle_editor)
+        btn_layout.addWidget(self.btn_open_editor)
+
+        self._main_layout.addLayout(btn_layout)
+        self.setCentralWidget(self._main_widget)
+
         self.resize(900, 600)
 
+        # We keep references for old/new
+        self._old_central_widget = None
+        self._editor = None
+
+    def open_circle_editor(self):
+        """Removes the current central widget, replaces with circle editor."""
+        if not self._last_loaded_pixmap:
+            print("No image loaded yet! Cannot open circle editor.")
+            return
+
+        # Step 1: take the old widget out of QMainWindow ownership
+        old_widget = self.takeCentralWidget()
+        self._old_central_widget = old_widget
+
+        # Step 2: create the editor
+        init_radius = 20
+        editor = CircleEditorWidget(
+            pixmap=self._last_loaded_pixmap,
+            init_radius=init_radius,
+            done_callback=self._on_circle_editor_done
+        )
+        self._editor = editor
+
+        # Step 3: set the new editor as the central widget
+        self.setCentralWidget(editor)
+
+    def _on_circle_editor_done(self, final_radius):
+        self._circle_radius_for_later_use = final_radius
+        print(f"Circle Editor done. Radius = {final_radius}")
+
+        # Take the editor out
+        editor_widget = self.takeCentralWidget()
+        if editor_widget is not None:
+            editor_widget.setParent(None)
+
+        # Put back the old widget
+        if self._old_central_widget is not None:
+            self.setCentralWidget(self._old_central_widget)
+            self._old_central_widget = None
+
+        # We can delete the editor if we like
+        if self._editor is not None:
+            self._editor.deleteLater()
+            self._editor = None
+
+    # --------------------------------------------------------------------
+    # Existing Functions
+    # --------------------------------------------------------------------
     def toggle_rainbow(self):
-        """Toggle the rainbow mode in the view."""
         self.image_view.toggle_rainbow()
 
     def load_image(self):
@@ -643,8 +812,12 @@ class MainWindow(QMainWindow):
             self.image_view.cost_image_original = cost_img
             self.image_view.cost_image = cost_img.copy()
 
+            # Store a pixmap to reuse
+            pm = QPixmap(file_path)
+            if not pm.isNull():
+                self._last_loaded_pixmap = pm
+
     def export_path(self):
-        """Export the full path (x,y) as a .npy file."""
         full_xy = self.image_view.get_full_path_xy()
         if not full_xy:
             print("No path to export.")
-- 
GitLab