Skip to content
Snippets Groups Projects
Select Git revision
  • 7418827e42761bb1d1fef77e6c69a2aaa7afab69
  • main default protected
  • GUI
  • christian_test
4 results

GUI_draft_live.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    GUI_draft_live.py 23.61 KiB
    import sys
    import math
    import numpy as np
    
    # For smoothing the path
    from scipy.signal import savgol_filter
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
    )
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
    from PyQt5.QtCore import Qt, QRectF
    
    from live_wire import compute_cost_image, find_path
    
    
    class LabeledPointItem(QGraphicsEllipseItem):
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._x = x
            self._y = y
            self._r = radius
            self._removable = removable
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
            self.setZValue(z_value)
    
            self._text_item = None
            if label:
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
                scale_factor = min(circle_diam / text_w, circle_diam / text_h)
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
            scaled_w = raw_rect.width() * scale_factor
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
        def set_pos(self, x, y):
            """Positions the circle so its center is at (x, y)."""
            self._x = x
            self._y = y
            self.setPos(x - self._r, y - self._r)
    
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
            return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
    
        def is_removable(self):
            return self._removable
    
    
    class ImageGraphicsView(QGraphicsView):
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
            # Zoom around mouse pointer
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
    
            # Image display
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
            self.anchor_points = []    # List[(x, y)]
            self.point_items = []      # LabeledPointItem objects
            self.full_path_points = [] # QGraphicsEllipseItems for the path
    
            # We'll store the entire path coords (smoothed) for reference
            self._full_path_xy = []
    
            self.dot_radius = 4
            self.path_radius = 1
            self.radius_cost_image = 2
            self._img_w = 0
            self._img_h = 0
    
            # Pan/Drag
            self.setDragMode(QGraphicsView.ScrollHandDrag)
            self.viewport().setCursor(Qt.ArrowCursor)
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
            self._dragging_idx = None
            self._drag_offset = (0, 0)
            self._drag_counter = 0
    
            # Cost images
            self.cost_image_original = None
            self.cost_image = None
    
            # Rainbow toggle
            self._rainbow_enabled = True
    
        def set_rainbow_enabled(self, enabled: bool):
            """Enable/disable rainbow mode, then rebuild the path."""
            self._rainbow_enabled = enabled
            self._rebuild_full_path()
    
        def toggle_rainbow(self):
            """Flip the rainbow mode and rebuild path."""
            self._rainbow_enabled = not self._rainbow_enabled
            self._rebuild_full_path()
    
        # --------------------------------------------------------------------
        # LOADING
        # --------------------------------------------------------------------
        def load_image(self, path):
            pixmap = QPixmap(path)
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
                self._clear_all_points()
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
                # By default, add S/E
                s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
                e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
                self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
                self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
    
        # --------------------------------------------------------------------
        # ANCHOR POINTS
        # --------------------------------------------------------------------
        def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
            """Insert anchor at index=idx (or -1 => before E). Clamps x,y to image bounds."""
            x_clamped = self._clamp(x, radius, self._img_w - radius)
            y_clamped = self._clamp(y, radius, self._img_h - radius)
    
            if idx < 0:
                if len(self.anchor_points) >= 2:
                    idx = len(self.anchor_points) - 1
                else:
                    idx = len(self.anchor_points)
    
            self.anchor_points.insert(idx, (x_clamped, y_clamped))
    
            color = Qt.green if label in ("S", "E") else Qt.red
            item = LabeledPointItem(x_clamped, y_clamped,
                                    label=label, radius=radius, color=color,
                                    removable=removable, z_value=z_val)
            self.point_items.insert(idx, item)
            self.scene.addItem(item)
    
        def _add_guide_point(self, x, y):
            """User clicked => find the correct sub-path, insert the point in that sub-path."""
            x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
            y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
    
            self._revert_cost_to_original()
    
            if not self._full_path_xy:
                # If there's no existing path built, just insert at the end
                self._insert_anchor_point(-1, x_clamped, y_clamped,
                                          label="", removable=True, z_val=1, radius=self.dot_radius)
            else:
                # Insert the new anchor in between the correct anchors,
                # by finding nearest coordinate in _full_path_xy, then
                # walking left+right until we find bounding anchors.
                self._insert_anchor_between_subpath(x_clamped, y_clamped)
    
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _insert_anchor_between_subpath(self, x_new, y_new):
            """Find the subpath bounding (x_new,y_new) and insert the new anchor accordingly."""
            if not self._full_path_xy:
                # Fallback if no path
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            # 1) Find nearest coordinate in the path
            best_idx = None
            best_d2 = float('inf')
            for i, (px, py) in enumerate(self._full_path_xy):
                dx = px - x_new
                dy = py - y_new
                d2 = dx*dx + dy*dy
                if d2 < best_d2:
                    best_d2 = d2
                    best_idx = i
    
            if best_idx is None:
                # fallback
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            def approx_equal(xa, ya, xb, yb, tol=1e-3):
                return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)
    
            def is_anchor(coord):
                cx, cy = coord
                for (ax, ay) in self.anchor_points:
                    if approx_equal(ax, ay, cx, cy):
                        return True
                return False
    
            # 2) Walk left
            left_anchor_pt = None
            iL = best_idx
            while iL >= 0:
                px, py = self._full_path_xy[iL]
                if is_anchor((px, py)):
                    left_anchor_pt = (px, py)
                    break
                iL -= 1
    
            # 3) Walk right
            right_anchor_pt = None
            iR = best_idx
            while iR < len(self._full_path_xy):
                px, py = self._full_path_xy[iR]
                if is_anchor((px, py)):
                    right_anchor_pt = (px, py)
                    break
                iR += 1
    
            # fallback if missing anchors
            if not left_anchor_pt or not right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            # If they happen to be the same anchor, fallback
            if left_anchor_pt == right_anchor_pt:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            # 4) Map these anchor coords to indices in self.anchor_points
            left_idx = None
            right_idx = None
            for i, (ax, ay) in enumerate(self.anchor_points):
                if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
                    left_idx = i
                if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
                    right_idx = i
    
            if left_idx is None or right_idx is None:
                self._insert_anchor_point(-1, x_new, y_new)
                return
    
            # 5) Insert new point in between
            if left_idx < right_idx:
                insert_idx = left_idx + 1
            else:
                insert_idx = right_idx + 1
    
            self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
                                      z_val=1, radius=self.dot_radius)
    
        # --------------------------------------------------------------------
        # COST IMAGE
        # --------------------------------------------------------------------
        def _revert_cost_to_original(self):
            if self.cost_image_original is not None:
                self.cost_image = self.cost_image_original.copy()
    
        def _apply_all_guide_points_to_cost(self):
            if self.cost_image is None:
                return
            for i, (ax, ay) in enumerate(self.anchor_points):
                if self.point_items[i].is_removable():
                    self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
    
        def _lower_cost_in_circle(self, x_f, y_f, radius):
            if self.cost_image is None:
                return
            h, w = self.cost_image.shape
            row_c = int(round(y_f))
            col_c = int(round(x_f))
            if not (0 <= row_c < h and 0 <= col_c < w):
                return
            global_min = self.cost_image.min()
            r_s = max(0, row_c - radius)
            r_e = min(h, row_c + radius + 1)
            c_s = max(0, col_c - radius)
            c_e = min(w, col_c + radius + 1)
            for rr in range(r_s, r_e):
                for cc in range(c_s, c_e):
                    dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                    if dist <= radius:
                        self.cost_image[rr, cc] = global_min
    
        # --------------------------------------------------------------------
        # PATH BUILDING
        # --------------------------------------------------------------------
        def _rebuild_full_path(self):
            # Clear old path visuals
            for item in self.full_path_points:
                self.scene.removeItem(item)
            self.full_path_points.clear()
            self._full_path_xy.clear()
    
            if len(self.anchor_points) < 2 or self.cost_image is None:
                return
    
            big_xy = []
            for i in range(len(self.anchor_points) - 1):
                xA, yA = self.anchor_points[i]
                xB, yB = self.anchor_points[i + 1]
                sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
                if i == 0:
                    big_xy.extend(sub_xy)
                else:
                    # Avoid repeating the shared anchor
                    if len(sub_xy) > 1:
                        big_xy.extend(sub_xy[1:])
    
            # Smooth if we have enough points
            if len(big_xy) >= 7:
                arr_xy = np.array(big_xy)
                smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
                big_xy = smoothed.tolist()
    
            # Store the entire path
            self._full_path_xy = big_xy[:]
    
            # Draw the path
            n_points = len(big_xy)
            for i, (px, py) in enumerate(big_xy):
                if n_points > 1:
                    fraction = i / (n_points - 1)
                else:
                    fraction = 0
    
                # If rainbow is on, use the rainbow color; else use a constant color
                if self._rainbow_enabled:
                    color = self._rainbow_color(fraction)
                else:
                    color = Qt.red
    
                path_item = LabeledPointItem(px, py, label="",
                                             radius=self.path_radius,
                                             color=color,
                                             removable=False,
                                             z_value=0)
                self.full_path_points.append(path_item)
                self.scene.addItem(path_item)
    
            # Keep S/E on top if they have labels
            for p_item in self.point_items:
                if p_item._text_item:
                    p_item.setZValue(100)
    
        def _compute_subpath_xy(self, xA, yA, xB, yB):
            """Return the raw path from (xA,yA)->(xB,yB)."""
            if self.cost_image is None:
                return []
            h, w = self.cost_image.shape
            rA, cA = int(round(yA)), int(round(xA))
            rB, cB = int(round(yB)), int(round(xB))
            rA = max(0, min(rA, h - 1))
            cA = max(0, min(cA, w - 1))
            rB = max(0, min(rB, h - 1))
            cB = max(0, min(cB, w - 1))
            try:
                path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
            except ValueError as e:
                print("Error in find_path:", e)
                return []
            return [(c, r) for (r, c) in path_rc]
    
        def _rainbow_color(self, fraction):
            """
            fraction: 0..1
            Returns a QColor whose hue is fraction * 300 (for example),
            at full saturation and full brightness.
            """
            hue = int(300 * fraction)  # up to 300 degrees
            saturation = 255
            value = 255
            return QColor.fromHsv(hue, saturation, value)
    
        # --------------------------------------------------------------------
        # MOUSE EVENTS
        # --------------------------------------------------------------------
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
                # See if user is clicking near an existing anchor => drag it
                idx = self._find_item_near(event.pos(), threshold=10)
                if idx is not None:
                    self._dragging_idx = idx
                    self._drag_counter = 0
    
                    scene_pos = self.mapToScene(event.pos())
                    px, py = self.point_items[idx].get_pos()
                    self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                    self.setDragMode(QGraphicsView.NoDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
                    return
                else:
                    # No anchor => we may add a new point
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
    
            elif event.button() == Qt.RightButton:
                # Right-click => remove anchor if removable
                self._remove_point_by_click(event.pos())
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            if self._dragging_idx is not None:
                # Dragging anchor
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
                # clamp so user can't drag outside
                r = self.point_items[self._dragging_idx]._r
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
                self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
                self._drag_counter += 1
                if self._drag_counter >= 4:
                    # partial path update
                    self._drag_counter = 0
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                    self._rebuild_full_path()
    
            else:
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
    
                super().mouseMoveEvent(event)
    
        def mouseReleaseEvent(self, event):
            super().mouseReleaseEvent(event)
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
                self.viewport().setCursor(Qt.ArrowCursor)
    
                if self._dragging_idx is not None:
                    # finished dragging => final update
                    idx = self._dragging_idx
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
    
                    newX, newY = self.point_items[idx].get_pos()
                    self.anchor_points[idx] = (newX, newY)
    
                    self._revert_cost_to_original()
                    self._apply_all_guide_points_to_cost()
                    self._rebuild_full_path()
    
                else:
                    # If user wasn't dragging => add new guide point
                    if not self._was_dragging:
                        scene_pos = self.mapToScene(event.pos())
                        x, y = scene_pos.x(), scene_pos.y()
                        self._add_guide_point(x, y)
    
                self._was_dragging = False
    
        def _remove_point_by_click(self, view_pos):
            idx = self._find_item_near(view_pos, threshold=10)
            if idx is None:
                return
            # skip if S/E
            if not self.point_items[idx].is_removable():
                return
    
            self.scene.removeItem(self.point_items[idx])
            self.point_items.pop(idx)
            self.anchor_points.pop(idx)
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def _find_item_near(self, view_pos, threshold=10):
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            closest_idx = None
            min_dist = float('inf')
            for i, itm in enumerate(self.point_items):
                d = itm.distance_to(x_click, y_click)
                if d < min_dist:
                    min_dist = d
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
        # --------------------------------------------------------------------
        # ZOOM
        # --------------------------------------------------------------------
        def wheelEvent(self, event):
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        # --------------------------------------------------------------------
        # UTILS
        # --------------------------------------------------------------------
        def _clamp(self, val, mn, mx):
            return max(mn, min(val, mx))
    
        def _clear_all_points(self):
            for it in self.point_items:
                self.scene.removeItem(it)
            self.point_items.clear()
            self.anchor_points.clear()
    
            for p in self.full_path_points:
                self.scene.removeItem(p)
            self.full_path_points.clear()
            self._full_path_xy.clear()
    
        def clear_guide_points(self):
            """Remove all removable anchors, keep S/E. Rebuild path."""
            i = 0
            while i < len(self.anchor_points):
                if self.point_items[i].is_removable():
                    self.scene.removeItem(self.point_items[i])
                    del self.point_items[i]
                    del self.anchor_points[i]
                else:
                    i += 1
    
            for it in self.full_path_points:
                self.scene.removeItem(it)
            self.full_path_points.clear()
            self._full_path_xy.clear()
    
            self._revert_cost_to_original()
            self._apply_all_guide_points_to_cost()
            self._rebuild_full_path()
    
        def get_full_path_xy(self):
            """Return the entire path (x,y) array after smoothing."""
            return self._full_path_xy
    
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            main_widget = QWidget()
            main_layout = QVBoxLayout(main_widget)
    
            self.image_view = ImageGraphicsView()
            main_layout.addWidget(self.image_view)
    
            # Buttons layout
            btn_layout = QHBoxLayout()
    
            # Load Image
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
            # Export Path
            self.btn_export_path = QPushButton("Export Path")
            self.btn_export_path.clicked.connect(self.export_path)
            btn_layout.addWidget(self.btn_export_path)
    
            # Clear Points
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
            # Toggle Rainbow
            self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
            self.btn_toggle_rainbow.clicked.connect(self.toggle_rainbow)
            btn_layout.addWidget(self.btn_toggle_rainbow)
    
            main_layout.addLayout(btn_layout)
            self.setCentralWidget(main_widget)
            self.resize(900, 600)
    
        def toggle_rainbow(self):
            """Toggle the rainbow mode in the view."""
            self.image_view.toggle_rainbow()
    
        def load_image(self):
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
                cost_img = compute_cost_image(file_path)
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
    
        def export_path(self):
            """Export the full path (x,y) as a .npy file."""
            full_xy = self.image_view.get_full_path_xy()
            if not full_xy:
                print("No path to export.")
                return
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Path", "",
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
                arr = np.array(full_xy)
                np.save(file_path, arr)
                print(f"Exported path with {len(arr)} points to {file_path}")
    
        def clear_points(self):
            self.image_view.clear_guide_points()
    
        def closeEvent(self, event):
            super().closeEvent(event)
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    if __name__ == "__main__":
        main()