Skip to content
Snippets Groups Projects
GUI_draft.py 20.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    
    
    # NEW IMPORT
    from scipy.signal import savgol_filter
    
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    
    s224389's avatar
    s224389 committed
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
    
    s224389's avatar
    s224389 committed
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
    
    from PyQt5.QtCore import Qt, QRectF
    
    Christian's avatar
    Christian committed
    from live_wire import compute_cost_image, find_path
    
    
    class LabeledPointItem(QGraphicsEllipseItem):
    
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
    
    s224389's avatar
    s224389 committed
            super().__init__(0, 0, 2*radius, 2*radius, parent)
    
            self._x = x
            self._y = y
            self._r = radius
    
    s224389's avatar
    s224389 committed
            self._removable = removable
    
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
    
            self.setZValue(z_value)
    
    
            self._text_item = None
            if label:
    
    s224389's avatar
    s224389 committed
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
    
                scale_factor = min(circle_diam / text_w, circle_diam / text_h)
    
    s224389's avatar
    s224389 committed
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
    
    Christian's avatar
    Christian committed
            scaled_w = raw_rect.width() * scale_factor
    
    s224389's avatar
    s224389 committed
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
    s224389's avatar
    s224389 committed
            self.setPos(x - self._r, y - self._r)
    
    s224389's avatar
    s224389 committed
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
    
            return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
    
    s224389's avatar
    s224389 committed
        def is_removable(self):
            return self._removable
    
    
    
    class ImageGraphicsView(QGraphicsView):
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
    
            # Allow zoom around mouse pointer
    
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
    
    
            # Image display item
    
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
    
            # Parallel lists
            self.anchor_points = []  # List[(x, y)]
            self.point_items = []    # List[LabeledPointItem]
    
            self.editor_mode = False
    
    Christian's avatar
    Christian committed
            self.dot_radius = 4
    
            self.radius_cost_image = 2  # cost-lowering radius
    
    s224389's avatar
    s224389 committed
            self._img_w = 0
            self._img_h = 0
    
    
            self.setDragMode(QGraphicsView.ScrollHandDrag)
            self.viewport().setCursor(Qt.ArrowCursor)
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
    
            self._dragging_idx = None
            self._drag_offset = (0, 0)
    
            # Keep original cost image to revert changes
            self.cost_image_original = None
    
    Christian's avatar
    Christian committed
            self.cost_image = None
    
    
            # The path is displayed as small magenta circles in self.full_path_points
            self.full_path_points = []
    
        # --------------------------------------------------------------------
        # LOADING
        # --------------------------------------------------------------------
        def load_image(self, path):
            pixmap = QPixmap(path)
    
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
    
    s224389's avatar
    s224389 committed
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
                self._clear_all_points()
    
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
    
                # Create S/E
                s_x, s_y = 0.15*self._img_w, 0.5*self._img_h
                e_x, e_y = 0.85*self._img_w, 0.5*self._img_h
    
    s224389's avatar
    s224389 committed
    
    
                # S => not removable
                self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
                # E => not removable
                self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
    
        def set_editor_mode(self, mode: bool):
            self.editor_mode = mode
    
    
        # --------------------------------------------------------------------
        # ANCHOR POINTS
        # --------------------------------------------------------------------
        def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
            """
            Insert at index=idx, or -1 => append just before E if E exists.
            """
            if idx < 0:
                # If we have at least 2 anchors, the last is E => insert before that
                if len(self.anchor_points) >= 2:
                    idx = len(self.anchor_points) - 1
                else:
                    idx = len(self.anchor_points)
    
            self.anchor_points.insert(idx, (x, y))
            color = Qt.green if label in ("S","E") else Qt.red
            item = LabeledPointItem(x, y, label=label, radius=radius, color=color, 
                                    removable=removable, z_value=z_val)
            self.point_items.insert(idx, item)
            self.scene.addItem(item)
    
        def _add_guide_point(self, x, y):
            """
            User added a red guide point => lower cost, insert anchor, rebuild path.
            """
            # 1) Revert cost
            self._revert_cost_to_original()
            # 2) Insert new anchor (removable)
            self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius)
            # 3) Re-apply cost-lowering for all existing guide points
            self._apply_all_guide_points_to_cost()
            # 4) Rebuild path
            self._rebuild_full_path()
    
        # --------------------------------------------------------------------
        # COST IMAGE
        # --------------------------------------------------------------------
        def _revert_cost_to_original(self):
            """self.cost_image <- copy of self.cost_image_original"""
            if self.cost_image_original is not None:
                self.cost_image = self.cost_image_original.copy()
    
        def _apply_all_guide_points_to_cost(self):
            """Lower cost around every removable anchor."""
            if self.cost_image is None:
                return
            for i, (ax, ay) in enumerate(self.anchor_points):
                if self.point_items[i].is_removable():
    
                    self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
    
    
        def _lower_cost_in_circle(self, x_f, y_f, radius):
            """Set cost_image row,col in circle of radius -> global min."""
            if self.cost_image is None:
                return
            h, w = self.cost_image.shape
            row_c = int(round(y_f))
            col_c = int(round(x_f))
            if not (0 <= row_c < h and 0 <= col_c < w):
                return
            global_min = self.cost_image.min()
            r_s = max(0, row_c - radius)
            r_e = min(h, row_c + radius + 1)
            c_s = max(0, col_c - radius)
            c_e = min(w, col_c + radius + 1)
            for rr in range(r_s, r_e):
                for cc in range(c_s, c_e):
                    dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                    if dist <= radius:
                        self.cost_image[rr, cc] = global_min
    
        # --------------------------------------------------------------------
        # PATH BUILDING
        # --------------------------------------------------------------------
        def _rebuild_full_path(self):
            # Remove old path items
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
            # Build subpaths
            if len(self.anchor_points) < 2 or self.cost_image is None:
                return
    
    s224389's avatar
    s224389 committed
    
    
            big_xy = []
            for i in range(len(self.anchor_points)-1):
                xA, yA = self.anchor_points[i]
                xB, yB = self.anchor_points[i+1]
                sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
                if i == 0:
                    big_xy.extend(sub_xy)
                else:
                    # avoid duplicating the point between subpaths
                    if len(sub_xy) > 1:
                        big_xy.extend(sub_xy[1:])
    
    
            # ---------------------------
            # NEW: Smooth the path
            # ---------------------------
            # big_xy is a list of (x, y). We'll convert to numpy and run savgol_filter
            if len(big_xy) >= 7:
                arr_xy = np.array(big_xy)  # shape (N, 2)
                # Apply Savitzky-Golay filter along axis=0
                # window_length=7, polyorder=1
                smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
                # Convert back to list of (x, y)
                big_xy = smoothed.tolist()
    
    
            # Draw them
            for (px, py) in big_xy:
                path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
                                             color=Qt.magenta, removable=False, z_value=0)
                self.full_path_points.append(path_item)
                self.scene.addItem(path_item)
    
            # Ensure S/E stay on top
            for p_item in self.point_items:
                if p_item._text_item:
                    p_item.setZValue(100)
    
        def _compute_subpath_xy(self, xA, yA, xB, yB):
            if self.cost_image is None:
                return []
            h, w = self.cost_image.shape
            rA, cA = int(round(yA)), int(round(xA))
            rB, cB = int(round(yB)), int(round(xB))
            rA = max(0, min(rA, h-1))
            cA = max(0, min(cA, w-1))
            rB = max(0, min(rB, h-1))
            cB = max(0, min(cB, w-1))
            try:
                path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
            except ValueError as e:
                print("Error in find_path:", e)
                return []
            return [(c, r) for (r, c) in path_rc]
    
    s224389's avatar
    s224389 committed
    
    
        # --------------------------------------------------------------------
        # MOUSE EVENTS (dragging, adding, removing points)
        # --------------------------------------------------------------------
    
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
    
                    idx = self._find_item_near(event.pos(), 10)
    
                        # drag existing anchor
    
                        self._dragging_idx = idx
                        scene_pos = self.mapToScene(event.pos())
    
                        px, py = self.point_items[idx].get_pos()
    
                        self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                        self.setDragMode(QGraphicsView.NoDrag)
                        self.viewport().setCursor(Qt.ClosedHandCursor)
                        return
                    else:
    
                        # If no anchor near, user might be panning
    
                        self.setDragMode(QGraphicsView.ScrollHandDrag)
                        self.viewport().setCursor(Qt.ClosedHandCursor)
                else:
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
    
    
            elif event.button() == Qt.RightButton:
                if self.editor_mode:
    
                    self._remove_point_by_click(event.pos())
    
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
    
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
                r = self.point_items[self._dragging_idx]._r
    
    s224389's avatar
    s224389 committed
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
    
                self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
                # if movement > threshold => pan
    
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
                super().mouseMoveEvent(event)
    
    
        def mouseReleaseEvent(self, event):
    
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
                self.viewport().setCursor(Qt.ArrowCursor)
    
                    idx = self._dragging_idx
    
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
    
    
                    # update anchor_points
                    newX, newY = self.point_items[idx].get_pos()
                    # even if S/E => update coords
                    self.anchor_points[idx] = (newX, newY)
    
                    # revert + re-apply cost, rebuild path
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self._rebuild_full_path()
    
    
                else:
                    if not self._was_dragging and self.editor_mode:
    
                        # user clicked an empty spot => add a guide point
                        scene_pos = self.mapToScene(event.pos())
                        x, y = scene_pos.x(), scene_pos.y()
                        self._add_guide_point(x, y)
    
    
                self._was_dragging = False
    
    
        def _remove_point_by_click(self, view_pos):
            idx = self._find_item_near(view_pos, threshold=10)
            if idx is None:
                return
            # check if removable => skip S/E
            if not self.point_items[idx].is_removable():
                return  # do nothing
    
            # remove anchor
            self.scene.removeItem(self.point_items[idx])
            self.point_items.pop(idx)
            self.anchor_points.pop(idx)
    
            # revert + re-apply cost, rebuild path
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _find_item_near(self, view_pos, threshold=10):
    
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
            min_dist = float('inf')
    
            closest_idx = None
            for i, itm in enumerate(self.point_items):
                d = itm.distance_to(x_click, y_click)
                if d < min_dist:
                    min_dist = d
    
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
    
        # --------------------------------------------------------------------
        # ZOOM
        # --------------------------------------------------------------------
        def wheelEvent(self, event):
    
    Christian's avatar
    Christian committed
            """
    
            Zoom in/out with mouse wheel
    
    Christian's avatar
    Christian committed
            """
    
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
    
            # If the user scrolls upward => zoom in. Otherwise => zoom out.
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        # --------------------------------------------------------------------
        # UTILS
        # --------------------------------------------------------------------
        def _clamp(self, val, mn, mx):
            return max(mn, min(val, mx))
    
        def _clear_all_points(self):
            for it in self.point_items:
                self.scene.removeItem(it)
            self.point_items.clear()
            self.anchor_points.clear()
    
            for p in self.full_path_points:
                self.scene.removeItem(p)
            self.full_path_points.clear()
    
        def clear_guide_points(self):
            """
            Removes all anchors that are 'removable' (guide points),
            keeps S/E in place. Then reverts cost, re-applies, rebuilds path.
            """
            i = 0
            while i < len(self.anchor_points):
                if self.point_items[i].is_removable():
                    self.scene.removeItem(self.point_items[i])
                    del self.point_items[i]
                    del self.anchor_points[i]
                else:
                    i += 1
    
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            main_widget = QWidget()
            main_layout = QVBoxLayout(main_widget)
    
            self.image_view = ImageGraphicsView()
            main_layout.addWidget(self.image_view)
    
            btn_layout = QHBoxLayout()
    
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
            self.btn_editor_mode = QPushButton("Editor Mode: OFF")
            self.btn_editor_mode.setCheckable(True)
            self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
            self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
            btn_layout.addWidget(self.btn_editor_mode)
    
            self.btn_export_points = QPushButton("Export Points")
            self.btn_export_points.clicked.connect(self.export_points)
            btn_layout.addWidget(self.btn_export_points)
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
            main_layout.addLayout(btn_layout)
            self.setCentralWidget(main_widget)
            self.resize(900, 600)
    
        def load_image(self):
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
    
                cost_img = compute_cost_image(file_path)
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
    
    
        def toggle_editor_mode(self):
            is_checked = self.btn_editor_mode.isChecked()
            self.image_view.set_editor_mode(is_checked)
            if is_checked:
                self.btn_editor_mode.setText("Editor Mode: ON")
                self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
            else:
                self.btn_editor_mode.setText("Editor Mode: OFF")
                self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
    
        def export_points(self):
    
            if not self.image_view.anchor_points:
                print("No anchor points to export.")
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Points", "",
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
    
                points_array = np.array(self.image_view.anchor_points)
    
                np.save(file_path, points_array)
                print(f"Exported {len(points_array)} points to {file_path}")
    
        def clear_points(self):
    
            """Remove all removable anchors (guide points), keep S/E in place."""
            self.image_view.clear_guide_points()
    
        def closeEvent(self, event):
            super().closeEvent(event)
    
    s224389's avatar
    s224389 committed
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    
    s224389's avatar
    s224389 committed
    if __name__ == "__main__":