Skip to content
Snippets Groups Projects
GUI_draft_live.py 29.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import math
    import numpy as np
    
    # For smoothing the path
    from scipy.signal import savgol_filter
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
        QSlider, QLabel
    
    )
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
    
    from PyQt5.QtCore import Qt, QRectF, QSize
    
    
    from live_wire import compute_cost_image, find_path
    
    
    
    # ------------------------------------------------------------------------
    # A pan & zoom QGraphicsView
    # ------------------------------------------------------------------------
    class PanZoomGraphicsView(QGraphicsView):
        def __init__(self, parent=None):
            super().__init__(parent)
            self.setDragMode(QGraphicsView.NoDrag)  # We'll handle panning manually
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
            self._panning = False
            self._pan_start = None
    
        def wheelEvent(self, event):
            """
            Zoom in/out with mouse wheel.
            """
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        def mousePressEvent(self, event):
            """
            If left button: Start panning (unless overridden in a subclass).
            """
            if event.button() == Qt.LeftButton:
                self._panning = True
                self._pan_start = event.pos()
                self.setCursor(Qt.ClosedHandCursor)
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            """
            If panning, translate the scene.
            """
            if self._panning and self._pan_start is not None:
                delta = event.pos() - self._pan_start
                self._pan_start = event.pos()
                self.translate(delta.x(), delta.y())
            super().mouseMoveEvent(event)
    
        def mouseReleaseEvent(self, event):
            """
            End panning.
            """
            if event.button() == Qt.LeftButton:
                self._panning = False
                self.setCursor(Qt.ArrowCursor)
            super().mouseReleaseEvent(event)
    
    
    # ------------------------------------------------------------------------
    # A specialized PanZoomGraphicsView for the circle editor
    #    Only pan if user did NOT click on the draggable circle
    # ------------------------------------------------------------------------
    class CircleEditorGraphicsView(PanZoomGraphicsView):
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                # Check if the user clicked on the circle item
                clicked_item = self.itemAt(event.pos().x(), event.pos().y())
                if clicked_item is not None:
                    # Walk up parent chain to see if it is our DraggableCircleItem
                    it = clicked_item
                    while it is not None and not hasattr(it, "boundingRect"):
                        it = it.parentItem()
                    from PyQt5.QtWidgets import QGraphicsEllipseItem
                    if isinstance(it, DraggableCircleItem):
                        # Let normal item-dragging occur, don't initiate panning
                        return QGraphicsView.mousePressEvent(self, event)
            # Otherwise proceed with normal panning logic
            super().mousePressEvent(event)
    
    
    # ------------------------------------------------------------------------
    # Draggable circle item (centered at (x, y) with radius)
    # ------------------------------------------------------------------------
    class DraggableCircleItem(QGraphicsEllipseItem):
        def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._r = radius
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
    
            # Enable item-based dragging
            self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
                          QGraphicsEllipseItem.ItemIsSelectable |
                          QGraphicsEllipseItem.ItemSendsScenePositionChanges)
    
            # Position so that (x, y) is the center
            self.setPos(x - radius, y - radius)
    
        def set_radius(self, r):
            # Keep the same center, just change radius
            old_center = self.sceneBoundingRect().center()
            self._r = r
            self.setRect(0, 0, 2*r, 2*r)
            new_center = self.sceneBoundingRect().center()
            diff_x = old_center.x() - new_center.x()
            diff_y = old_center.y() - new_center.y()
            self.moveBy(diff_x, diff_y)
    
        def radius(self):
            return self._r
    
    
    # ------------------------------------------------------------------------
    # Circle editor widget with slider + done
    # ------------------------------------------------------------------------
    class CircleEditorWidget(QWidget):
        def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
            super().__init__(parent)
            self._pixmap = pixmap
            self._done_callback = done_callback
            self._init_radius = init_radius
    
            layout = QVBoxLayout(self)
            self.setLayout(layout)
    
            # Use specialized CircleEditorGraphicsView
            self._graphics_view = CircleEditorGraphicsView()
            self._scene = QGraphicsScene(self)
            self._graphics_view.setScene(self._scene)
            layout.addWidget(self._graphics_view)
    
            self._image_item = QGraphicsPixmapItem(self._pixmap)
            self._scene.addItem(self._image_item)
    
            # Put circle in center
            cx = self._pixmap.width() / 2
            cy = self._pixmap.height() / 2
            self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
            self._scene.addItem(self._circle_item)
    
            # Fit in view
            self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
            self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
    
            # Bottom controls (slider + done)
            bottom_layout = QHBoxLayout()
            layout.addLayout(bottom_layout)
    
            lbl = QLabel("size:")
            bottom_layout.addWidget(lbl)
    
            self._slider = QSlider(Qt.Horizontal)
            self._slider.setRange(1, 200)
            self._slider.setValue(self._init_radius)
            bottom_layout.addWidget(self._slider)
    
            self._btn_done = QPushButton("Done")
            bottom_layout.addWidget(self._btn_done)
    
            # Connect signals
            self._slider.valueChanged.connect(self._on_slider_changed)
            self._btn_done.clicked.connect(self._on_done_clicked)
    
        def _on_slider_changed(self, value):
            self._circle_item.set_radius(value)
    
        def _on_done_clicked(self):
            final_radius = self._circle_item.radius()
            if self._done_callback is not None:
                self._done_callback(final_radius)
    
        def sizeHint(self):
            return QSize(800, 600)
    
    
    # ------------------------------------------------------------------------
    # Labeled point item
    # ------------------------------------------------------------------------
    
    class LabeledPointItem(QGraphicsEllipseItem):
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._x = x
            self._y = y
            self._r = radius
            self._removable = removable
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
            self.setZValue(z_value)
    
            self._text_item = None
            if label:
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
                scale_factor = min(circle_diam / text_w, circle_diam / text_h)
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
            scaled_w = raw_rect.width() * scale_factor
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
        def set_pos(self, x, y):
    
    Christian's avatar
    Christian committed
            """Positions the circle so its center is at (x, y)."""
    
            self._x = x
            self._y = y
            self.setPos(x - self._r, y - self._r)
    
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
            return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
    
        def is_removable(self):
            return self._removable
    
    
    
    # ------------------------------------------------------------------------
    # The original ImageGraphicsView with pan & zoom
    # ------------------------------------------------------------------------
    class ImageGraphicsView(PanZoomGraphicsView):
    
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
    
    Christian's avatar
    Christian committed
            # Image display
    
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
    
    Christian's avatar
    Christian committed
            self.anchor_points = []    # List[(x, y)]
    
            self.point_items = []      # LabeledPointItem
            self.full_path_points = [] # QGraphicsEllipseItems for path
            self._full_path_xy = []    # entire path coords (smoothed)
    
    
            self.dot_radius = 4
            self.path_radius = 1
    
    Christian's avatar
    Christian committed
            self.radius_cost_image = 2
    
            self._img_w = 0
            self._img_h = 0
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
            self._dragging_idx = None
            self._drag_offset = (0, 0)
    
    Christian's avatar
    Christian committed
            self._drag_counter = 0
    
    Christian's avatar
    Christian committed
            # Cost images
    
            self.cost_image_original = None
            self.cost_image = None
    
    
            # Rainbow toggle
            self._rainbow_enabled = True
    
        def set_rainbow_enabled(self, enabled: bool):
            """Enable/disable rainbow mode, then rebuild the path."""
            self._rainbow_enabled = enabled
            self._rebuild_full_path()
    
        def toggle_rainbow(self):
            """Flip the rainbow mode and rebuild path."""
            self._rainbow_enabled = not self._rainbow_enabled
            self._rebuild_full_path()
    
    
        # --------------------------------------------------------------------
        # LOADING
        # --------------------------------------------------------------------
        def load_image(self, path):
            pixmap = QPixmap(path)
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
                self._clear_all_points()
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
    
    Christian's avatar
    Christian committed
                # By default, add S/E
    
                s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
                e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
    
                self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
                self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
    
        # --------------------------------------------------------------------
        # ANCHOR POINTS
        # --------------------------------------------------------------------
        def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
    
    Christian's avatar
    Christian committed
            """Insert anchor at index=idx (or -1 => before E). Clamps x,y to image bounds."""
    
            x_clamped = self._clamp(x, radius, self._img_w - radius)
            y_clamped = self._clamp(y, radius, self._img_h - radius)
    
    
            if idx < 0:
                if len(self.anchor_points) >= 2:
                    idx = len(self.anchor_points) - 1
                else:
                    idx = len(self.anchor_points)
    
    
            self.anchor_points.insert(idx, (x_clamped, y_clamped))
    
            color = Qt.green if label in ("S", "E") else Qt.red
    
    Christian's avatar
    Christian committed
            item = LabeledPointItem(x_clamped, y_clamped,
                                    label=label, radius=radius, color=color,
                                    removable=removable, z_value=z_val)
    
            self.point_items.insert(idx, item)
            self.scene.addItem(item)
    
        def _add_guide_point(self, x, y):
    
            x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
            y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
    
    
            self._revert_cost_to_original()
    
    Christian's avatar
    Christian committed
            if not self._full_path_xy:
                self._insert_anchor_point(-1, x_clamped, y_clamped,
                                          label="", removable=True, z_val=1, radius=self.dot_radius)
            else:
                self._insert_anchor_between_subpath(x_clamped, y_clamped)
    
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    Christian's avatar
    Christian committed
        def _insert_anchor_between_subpath(self, x_new, y_new):
            if not self._full_path_xy:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            best_idx = None
            best_d2 = float('inf')
            for i, (px, py) in enumerate(self._full_path_xy):
                dx = px - x_new
                dy = py - y_new
                d2 = dx*dx + dy*dy
                if d2 < best_d2:
                    best_d2 = d2
                    best_idx = i
    
            if best_idx is None:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            def approx_equal(xa, ya, xb, yb, tol=1e-3):
                return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)
    
            def is_anchor(coord):
                cx, cy = coord
                for (ax, ay) in self.anchor_points:
                    if approx_equal(ax, ay, cx, cy):
                        return True
                return False
    
    
            # Walk left
    
    Christian's avatar
    Christian committed
            left_anchor_pt = None
            iL = best_idx
            while iL >= 0:
                px, py = self._full_path_xy[iL]
                if is_anchor((px, py)):
                    left_anchor_pt = (px, py)
                    break
                iL -= 1
    
    
            # Walk right
    
    Christian's avatar
    Christian committed
            iR = best_idx
            while iR < len(self._full_path_xy):
                px, py = self._full_path_xy[iR]
                if is_anchor((px, py)):
                    right_anchor_pt = (px, py)
                    break
                iR += 1
    
            if not left_anchor_pt or not right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
    
            if left_anchor_pt == right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
    
    Christian's avatar
    Christian committed
            left_idx = None
            right_idx = None
            for i, (ax, ay) in enumerate(self.anchor_points):
                if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
                    left_idx = i
                if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
                    right_idx = i
    
            if left_idx is None or right_idx is None:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            if left_idx < right_idx:
                insert_idx = left_idx + 1
            else:
    
    Christian's avatar
    Christian committed
    
            self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
                                      z_val=1, radius=self.dot_radius)
    
    
        # --------------------------------------------------------------------
        # COST IMAGE
        # --------------------------------------------------------------------
        def _revert_cost_to_original(self):
            if self.cost_image_original is not None:
                self.cost_image = self.cost_image_original.copy()
    
        def _apply_all_guide_points_to_cost(self):
            if self.cost_image is None:
                return
            for i, (ax, ay) in enumerate(self.anchor_points):
                if self.point_items[i].is_removable():
                    self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
    
        def _lower_cost_in_circle(self, x_f, y_f, radius):
            if self.cost_image is None:
                return
            h, w = self.cost_image.shape
            row_c = int(round(y_f))
            col_c = int(round(x_f))
            if not (0 <= row_c < h and 0 <= col_c < w):
                return
            global_min = self.cost_image.min()
            r_s = max(0, row_c - radius)
            r_e = min(h, row_c + radius + 1)
            c_s = max(0, col_c - radius)
            c_e = min(w, col_c + radius + 1)
            for rr in range(r_s, r_e):
                for cc in range(c_s, c_e):
                    dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                    if dist <= radius:
                        self.cost_image[rr, cc] = global_min
    
        # --------------------------------------------------------------------
        # PATH BUILDING
        # --------------------------------------------------------------------
        def _rebuild_full_path(self):
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
            if len(self.anchor_points) < 2 or self.cost_image is None:
                return
    
            big_xy = []
    
            for i in range(len(self.anchor_points) - 1):
    
                xA, yA = self.anchor_points[i]
    
                xB, yB = self.anchor_points[i + 1]
    
                sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
                if i == 0:
                    big_xy.extend(sub_xy)
                else:
                    if len(sub_xy) > 1:
                        big_xy.extend(sub_xy[1:])
    
            if len(big_xy) >= 7:
    
    Christian's avatar
    Christian committed
                arr_xy = np.array(big_xy)
    
                smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
                big_xy = smoothed.tolist()
    
    
    Christian's avatar
    Christian committed
            self._full_path_xy = big_xy[:]
    
    
            n_points = len(big_xy)
            for i, (px, py) in enumerate(big_xy):
    
                fraction = i / (n_points - 1) if n_points > 1 else 0
    
                if self._rainbow_enabled:
                    color = self._rainbow_color(fraction)
                else:
                    color = Qt.red
    
                path_item = LabeledPointItem(px, py, label="",
                                             radius=self.path_radius,
                                             color=color,
                                             removable=False,
                                             z_value=0)
    
                self.full_path_points.append(path_item)
                self.scene.addItem(path_item)
    
    
            # Keep anchor labels on top
    
            for p_item in self.point_items:
                if p_item._text_item:
                    p_item.setZValue(100)
    
        def _compute_subpath_xy(self, xA, yA, xB, yB):
            if self.cost_image is None:
                return []
            h, w = self.cost_image.shape
            rA, cA = int(round(yA)), int(round(xA))
            rB, cB = int(round(yB)), int(round(xB))
    
            rA = max(0, min(rA, h - 1))
            cA = max(0, min(cA, w - 1))
            rB = max(0, min(rB, h - 1))
            cB = max(0, min(cB, w - 1))
    
            try:
                path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
            except ValueError as e:
                print("Error in find_path:", e)
                return []
            return [(c, r) for (r, c) in path_rc]
    
    
        def _rainbow_color(self, fraction):
    
            hue = int(300 * fraction)
    
            saturation = 255
            value = 255
            return QColor.fromHsv(hue, saturation, value)
    
    
        # --------------------------------------------------------------------
    
        # MOUSE EVENTS (with pan & zoom from PanZoomGraphicsView)
    
        # --------------------------------------------------------------------
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
    
                idx = self._find_item_near(event.pos(), threshold=10)
                if idx is not None:
                    self._dragging_idx = idx
                    self._drag_counter = 0
                    scene_pos = self.mapToScene(event.pos())
                    px, py = self.point_items[idx].get_pos()
                    self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
    
                    self.setCursor(Qt.ClosedHandCursor)
    
    
            elif event.button() == Qt.RightButton:
    
                self._remove_point_by_click(event.pos())
    
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            if self._dragging_idx is not None:
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
                r = self.point_items[self._dragging_idx]._r
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
                self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
                self._drag_counter += 1
                if self._drag_counter >= 4:
                    self._drag_counter = 0
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
    
                    self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                    self._rebuild_full_path()
    
            else:
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
    
            super().mouseMoveEvent(event)
    
    
        def mouseReleaseEvent(self, event):
            super().mouseReleaseEvent(event)
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
    
                self.setCursor(Qt.ArrowCursor)
    
                if self._dragging_idx is not None:
                    idx = self._dragging_idx
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
    
                    newX, newY = self.point_items[idx].get_pos()
                    self.anchor_points[idx] = (newX, newY)
    
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self._rebuild_full_path()
                else:
    
                    # No drag => add point
    
                        scene_pos = self.mapToScene(event.pos())
                        x, y = scene_pos.x(), scene_pos.y()
                        self._add_guide_point(x, y)
    
                self._was_dragging = False
    
        def _remove_point_by_click(self, view_pos):
            idx = self._find_item_near(view_pos, threshold=10)
            if idx is None:
                return
            if not self.point_items[idx].is_removable():
    
    
            self.scene.removeItem(self.point_items[idx])
            self.point_items.pop(idx)
            self.anchor_points.pop(idx)
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _find_item_near(self, view_pos, threshold=10):
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            closest_idx = None
    
    Christian's avatar
    Christian committed
            min_dist = float('inf')
    
            for i, itm in enumerate(self.point_items):
                d = itm.distance_to(x_click, y_click)
                if d < min_dist:
                    min_dist = d
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
        # --------------------------------------------------------------------
        # UTILS
        # --------------------------------------------------------------------
        def _clamp(self, val, mn, mx):
            return max(mn, min(val, mx))
    
        def _clear_all_points(self):
            for it in self.point_items:
                self.scene.removeItem(it)
            self.point_items.clear()
            self.anchor_points.clear()
    
            for p in self.full_path_points:
                self.scene.removeItem(p)
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
        def clear_guide_points(self):
            i = 0
            while i < len(self.anchor_points):
                if self.point_items[i].is_removable():
                    self.scene.removeItem(self.point_items[i])
                    del self.point_items[i]
                    del self.anchor_points[i]
                else:
                    i += 1
    
    
    Christian's avatar
    Christian committed
            for it in self.full_path_points:
                self.scene.removeItem(it)
    
            self.full_path_points.clear()
    
    Christian's avatar
    Christian committed
            self._full_path_xy.clear()
    
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    Christian's avatar
    Christian committed
        def get_full_path_xy(self):
            return self._full_path_xy
    
    
    # ------------------------------------------------------------------------
    # Main Window
    # ------------------------------------------------------------------------
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
    
            self._last_loaded_pixmap = None
            self._circle_radius_for_later_use = 0
    
            # Original main widget
            self._main_widget = QWidget()
            self._main_layout = QVBoxLayout(self._main_widget)
    
            # Image view
    
            self.image_view = ImageGraphicsView()
    
            self._main_layout.addWidget(self.image_view)
    
            # Button row
    
            btn_layout = QHBoxLayout()
    
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
    
    Christian's avatar
    Christian committed
            self.btn_export_path = QPushButton("Export Path")
            self.btn_export_path.clicked.connect(self.export_path)
            btn_layout.addWidget(self.btn_export_path)
    
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
    
            self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
            self.btn_toggle_rainbow.clicked.connect(self.toggle_rainbow)
            btn_layout.addWidget(self.btn_toggle_rainbow)
    
    
            # New circle editor button
            self.btn_open_editor = QPushButton("Open Circle Editor")
            self.btn_open_editor.clicked.connect(self.open_circle_editor)
            btn_layout.addWidget(self.btn_open_editor)
    
            self._main_layout.addLayout(btn_layout)
            self.setCentralWidget(self._main_widget)
    
    
            self.resize(900, 600)
    
    
            # We keep references for old/new
            self._old_central_widget = None
            self._editor = None
    
        def open_circle_editor(self):
            """Removes the current central widget, replaces with circle editor."""
            if not self._last_loaded_pixmap:
                print("No image loaded yet! Cannot open circle editor.")
                return
    
            # Step 1: take the old widget out of QMainWindow ownership
            old_widget = self.takeCentralWidget()
            self._old_central_widget = old_widget
    
            # Step 2: create the editor
            init_radius = 20
            editor = CircleEditorWidget(
                pixmap=self._last_loaded_pixmap,
                init_radius=init_radius,
                done_callback=self._on_circle_editor_done
            )
            self._editor = editor
    
            # Step 3: set the new editor as the central widget
            self.setCentralWidget(editor)
    
        def _on_circle_editor_done(self, final_radius):
            self._circle_radius_for_later_use = final_radius
            print(f"Circle Editor done. Radius = {final_radius}")
    
            # Take the editor out
            editor_widget = self.takeCentralWidget()
            if editor_widget is not None:
                editor_widget.setParent(None)
    
            # Put back the old widget
            if self._old_central_widget is not None:
                self.setCentralWidget(self._old_central_widget)
                self._old_central_widget = None
    
            # We can delete the editor if we like
            if self._editor is not None:
                self._editor.deleteLater()
                self._editor = None
    
        # --------------------------------------------------------------------
        # Existing Functions
        # --------------------------------------------------------------------
    
        def toggle_rainbow(self):
            self.image_view.toggle_rainbow()
    
    
        def load_image(self):
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
                cost_img = compute_cost_image(file_path)
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
    
    
                # Store a pixmap to reuse
                pm = QPixmap(file_path)
                if not pm.isNull():
                    self._last_loaded_pixmap = pm
    
    
    Christian's avatar
    Christian committed
        def export_path(self):
            full_xy = self.image_view.get_full_path_xy()
            if not full_xy:
                print("No path to export.")
    
                return
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
    
    Christian's avatar
    Christian committed
                self, "Export Path", "",
    
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
    
    Christian's avatar
    Christian committed
                arr = np.array(full_xy)
                np.save(file_path, arr)
                print(f"Exported path with {len(arr)} points to {file_path}")
    
    
        def clear_points(self):
            self.image_view.clear_guide_points()
    
        def closeEvent(self, event):
            super().closeEvent(event)
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    if __name__ == "__main__":
        main()