Skip to content
Snippets Groups Projects
GUI_draft_live.py 39.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import math
    import numpy as np
    
    # For smoothing the path
    from scipy.signal import savgol_filter
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
    
    Christian's avatar
    Christian committed
        QSlider, QLabel, QCheckBox, QGridLayout, QSizePolicy
    
    Christian's avatar
    Christian committed
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont, QImage
    
    from PyQt5.QtCore import Qt, QRectF, QSize
    
    Christian's avatar
    Christian committed
    # Make sure the following imports exist in live_wire.py (or similar):
    #   from skimage import exposure
    #   from skimage.filters import gaussian
    #   def preprocess_image(image, sigma=3, clip_limit=0.01): ...
    
    from live_wire import compute_cost_image, find_path, preprocess_image
    
    # ------------------------------------------------------------------------
    # A pan & zoom QGraphicsView
    # ------------------------------------------------------------------------
    class PanZoomGraphicsView(QGraphicsView):
        def __init__(self, parent=None):
            super().__init__(parent)
            self.setDragMode(QGraphicsView.NoDrag)  # We'll handle panning manually
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
            self._panning = False
            self._pan_start = None
    
    
    Christian's avatar
    Christian committed
            # Let it expand in layouts
            self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
    
    
        def wheelEvent(self, event):
            """
            Zoom in/out with mouse wheel.
            """
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        def mousePressEvent(self, event):
            """
            If left button: Start panning (unless overridden in a subclass).
            """
            if event.button() == Qt.LeftButton:
                self._panning = True
                self._pan_start = event.pos()
                self.setCursor(Qt.ClosedHandCursor)
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            """
            If panning, translate the scene.
            """
            if self._panning and self._pan_start is not None:
                delta = event.pos() - self._pan_start
                self._pan_start = event.pos()
                self.translate(delta.x(), delta.y())
            super().mouseMoveEvent(event)
    
        def mouseReleaseEvent(self, event):
            """
            End panning.
            """
            if event.button() == Qt.LeftButton:
                self._panning = False
                self.setCursor(Qt.ArrowCursor)
            super().mouseReleaseEvent(event)
    
    
    # ------------------------------------------------------------------------
    # A specialized PanZoomGraphicsView for the circle editor
    # ------------------------------------------------------------------------
    class CircleEditorGraphicsView(PanZoomGraphicsView):
    
        def __init__(self, circle_editor_widget, parent=None):
            super().__init__(parent)
            self._circle_editor_widget = circle_editor_widget
    
    
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                # Check if the user clicked on the circle item
    
                clicked_item = self.itemAt(event.pos())
    
                if clicked_item is not None:
                    # Walk up parent chain to see if it is our DraggableCircleItem
                    it = clicked_item
                    while it is not None and not hasattr(it, "boundingRect"):
                        it = it.parentItem()
    
                    if isinstance(it, DraggableCircleItem):
                        # Let normal item-dragging occur, don't initiate panning
                        return QGraphicsView.mousePressEvent(self, event)
            # Otherwise proceed with normal panning logic
            super().mousePressEvent(event)
    
    
        def wheelEvent(self, event):
            """
            Overridden so that if the mouse is hovering over the circle,
            we adjust the circle's radius instead of zooming the image.
            """
            pos_in_widget = event.pos()
            item_under = self.itemAt(pos_in_widget)
            if item_under is not None:
                # climb up the chain to find if it's our DraggableCircleItem
                it = item_under
                while it is not None and not hasattr(it, "boundingRect"):
                    it = it.parentItem()
    
                if isinstance(it, DraggableCircleItem):
                    # Scroll up -> increase radius, scroll down -> decrease
                    delta = event.angleDelta().y()
    
    Christian's avatar
    Christian committed
                    # each wheel "notch" is typically 120
    
                    step = 1 if delta > 0 else -1
    
                    old_r = it.radius()
                    new_r = max(1, old_r + step)
                    it.set_radius(new_r)
    
                    # Also update the slider in the parent CircleEditorWidget
                    self._circle_editor_widget.update_slider_value(new_r)
    
                    event.accept()
                    return
    
            # else do normal pan/zoom
            super().wheelEvent(event)
    
    
    
    # ------------------------------------------------------------------------
    # Draggable circle item (centered at (x, y) with radius)
    # ------------------------------------------------------------------------
    class DraggableCircleItem(QGraphicsEllipseItem):
        def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._r = radius
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
    
            # Enable item-based dragging
            self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
                          QGraphicsEllipseItem.ItemIsSelectable |
                          QGraphicsEllipseItem.ItemSendsScenePositionChanges)
    
            # Position so that (x, y) is the center
            self.setPos(x - radius, y - radius)
    
        def set_radius(self, r):
            # Keep the same center, just change radius
            old_center = self.sceneBoundingRect().center()
            self._r = r
            self.setRect(0, 0, 2*r, 2*r)
            new_center = self.sceneBoundingRect().center()
            diff_x = old_center.x() - new_center.x()
            diff_y = old_center.y() - new_center.y()
            self.moveBy(diff_x, diff_y)
    
        def radius(self):
            return self._r
    
    
    # ------------------------------------------------------------------------
    # Circle editor widget with slider + done
    # ------------------------------------------------------------------------
    class CircleEditorWidget(QWidget):
        def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
            super().__init__(parent)
            self._pixmap = pixmap
            self._done_callback = done_callback
            self._init_radius = init_radius
    
            layout = QVBoxLayout(self)
            self.setLayout(layout)
    
            # Use specialized CircleEditorGraphicsView
    
            self._graphics_view = CircleEditorGraphicsView(circle_editor_widget=self)
    
            self._scene = QGraphicsScene(self)
            self._graphics_view.setScene(self._scene)
            layout.addWidget(self._graphics_view)
    
            self._image_item = QGraphicsPixmapItem(self._pixmap)
            self._scene.addItem(self._image_item)
    
            # Put circle in center
            cx = self._pixmap.width() / 2
            cy = self._pixmap.height() / 2
            self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
            self._scene.addItem(self._circle_item)
    
            # Fit in view
            self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
            self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
    
            # Bottom controls (slider + done)
            bottom_layout = QHBoxLayout()
            layout.addLayout(bottom_layout)
    
            lbl = QLabel("size:")
            bottom_layout.addWidget(lbl)
    
            self._slider = QSlider(Qt.Horizontal)
            self._slider.setRange(1, 200)
            self._slider.setValue(self._init_radius)
            bottom_layout.addWidget(self._slider)
    
            self._btn_done = QPushButton("Done")
            bottom_layout.addWidget(self._btn_done)
    
            # Connect signals
            self._slider.valueChanged.connect(self._on_slider_changed)
            self._btn_done.clicked.connect(self._on_done_clicked)
    
        def _on_slider_changed(self, value):
            self._circle_item.set_radius(value)
    
        def _on_done_clicked(self):
            final_radius = self._circle_item.radius()
            if self._done_callback is not None:
                self._done_callback(final_radius)
    
    
        def update_slider_value(self, new_radius):
            """
            Called by CircleEditorGraphicsView when the user scrolls on the circle item.
            We sync the slider to the new radius.
            """
    
    Christian's avatar
    Christian committed
            self._slider.blockSignals(True)
    
            self._slider.setValue(new_radius)
            self._slider.blockSignals(False)
    
    
        def sizeHint(self):
            return QSize(800, 600)
    
    
    # ------------------------------------------------------------------------
    # Labeled point item
    # ------------------------------------------------------------------------
    
    class LabeledPointItem(QGraphicsEllipseItem):
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._x = x
            self._y = y
            self._r = radius
            self._removable = removable
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
            self.setZValue(z_value)
    
            self._text_item = None
            if label:
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
                scale_factor = min(circle_diam / text_w, circle_diam / text_h)
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
            scaled_w = raw_rect.width() * scale_factor
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
        def set_pos(self, x, y):
    
    Christian's avatar
    Christian committed
            """Positions the circle so its center is at (x, y)."""
    
            self._x = x
            self._y = y
            self.setPos(x - self._r, y - self._r)
    
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
            return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
    
        def is_removable(self):
            return self._removable
    
    
    
    # ------------------------------------------------------------------------
    # The original ImageGraphicsView with pan & zoom
    # ------------------------------------------------------------------------
    class ImageGraphicsView(PanZoomGraphicsView):
    
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
    
    Christian's avatar
    Christian committed
            # Image display
    
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
    
    Christian's avatar
    Christian committed
            self.anchor_points = []    # List[(x, y)]
    
            self.point_items = []      # LabeledPointItem
            self.full_path_points = [] # QGraphicsEllipseItems for path
            self._full_path_xy = []    # entire path coords (smoothed)
    
    
            self.dot_radius = 4
            self.path_radius = 1
    
    Christian's avatar
    Christian committed
            self.radius_cost_image = 2
    
            self._img_w = 0
            self._img_h = 0
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
            self._dragging_idx = None
            self._drag_offset = (0, 0)
    
    Christian's avatar
    Christian committed
            self._drag_counter = 0
    
    Christian's avatar
    Christian committed
            # Cost images
    
            self.cost_image_original = None
            self.cost_image = None
    
    
    Christian's avatar
    Christian committed
            # Rainbow toggle => start with OFF
            self._rainbow_enabled = False
    
            # Smoothing parameters
            self._savgol_window_length = 7
            self._savgol_polyorder = 1
    
    
        def set_rainbow_enabled(self, enabled: bool):
            """Enable/disable rainbow mode, then rebuild the path."""
            self._rainbow_enabled = enabled
            self._rebuild_full_path()
    
        def toggle_rainbow(self):
            """Flip the rainbow mode and rebuild path."""
            self._rainbow_enabled = not self._rainbow_enabled
            self._rebuild_full_path()
    
    
    Christian's avatar
    Christian committed
        def set_savgol_window_length(self, wlen: int):
            """
            Set the window length for the Savitzky-Golay filter and update polyorder
            based on window_length_polyorder_ratio.
            """
            # SavGol requires window_length to be odd and >= 3
            if wlen < 3:
                wlen = 3
            if wlen % 2 == 0:
                wlen += 1
    
            self._savgol_window_length = wlen
    
            # polyorder is nearest integer to (window_length / 7)
            # but must be >= 1 and < window_length
            p = round(wlen / 7.0)
            p = max(1, p)
            if p >= wlen:
                p = wlen - 1
            self._savgol_polyorder = p
    
            self._rebuild_full_path()
    
    
        # --------------------------------------------------------------------
        # LOADING
        # --------------------------------------------------------------------
        def load_image(self, path):
            pixmap = QPixmap(path)
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
                self._clear_all_points()
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
    
    Christian's avatar
    Christian committed
                # By default, add S/E
    
                s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
                e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
    
                self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
                self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
    
        # --------------------------------------------------------------------
        # ANCHOR POINTS
        # --------------------------------------------------------------------
        def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
    
    Christian's avatar
    Christian committed
            """Insert anchor at index=idx (or -1 => before E). Clamps x,y to image bounds."""
    
            x_clamped = self._clamp(x, radius, self._img_w - radius)
            y_clamped = self._clamp(y, radius, self._img_h - radius)
    
    
            if idx < 0:
                if len(self.anchor_points) >= 2:
                    idx = len(self.anchor_points) - 1
                else:
                    idx = len(self.anchor_points)
    
    
            self.anchor_points.insert(idx, (x_clamped, y_clamped))
    
            color = Qt.green if label in ("S", "E") else Qt.red
    
    Christian's avatar
    Christian committed
            item = LabeledPointItem(x_clamped, y_clamped,
                                    label=label, radius=radius, color=color,
                                    removable=removable, z_value=z_val)
    
            self.point_items.insert(idx, item)
            self.scene.addItem(item)
    
        def _add_guide_point(self, x, y):
    
            x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
    
    Christian's avatar
    Christian committed
            y_clamped = self._clamp(y, self.dot_radius, self._img_w - self.dot_radius)
    
            self._revert_cost_to_original()
    
    Christian's avatar
    Christian committed
            if not self._full_path_xy:
                self._insert_anchor_point(-1, x_clamped, y_clamped,
                                          label="", removable=True, z_val=1, radius=self.dot_radius)
            else:
                self._insert_anchor_between_subpath(x_clamped, y_clamped)
    
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    Christian's avatar
    Christian committed
        def _insert_anchor_between_subpath(self, x_new, y_new):
            if not self._full_path_xy:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            best_idx = None
            best_d2 = float('inf')
            for i, (px, py) in enumerate(self._full_path_xy):
                dx = px - x_new
                dy = py - y_new
                d2 = dx*dx + dy*dy
                if d2 < best_d2:
                    best_d2 = d2
                    best_idx = i
    
            if best_idx is None:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            def approx_equal(xa, ya, xb, yb, tol=1e-3):
                return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)
    
            def is_anchor(coord):
                cx, cy = coord
                for (ax, ay) in self.anchor_points:
                    if approx_equal(ax, ay, cx, cy):
                        return True
                return False
    
    
            # Walk left
    
    Christian's avatar
    Christian committed
            left_anchor_pt = None
            iL = best_idx
            while iL >= 0:
                px, py = self._full_path_xy[iL]
                if is_anchor((px, py)):
                    left_anchor_pt = (px, py)
                    break
                iL -= 1
    
    
            # Walk right
    
    Christian's avatar
    Christian committed
            iR = best_idx
            while iR < len(self._full_path_xy):
                px, py = self._full_path_xy[iR]
                if is_anchor((px, py)):
                    right_anchor_pt = (px, py)
                    break
                iR += 1
    
            if not left_anchor_pt or not right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
    
            if left_anchor_pt == right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
    
    Christian's avatar
    Christian committed
            left_idx = None
            right_idx = None
            for i, (ax, ay) in enumerate(self.anchor_points):
                if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
                    left_idx = i
                if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
                    right_idx = i
    
            if left_idx is None or right_idx is None:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            if left_idx < right_idx:
                insert_idx = left_idx + 1
            else:
    
    Christian's avatar
    Christian committed
    
            self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
                                      z_val=1, radius=self.dot_radius)
    
    
        # --------------------------------------------------------------------
        # COST IMAGE
        # --------------------------------------------------------------------
        def _revert_cost_to_original(self):
            if self.cost_image_original is not None:
                self.cost_image = self.cost_image_original.copy()
    
        def _apply_all_guide_points_to_cost(self):
            if self.cost_image is None:
                return
            for i, (ax, ay) in enumerate(self.anchor_points):
                if self.point_items[i].is_removable():
                    self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
    
        def _lower_cost_in_circle(self, x_f, y_f, radius):
            if self.cost_image is None:
                return
            h, w = self.cost_image.shape
            row_c = int(round(y_f))
            col_c = int(round(x_f))
            if not (0 <= row_c < h and 0 <= col_c < w):
                return
            global_min = self.cost_image.min()
            r_s = max(0, row_c - radius)
            r_e = min(h, row_c + radius + 1)
            c_s = max(0, col_c - radius)
            c_e = min(w, col_c + radius + 1)
            for rr in range(r_s, r_e):
                for cc in range(c_s, c_e):
                    dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                    if dist <= radius:
                        self.cost_image[rr, cc] = global_min
    
        # --------------------------------------------------------------------
        # PATH BUILDING
        # --------------------------------------------------------------------
        def _rebuild_full_path(self):
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
            if len(self.anchor_points) < 2 or self.cost_image is None:
                return
    
            big_xy = []
    
            for i in range(len(self.anchor_points) - 1):
    
                xA, yA = self.anchor_points[i]
    
                xB, yB = self.anchor_points[i + 1]
    
                sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
                if i == 0:
                    big_xy.extend(sub_xy)
                else:
                    if len(sub_xy) > 1:
                        big_xy.extend(sub_xy[1:])
    
    
    Christian's avatar
    Christian committed
            if len(big_xy) >= self._savgol_window_length:
    
    Christian's avatar
    Christian committed
                arr_xy = np.array(big_xy)
    
    Christian's avatar
    Christian committed
                smoothed = savgol_filter(
                    arr_xy,
                    window_length=self._savgol_window_length,
                    polyorder=self._savgol_polyorder,
                    axis=0
                )
    
                big_xy = smoothed.tolist()
    
    
    Christian's avatar
    Christian committed
            self._full_path_xy = big_xy[:]
    
    
            n_points = len(big_xy)
            for i, (px, py) in enumerate(big_xy):
    
                fraction = i / (n_points - 1) if n_points > 1 else 0
    
                if self._rainbow_enabled:
                    color = self._rainbow_color(fraction)
                else:
                    color = Qt.red
    
                path_item = LabeledPointItem(px, py, label="",
                                             radius=self.path_radius,
                                             color=color,
                                             removable=False,
                                             z_value=0)
    
                self.full_path_points.append(path_item)
                self.scene.addItem(path_item)
    
    
            # Keep anchor labels on top
    
            for p_item in self.point_items:
                if p_item._text_item:
                    p_item.setZValue(100)
    
        def _compute_subpath_xy(self, xA, yA, xB, yB):
            if self.cost_image is None:
                return []
            h, w = self.cost_image.shape
            rA, cA = int(round(yA)), int(round(xA))
            rB, cB = int(round(yB)), int(round(xB))
    
            rA = max(0, min(rA, h - 1))
            cA = max(0, min(cA, w - 1))
            rB = max(0, min(rB, h - 1))
            cB = max(0, min(cB, w - 1))
    
            try:
                path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
            except ValueError as e:
                print("Error in find_path:", e)
                return []
            return [(c, r) for (r, c) in path_rc]
    
    
        def _rainbow_color(self, fraction):
    
            hue = int(300 * fraction)
    
            saturation = 255
            value = 255
            return QColor.fromHsv(hue, saturation, value)
    
    
        # --------------------------------------------------------------------
    
    Christian's avatar
    Christian committed
        # MOUSE EVENTS
    
        # --------------------------------------------------------------------
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
    
                idx = self._find_item_near(event.pos(), threshold=10)
                if idx is not None:
                    self._dragging_idx = idx
                    self._drag_counter = 0
                    scene_pos = self.mapToScene(event.pos())
                    px, py = self.point_items[idx].get_pos()
                    self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
    
                    self.setCursor(Qt.ClosedHandCursor)
    
    
            elif event.button() == Qt.RightButton:
    
                self._remove_point_by_click(event.pos())
    
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            if self._dragging_idx is not None:
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
                r = self.point_items[self._dragging_idx]._r
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
                self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
                self._drag_counter += 1
                if self._drag_counter >= 4:
                    self._drag_counter = 0
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
    
                    self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                    self._rebuild_full_path()
    
            else:
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
    
            super().mouseMoveEvent(event)
    
    
        def mouseReleaseEvent(self, event):
            super().mouseReleaseEvent(event)
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
    
                self.setCursor(Qt.ArrowCursor)
    
                if self._dragging_idx is not None:
                    idx = self._dragging_idx
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
    
                    newX, newY = self.point_items[idx].get_pos()
                    self.anchor_points[idx] = (newX, newY)
    
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self._rebuild_full_path()
                else:
    
                    # No drag => add point
    
                        scene_pos = self.mapToScene(event.pos())
                        x, y = scene_pos.x(), scene_pos.y()
                        self._add_guide_point(x, y)
    
                self._was_dragging = False
    
        def _remove_point_by_click(self, view_pos):
            idx = self._find_item_near(view_pos, threshold=10)
            if idx is None:
                return
            if not self.point_items[idx].is_removable():
    
    
            self.scene.removeItem(self.point_items[idx])
            self.point_items.pop(idx)
            self.anchor_points.pop(idx)
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _find_item_near(self, view_pos, threshold=10):
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            closest_idx = None
    
    Christian's avatar
    Christian committed
            min_dist = float('inf')
    
            for i, itm in enumerate(self.point_items):
                d = itm.distance_to(x_click, y_click)
                if d < min_dist:
                    min_dist = d
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
        # --------------------------------------------------------------------
        # UTILS
        # --------------------------------------------------------------------
        def _clamp(self, val, mn, mx):
            return max(mn, min(val, mx))
    
        def _clear_all_points(self):
            for it in self.point_items:
                self.scene.removeItem(it)
            self.point_items.clear()
            self.anchor_points.clear()
    
            for p in self.full_path_points:
                self.scene.removeItem(p)
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
        def clear_guide_points(self):
            i = 0
            while i < len(self.anchor_points):
                if self.point_items[i].is_removable():
                    self.scene.removeItem(self.point_items[i])
                    del self.point_items[i]
                    del self.anchor_points[i]
                else:
                    i += 1
    
    
    Christian's avatar
    Christian committed
            for it in self.full_path_points:
                self.scene.removeItem(it)
    
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    Christian's avatar
    Christian committed
        def get_full_path_xy(self):
            return self._full_path_xy
    
    
    Christian's avatar
    Christian committed
    # ------------------------------------------------------------------------
    # Advanced Settings Widget
    # ------------------------------------------------------------------------
    class AdvancedSettingsWidget(QWidget):
        def __init__(self, main_window, parent=None):
            super().__init__(parent)
            self._main_window = main_window  # to call e.g. main_window.open_circle_editor()
    
            main_layout = QVBoxLayout()
            self.setLayout(main_layout)
    
            # A small grid for the controls (buttons/sliders)
            controls_layout = QGridLayout()
    
            # 1) Rainbow toggle
            self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
            self.btn_toggle_rainbow.clicked.connect(self._on_toggle_rainbow)
            controls_layout.addWidget(self.btn_toggle_rainbow, 0, 0)
    
            # 2) Circle editor
            self.btn_circle_editor = QPushButton("Open Circle Editor")
            self.btn_circle_editor.clicked.connect(self._main_window.open_circle_editor)
            controls_layout.addWidget(self.btn_circle_editor, 0, 1)
    
            # 3) Line smoothing slider + label
            lab_smoothing = QLabel("Line smoothing (SavGol window_length)")
            controls_layout.addWidget(lab_smoothing, 1, 0)
            self.line_smoothing_slider = QSlider(Qt.Horizontal)
            self.line_smoothing_slider.setRange(3, 51)  # allow from 3 to 51
            self.line_smoothing_slider.setValue(7)      # default
            self.line_smoothing_slider.valueChanged.connect(self._on_line_smoothing_slider)
            controls_layout.addWidget(self.line_smoothing_slider, 1, 1)
    
            # 4) Contrast slider + label
            lab_contrast = QLabel("Contrast (clip_limit)")
            controls_layout.addWidget(lab_contrast, 2, 0)
            self.contrast_slider = QSlider(Qt.Horizontal)
            self.contrast_slider.setRange(0, 100)  # 0..100 => 0..1 with step of 0.01
            self.contrast_slider.setValue(1)       # default is 0.01
            self.contrast_slider.valueChanged.connect(self._on_contrast_slider)
            controls_layout.addWidget(self.contrast_slider, 2, 1)
    
            main_layout.addLayout(controls_layout)
    
            # Now a horizontal layout for the two images
            images_layout = QHBoxLayout()
    
            # Contrasted blurred
            self.label_contrasted_blurred = QLabel()
            self.label_contrasted_blurred.setText("CONTRASTED BLURRED IMG")
            self.label_contrasted_blurred.setAlignment(Qt.AlignCenter)
            self.label_contrasted_blurred.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
            self.label_contrasted_blurred.setScaledContents(True)
            images_layout.addWidget(self.label_contrasted_blurred)
    
            # Cost image
            self.label_cost_image = QLabel()
            self.label_cost_image.setText("Current COST IMAGE")
            self.label_cost_image.setAlignment(Qt.AlignCenter)
            self.label_cost_image.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
            self.label_cost_image.setScaledContents(True)
            images_layout.addWidget(self.label_cost_image)
    
            main_layout.addLayout(images_layout)
    
        def _on_toggle_rainbow(self):
            self._main_window.toggle_rainbow()
    
        def _on_line_smoothing_slider(self, value):
            self._main_window.image_view.set_savgol_window_length(value)
    
        def _on_contrast_slider(self, value):
            clip_limit = value / 100.0
            self._main_window.update_contrast(clip_limit)
    
        def update_displays(self, contrasted_img_np, cost_img_np):
            """
            Called by main_window to refresh the two images in the advanced panel.
            contrasted_img_np = the grayscaled blurred+contrasted image as float or 0-1 array
            cost_img_np = the current cost image (numpy array)
            """
            cb_pix = self._np_array_to_qpixmap(contrasted_img_np)
            cost_pix = self._np_array_to_qpixmap(cost_img_np, normalize=True)
    
            if cb_pix is not None:
                self.label_contrasted_blurred.setPixmap(cb_pix)
            if cost_pix is not None:
                self.label_cost_image.setPixmap(cost_pix)
    
        def _np_array_to_qpixmap(self, arr, normalize=False):
            if arr is None:
                return None
            arr_ = arr.copy()
            if normalize:
                mn, mx = arr_.min(), arr_.max()
                if abs(mx - mn) < 1e-12:
                    arr_[:] = 0
                else:
                    arr_ = (arr_ - mn) / (mx - mn)
            arr_ = np.clip(arr_, 0, 1)
            arr_255 = (arr_ * 255).astype(np.uint8)
    
            h, w = arr_255.shape
            qimage = QImage(arr_255.data, w, h, w, QImage.Format_Grayscale8)
            return QPixmap.fromImage(qimage)
    
    
    
    # ------------------------------------------------------------------------
    # Main Window
    # ------------------------------------------------------------------------
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
    
            self._last_loaded_pixmap = None
    
    Christian's avatar
    Christian committed
            self._circle_calibrated_radius = 6
            self._last_loaded_file_path = None
    
    Christian's avatar
    Christian committed
            # For the contrast slider
            self._current_clip_limit = 0.01
    
            # Outer widget + layout
    
            self._main_widget = QWidget()
    
    Christian's avatar
    Christian committed
            self._main_layout = QHBoxLayout(self._main_widget)  # horizontal so we can place advanced on the right
            self._left_panel = QVBoxLayout()  # for the image & row of buttons
            self._main_layout.addLayout(self._left_panel)
            self.setCentralWidget(self._main_widget)
    
            # Image view
    
            self.image_view = ImageGraphicsView()
    
    Christian's avatar
    Christian committed
            self._left_panel.addWidget(self.image_view)
    
            # Button row
    
            btn_layout = QHBoxLayout()
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
    
    Christian's avatar
    Christian committed
            self.btn_export_path = QPushButton("Export Path")
            self.btn_export_path.clicked.connect(self.export_path)
            btn_layout.addWidget(self.btn_export_path)
    
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
    
    Christian's avatar
    Christian committed
            # "Advanced Settings" button
            self.btn_advanced = QPushButton("Advanced Settings")
            self.btn_advanced.setCheckable(True)
            self.btn_advanced.clicked.connect(self._toggle_advanced_settings)
            btn_layout.addWidget(self.btn_advanced)
    
    Christian's avatar
    Christian committed
            self._left_panel.addLayout(btn_layout)
    
    Christian's avatar
    Christian committed
            # Create advanced settings widget (hidden by default)
            self._advanced_widget = AdvancedSettingsWidget(self)
            self._advanced_widget.hide()
            self._main_layout.addWidget(self._advanced_widget)
    
    Christian's avatar
    Christian committed
            self.resize(1000, 600)
    
            self._old_central_widget = None
            self._editor = None
    
    
    Christian's avatar
    Christian committed
        def _toggle_advanced_settings(self, checked):
            if checked:
                self._advanced_widget.show()
            else:
                self._advanced_widget.hide()
            # Ask Qt to re-layout the window so it can expand/shrink as needed:
            self.adjustSize()
    
    
        def open_circle_editor(self):
            """Removes the current central widget, replaces with circle editor."""
            if not self._last_loaded_pixmap:
                print("No image loaded yet! Cannot open circle editor.")
                return
    
            old_widget = self.takeCentralWidget()
            self._old_central_widget = old_widget
    
    
    Christian's avatar
    Christian committed
            init_radius = self._circle_calibrated_radius
    
            editor = CircleEditorWidget(
                pixmap=self._last_loaded_pixmap,
                init_radius=init_radius,
                done_callback=self._on_circle_editor_done
            )
            self._editor = editor
    
            self.setCentralWidget(editor)
    
        def _on_circle_editor_done(self, final_radius):
    
    Christian's avatar
    Christian committed
            self._circle_calibrated_radius = final_radius
    
            print(f"Circle Editor done. Radius = {final_radius}")
    
    
    Christian's avatar
    Christian committed
            if self._last_loaded_file_path:
                cost_img = compute_cost_image(
                    self._last_loaded_file_path,
                    self._circle_calibrated_radius,
                    clip_limit=self._current_clip_limit
                )
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
                self.image_view._apply_all_guide_points_to_cost()
                self.image_view._rebuild_full_path()
                self._update_advanced_images()
    
    
            editor_widget = self.takeCentralWidget()
            if editor_widget is not None:
                editor_widget.setParent(None)
    
            if self._old_central_widget is not None:
                self.setCentralWidget(self._old_central_widget)
                self._old_central_widget = None