Skip to content
Snippets Groups Projects
GUI_draft_live.py 19.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import math
    import numpy as np
    
    # For smoothing the path
    from scipy.signal import savgol_filter
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
    )
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
    from PyQt5.QtCore import Qt, QRectF
    
    from live_wire import compute_cost_image, find_path
    
    
    class LabeledPointItem(QGraphicsEllipseItem):
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._x = x
            self._y = y
            self._r = radius
            self._removable = removable
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
            self.setZValue(z_value)
    
            self._text_item = None
            if label:
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
                scale_factor = min(circle_diam / text_w, circle_diam / text_h)
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
            scaled_w = raw_rect.width() * scale_factor
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
        def set_pos(self, x, y):
    
            """Positions the circle so that its center is at (x, y)."""
    
            self._x = x
            self._y = y
            self.setPos(x - self._r, y - self._r)
    
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
            return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
    
        def is_removable(self):
            return self._removable
    
    
    class ImageGraphicsView(QGraphicsView):
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
    
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
    
            # Image display item
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
    
            # Parallel lists: anchor_points + LabeledPointItem
            self.anchor_points = []  # List of (x, y)
            self.point_items = []    # List of LabeledPointItem
    
    
            self.dot_radius = 4
            self.path_radius = 1
            self.radius_cost_image = 2  # cost-lowering radius
            self._img_w = 0
            self._img_h = 0
    
    
            self.setDragMode(QGraphicsView.ScrollHandDrag)
            self.viewport().setCursor(Qt.ArrowCursor)
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
            self._dragging_idx = None
            self._drag_offset = (0, 0)
    
            self._drag_counter = 0  # throttles path updates while dragging
    
            # We will keep two copies of the cost image
    
            self.cost_image_original = None
            self.cost_image = None
    
    
            self.full_path_points = []
    
        # --------------------------------------------------------------------
        # LOADING
        # --------------------------------------------------------------------
        def load_image(self, path):
            pixmap = QPixmap(path)
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
                self._clear_all_points()
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
    
                # Place S/E at left and right
                s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
                e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
    
                self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
                self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
    
        # --------------------------------------------------------------------
        # ANCHOR POINTS
        # --------------------------------------------------------------------
        def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
            """
            Insert at index=idx, or -1 => append just before E if E exists.
    
            Clamps x,y so points can't go outside the image.
    
            x_clamped = self._clamp(x, radius, self._img_w - radius)
            y_clamped = self._clamp(y, radius, self._img_h - radius)
    
    
            if idx < 0:
    
                # If we have at least 2 anchors, the last is E => insert before it
    
                if len(self.anchor_points) >= 2:
                    idx = len(self.anchor_points) - 1
                else:
                    idx = len(self.anchor_points)
    
    
            self.anchor_points.insert(idx, (x_clamped, y_clamped))
    
            color = Qt.green if label in ("S", "E") else Qt.red
            item = LabeledPointItem(
                x_clamped, y_clamped, label=label,
                radius=radius, color=color, removable=removable, z_value=z_val
            )
    
            self.point_items.insert(idx, item)
            self.scene.addItem(item)
    
        def _add_guide_point(self, x, y):
    
            """Called when user left-clicks an empty spot. Insert a red guide point, recalc path."""
            # clamp to image boundaries
            x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
            y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
    
            # 1) revert cost
    
            self._revert_cost_to_original()
    
    
            # 2) Insert new anchor
            self._insert_anchor_point(-1, x_clamped, y_clamped, label="", removable=True, z_val=1, radius=self.dot_radius)
    
            # 3) Re-apply cost-lowering
    
            self._apply_all_guide_points_to_cost()
    
            # 4) Rebuild path
            self._rebuild_full_path()
    
        # --------------------------------------------------------------------
        # COST IMAGE
        # --------------------------------------------------------------------
        def _revert_cost_to_original(self):
            if self.cost_image_original is not None:
                self.cost_image = self.cost_image_original.copy()
    
        def _apply_all_guide_points_to_cost(self):
    
            """Lower cost around every REMOVABLE anchor (the red ones)."""
    
            if self.cost_image is None:
                return
            for i, (ax, ay) in enumerate(self.anchor_points):
                if self.point_items[i].is_removable():
                    self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
    
        def _lower_cost_in_circle(self, x_f, y_f, radius):
            if self.cost_image is None:
                return
            h, w = self.cost_image.shape
            row_c = int(round(y_f))
            col_c = int(round(x_f))
            if not (0 <= row_c < h and 0 <= col_c < w):
                return
            global_min = self.cost_image.min()
            r_s = max(0, row_c - radius)
            r_e = min(h, row_c + radius + 1)
            c_s = max(0, col_c - radius)
            c_e = min(w, col_c + radius + 1)
            for rr in range(r_s, r_e):
                for cc in range(c_s, c_e):
                    dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                    if dist <= radius:
                        self.cost_image[rr, cc] = global_min
    
        # --------------------------------------------------------------------
        # PATH BUILDING
        # --------------------------------------------------------------------
        def _rebuild_full_path(self):
            # Remove old path items
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
            # Build subpaths
            if len(self.anchor_points) < 2 or self.cost_image is None:
                return
    
            big_xy = []
    
            for i in range(len(self.anchor_points) - 1):
    
                xA, yA = self.anchor_points[i]
    
                xB, yB = self.anchor_points[i + 1]
    
                sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
                if i == 0:
                    big_xy.extend(sub_xy)
                else:
    
                    # avoid duplicating the shared anchor
    
                    if len(sub_xy) > 1:
                        big_xy.extend(sub_xy[1:])
    
            # Smoothing with Savitzky-Golay
            if len(big_xy) >= 7:
                arr_xy = np.array(big_xy)  # shape (N,2)
                smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
                big_xy = smoothed.tolist()
    
            # Draw them
            for (px, py) in big_xy:
                path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
                                             color=Qt.magenta, removable=False, z_value=0)
                self.full_path_points.append(path_item)
                self.scene.addItem(path_item)
    
    
            for p_item in self.point_items:
                if p_item._text_item:
                    p_item.setZValue(100)
    
        def _compute_subpath_xy(self, xA, yA, xB, yB):
            if self.cost_image is None:
                return []
            h, w = self.cost_image.shape
            rA, cA = int(round(yA)), int(round(xA))
            rB, cB = int(round(yB)), int(round(xB))
    
            rA = max(0, min(rA, h - 1))
            cA = max(0, min(cA, w - 1))
            rB = max(0, min(rB, h - 1))
            cB = max(0, min(cB, w - 1))
    
            try:
                path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
            except ValueError as e:
                print("Error in find_path:", e)
                return []
            return [(c, r) for (r, c) in path_rc]
    
        # --------------------------------------------------------------------
    
        # --------------------------------------------------------------------
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
    
                # Check if user clicked near an existing anchor => drag
                idx = self._find_item_near(event.pos(), threshold=10)
                if idx is not None:
                    # drag existing anchor
                    self._dragging_idx = idx
                    self._drag_counter = 0
    
                    scene_pos = self.mapToScene(event.pos())
                    px, py = self.point_items[idx].get_pos()
                    self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                    self.setDragMode(QGraphicsView.NoDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
                    return
    
                    # If no anchor near => user is placing a new point
    
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
    
            elif event.button() == Qt.RightButton:
    
                # Right-click => remove point if it's removable
                self._remove_point_by_click(event.pos())
    
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            if self._dragging_idx is not None:
                # dragging an anchor
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
                r = self.point_items[self._dragging_idx]._r
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
                self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
                self._drag_counter += 1
                if self._drag_counter >= 4:
    
                    # partial path update => revert cost, reapply, rebuild
    
                    self._drag_counter = 0
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
    
    
                    # anchor_points updated
                    self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                    self._rebuild_full_path()
    
                return
            else:
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
    
                super().mouseMoveEvent(event)
    
        def mouseReleaseEvent(self, event):
            super().mouseReleaseEvent(event)
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
                self.viewport().setCursor(Qt.ArrowCursor)
    
                if self._dragging_idx is not None:
    
                    # done dragging => final path update
    
                    idx = self._dragging_idx
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
    
                    newX, newY = self.point_items[idx].get_pos()
                    self.anchor_points[idx] = (newX, newY)
    
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self._rebuild_full_path()
                else:
    
                    # If not dragging => place a new guide point
                    if not self._was_dragging:
    
                        scene_pos = self.mapToScene(event.pos())
                        x, y = scene_pos.x(), scene_pos.y()
                        self._add_guide_point(x, y)
    
                self._was_dragging = False
    
        def _remove_point_by_click(self, view_pos):
            idx = self._find_item_near(view_pos, threshold=10)
            if idx is None:
                return
    
            if not self.point_items[idx].is_removable():
    
    
            self.scene.removeItem(self.point_items[idx])
            self.point_items.pop(idx)
            self.anchor_points.pop(idx)
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _find_item_near(self, view_pos, threshold=10):
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
            min_dist = float('inf')
            closest_idx = None
            for i, itm in enumerate(self.point_items):
                d = itm.distance_to(x_click, y_click)
                if d < min_dist:
                    min_dist = d
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
        # --------------------------------------------------------------------
        # ZOOM
        # --------------------------------------------------------------------
        def wheelEvent(self, event):
    
            """Zoom in/out with mouse wheel."""
    
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
    
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        # --------------------------------------------------------------------
        # UTILS
        # --------------------------------------------------------------------
        def _clamp(self, val, mn, mx):
            return max(mn, min(val, mx))
    
        def _clear_all_points(self):
            for it in self.point_items:
                self.scene.removeItem(it)
            self.point_items.clear()
            self.anchor_points.clear()
    
            for p in self.full_path_points:
                self.scene.removeItem(p)
            self.full_path_points.clear()
    
        def clear_guide_points(self):
    
            """Remove all removable (guide) anchors, keep S/E. Then rebuild."""
    
            i = 0
            while i < len(self.anchor_points):
                if self.point_items[i].is_removable():
                    self.scene.removeItem(self.point_items[i])
                    del self.point_items[i]
                    del self.anchor_points[i]
                else:
                    i += 1
    
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            main_widget = QWidget()
            main_layout = QVBoxLayout(main_widget)
    
            self.image_view = ImageGraphicsView()
            main_layout.addWidget(self.image_view)
    
    
            btn_layout = QHBoxLayout()
    
    
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
    
            self.btn_export_points = QPushButton("Export Points")
            self.btn_export_points.clicked.connect(self.export_points)
            btn_layout.addWidget(self.btn_export_points)
    
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
            main_layout.addLayout(btn_layout)
            self.setCentralWidget(main_widget)
            self.resize(900, 600)
    
        def load_image(self):
    
            """Open file dialog, load image, compute cost image, store in view."""
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
                cost_img = compute_cost_image(file_path)
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
    
        def export_points(self):
            if not self.image_view.anchor_points:
                print("No anchor points to export.")
                return
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Points", "",
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
                points_array = np.array(self.image_view.anchor_points)
                np.save(file_path, points_array)
                print(f"Exported {len(points_array)} points to {file_path}")
    
        def clear_points(self):
            """Remove all removable anchors (guide points), keep S/E in place."""
            self.image_view.clear_guide_points()
    
        def closeEvent(self, event):
            super().closeEvent(event)
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    if __name__ == "__main__":
        main()