import sys import math import numpy as np # NEW IMPORT from scipy.signal import savgol_filter from PyQt5.QtWidgets import ( QApplication, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton, QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem ) from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont from PyQt5.QtCore import Qt, QRectF from live_wire import compute_cost_image, find_path class LabeledPointItem(QGraphicsEllipseItem): def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None): super().__init__(0, 0, 2*radius, 2*radius, parent) self._x = x self._y = y self._r = radius self._removable = removable pen = QPen(color) brush = QBrush(color) self.setPen(pen) self.setBrush(brush) self.setZValue(z_value) self._text_item = None if label: self._text_item = QGraphicsTextItem(self) self._text_item.setPlainText(label) self._text_item.setDefaultTextColor(QColor("black")) font = QFont("Arial", 14) font.setBold(True) self._text_item.setFont(font) self._scale_text_to_fit() self.set_pos(x, y) def _scale_text_to_fit(self): if not self._text_item: return self._text_item.setScale(1.0) circle_diam = 2 * self._r raw_rect = self._text_item.boundingRect() text_w = raw_rect.width() text_h = raw_rect.height() if text_w > circle_diam or text_h > circle_diam: scale_factor = min(circle_diam / text_w, circle_diam / text_h) self._text_item.setScale(scale_factor) self._center_label() def _center_label(self): if not self._text_item: return ellipse_w = 2 * self._r ellipse_h = 2 * self._r raw_rect = self._text_item.boundingRect() scale_factor = self._text_item.scale() scaled_w = raw_rect.width() * scale_factor scaled_h = raw_rect.height() * scale_factor tx = (ellipse_w - scaled_w) * 0.5 ty = (ellipse_h - scaled_h) * 0.5 self._text_item.setPos(tx, ty) def set_pos(self, x, y): self._x = x self._y = y self.setPos(x - self._r, y - self._r) def get_pos(self): return (self._x, self._y) def distance_to(self, x_other, y_other): return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2) def is_removable(self): return self._removable class ImageGraphicsView(QGraphicsView): def __init__(self, parent=None): super().__init__(parent) self.scene = QGraphicsScene(self) self.setScene(self.scene) # Allow zoom around mouse pointer self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) # Image display item self.image_item = QGraphicsPixmapItem() self.scene.addItem(self.image_item) # Parallel lists self.anchor_points = [] # List[(x, y)] self.point_items = [] # List[LabeledPointItem] self.editor_mode = False self.dot_radius = 4 self.path_radius = 1 self.radius_cost_image = 2 # cost-lowering radius self._img_w = 0 self._img_h = 0 # For pan/drag self.setDragMode(QGraphicsView.ScrollHandDrag) self.viewport().setCursor(Qt.ArrowCursor) self._mouse_pressed = False self._press_view_pos = None self._drag_threshold = 5 self._was_dragging = False self._dragging_idx = None self._drag_offset = (0, 0) # Keep original cost image to revert changes self.cost_image_original = None self.cost_image = None # The path is displayed as small magenta circles in self.full_path_points self.full_path_points = [] # -------------------------------------------------------------------- # LOADING # -------------------------------------------------------------------- def load_image(self, path): pixmap = QPixmap(path) if not pixmap.isNull(): self.image_item.setPixmap(pixmap) self.setSceneRect(QRectF(pixmap.rect())) self._img_w = pixmap.width() self._img_h = pixmap.height() self._clear_all_points() self.resetTransform() self.fitInView(self.image_item, Qt.KeepAspectRatio) # Create S/E s_x, s_y = 0.15*self._img_w, 0.5*self._img_h e_x, e_y = 0.85*self._img_w, 0.5*self._img_h # S => not removable self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6) # E => not removable self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6) def set_editor_mode(self, mode: bool): self.editor_mode = mode # -------------------------------------------------------------------- # ANCHOR POINTS # -------------------------------------------------------------------- def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4): """ Insert at index=idx, or -1 => append just before E if E exists. """ if idx < 0: # If we have at least 2 anchors, the last is E => insert before that if len(self.anchor_points) >= 2: idx = len(self.anchor_points) - 1 else: idx = len(self.anchor_points) self.anchor_points.insert(idx, (x, y)) color = Qt.green if label in ("S","E") else Qt.red item = LabeledPointItem(x, y, label=label, radius=radius, color=color, removable=removable, z_value=z_val) self.point_items.insert(idx, item) self.scene.addItem(item) def _add_guide_point(self, x, y): """ User added a red guide point => lower cost, insert anchor, rebuild path. """ # 1) Revert cost self._revert_cost_to_original() # 2) Insert new anchor (removable) self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius) # 3) Re-apply cost-lowering for all existing guide points self._apply_all_guide_points_to_cost() # 4) Rebuild path self._rebuild_full_path() # -------------------------------------------------------------------- # COST IMAGE # -------------------------------------------------------------------- def _revert_cost_to_original(self): """self.cost_image <- copy of self.cost_image_original""" if self.cost_image_original is not None: self.cost_image = self.cost_image_original.copy() def _apply_all_guide_points_to_cost(self): """Lower cost around every removable anchor.""" if self.cost_image is None: return for i, (ax, ay) in enumerate(self.anchor_points): if self.point_items[i].is_removable(): self._lower_cost_in_circle(ax, ay, self.radius_cost_image) def _lower_cost_in_circle(self, x_f, y_f, radius): """Set cost_image row,col in circle of radius -> global min.""" if self.cost_image is None: return h, w = self.cost_image.shape row_c = int(round(y_f)) col_c = int(round(x_f)) if not (0 <= row_c < h and 0 <= col_c < w): return global_min = self.cost_image.min() r_s = max(0, row_c - radius) r_e = min(h, row_c + radius + 1) c_s = max(0, col_c - radius) c_e = min(w, col_c + radius + 1) for rr in range(r_s, r_e): for cc in range(c_s, c_e): dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2) if dist <= radius: self.cost_image[rr, cc] = global_min # -------------------------------------------------------------------- # PATH BUILDING # -------------------------------------------------------------------- def _rebuild_full_path(self): # Remove old path items for item in self.full_path_points: self.scene.removeItem(item) self.full_path_points.clear() # Build subpaths if len(self.anchor_points) < 2 or self.cost_image is None: return big_xy = [] for i in range(len(self.anchor_points)-1): xA, yA = self.anchor_points[i] xB, yB = self.anchor_points[i+1] sub_xy = self._compute_subpath_xy(xA, yA, xB, yB) if i == 0: big_xy.extend(sub_xy) else: # avoid duplicating the point between subpaths if len(sub_xy) > 1: big_xy.extend(sub_xy[1:]) # --------------------------- # NEW: Smooth the path # --------------------------- # big_xy is a list of (x, y). We'll convert to numpy and run savgol_filter if len(big_xy) >= 7: arr_xy = np.array(big_xy) # shape (N, 2) # Apply Savitzky-Golay filter along axis=0 # window_length=7, polyorder=1 smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0) # Convert back to list of (x, y) big_xy = smoothed.tolist() # Draw them for (px, py) in big_xy: path_item = LabeledPointItem(px, py, label="", radius=self.path_radius, color=Qt.magenta, removable=False, z_value=0) self.full_path_points.append(path_item) self.scene.addItem(path_item) # Ensure S/E stay on top for p_item in self.point_items: if p_item._text_item: p_item.setZValue(100) def _compute_subpath_xy(self, xA, yA, xB, yB): if self.cost_image is None: return [] h, w = self.cost_image.shape rA, cA = int(round(yA)), int(round(xA)) rB, cB = int(round(yB)), int(round(xB)) rA = max(0, min(rA, h-1)) cA = max(0, min(cA, w-1)) rB = max(0, min(rB, h-1)) cB = max(0, min(cB, w-1)) try: path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)]) except ValueError as e: print("Error in find_path:", e) return [] return [(c, r) for (r, c) in path_rc] # -------------------------------------------------------------------- # MOUSE EVENTS (dragging, adding, removing points) # -------------------------------------------------------------------- def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self._mouse_pressed = True self._was_dragging = False self._press_view_pos = event.pos() if self.editor_mode: idx = self._find_item_near(event.pos(), 10) if idx is not None: # drag existing anchor self._dragging_idx = idx scene_pos = self.mapToScene(event.pos()) px, py = self.point_items[idx].get_pos() self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py) self.setDragMode(QGraphicsView.NoDrag) self.viewport().setCursor(Qt.ClosedHandCursor) return else: # If no anchor near, user might be panning self.setDragMode(QGraphicsView.ScrollHandDrag) self.viewport().setCursor(Qt.ClosedHandCursor) else: self.setDragMode(QGraphicsView.ScrollHandDrag) self.viewport().setCursor(Qt.ClosedHandCursor) elif event.button() == Qt.RightButton: if self.editor_mode: self._remove_point_by_click(event.pos()) super().mousePressEvent(event) def mouseMoveEvent(self, event): if self._dragging_idx is not None: # dragging an anchor scene_pos = self.mapToScene(event.pos()) x_new = scene_pos.x() - self._drag_offset[0] y_new = scene_pos.y() - self._drag_offset[1] r = self.point_items[self._dragging_idx]._r x_clamped = self._clamp(x_new, r, self._img_w - r) y_clamped = self._clamp(y_new, r, self._img_h - r) self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped) return else: # if movement > threshold => pan if self._mouse_pressed and (event.buttons() & Qt.LeftButton): dist = (event.pos() - self._press_view_pos).manhattanLength() if dist > self._drag_threshold: self._was_dragging = True super().mouseMoveEvent(event) def mouseReleaseEvent(self, event): super().mouseReleaseEvent(event) if event.button() == Qt.LeftButton and self._mouse_pressed: self._mouse_pressed = False self.viewport().setCursor(Qt.ArrowCursor) if self._dragging_idx is not None: idx = self._dragging_idx self._dragging_idx = None self._drag_offset = (0, 0) self.setDragMode(QGraphicsView.ScrollHandDrag) # update anchor_points newX, newY = self.point_items[idx].get_pos() # even if S/E => update coords self.anchor_points[idx] = (newX, newY) # revert + re-apply cost, rebuild path self._revert_cost_to_original() self._apply_all_guide_points_to_cost() self._rebuild_full_path() else: if not self._was_dragging and self.editor_mode: # user clicked an empty spot => add a guide point scene_pos = self.mapToScene(event.pos()) x, y = scene_pos.x(), scene_pos.y() self._add_guide_point(x, y) self._was_dragging = False def _remove_point_by_click(self, view_pos): idx = self._find_item_near(view_pos, threshold=10) if idx is None: return # check if removable => skip S/E if not self.point_items[idx].is_removable(): return # do nothing # remove anchor self.scene.removeItem(self.point_items[idx]) self.point_items.pop(idx) self.anchor_points.pop(idx) # revert + re-apply cost, rebuild path self._revert_cost_to_original() self._apply_all_guide_points_to_cost() self._rebuild_full_path() def _find_item_near(self, view_pos, threshold=10): scene_pos = self.mapToScene(view_pos) x_click, y_click = scene_pos.x(), scene_pos.y() min_dist = float('inf') closest_idx = None for i, itm in enumerate(self.point_items): d = itm.distance_to(x_click, y_click) if d < min_dist: min_dist = d closest_idx = i if closest_idx is not None and min_dist <= threshold: return closest_idx return None # -------------------------------------------------------------------- # ZOOM # -------------------------------------------------------------------- def wheelEvent(self, event): """ Zoom in/out with mouse wheel """ zoom_in_factor = 1.25 zoom_out_factor = 1 / zoom_in_factor # If the user scrolls upward => zoom in. Otherwise => zoom out. if event.angleDelta().y() > 0: self.scale(zoom_in_factor, zoom_in_factor) else: self.scale(zoom_out_factor, zoom_out_factor) event.accept() # -------------------------------------------------------------------- # UTILS # -------------------------------------------------------------------- def _clamp(self, val, mn, mx): return max(mn, min(val, mx)) def _clear_all_points(self): for it in self.point_items: self.scene.removeItem(it) self.point_items.clear() self.anchor_points.clear() for p in self.full_path_points: self.scene.removeItem(p) self.full_path_points.clear() def clear_guide_points(self): """ Removes all anchors that are 'removable' (guide points), keeps S/E in place. Then reverts cost, re-applies, rebuilds path. """ i = 0 while i < len(self.anchor_points): if self.point_items[i].is_removable(): self.scene.removeItem(self.point_items[i]) del self.point_items[i] del self.anchor_points[i] else: i += 1 for item in self.full_path_points: self.scene.removeItem(item) self.full_path_points.clear() self._revert_cost_to_original() self._apply_all_guide_points_to_cost() self._rebuild_full_path() class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("Test GUI") main_widget = QWidget() main_layout = QVBoxLayout(main_widget) self.image_view = ImageGraphicsView() main_layout.addWidget(self.image_view) btn_layout = QHBoxLayout() self.btn_load_image = QPushButton("Load Image") self.btn_load_image.clicked.connect(self.load_image) btn_layout.addWidget(self.btn_load_image) self.btn_editor_mode = QPushButton("Editor Mode: OFF") self.btn_editor_mode.setCheckable(True) self.btn_editor_mode.setStyleSheet("background-color: lightgray;") self.btn_editor_mode.clicked.connect(self.toggle_editor_mode) btn_layout.addWidget(self.btn_editor_mode) self.btn_export_points = QPushButton("Export Points") self.btn_export_points.clicked.connect(self.export_points) btn_layout.addWidget(self.btn_export_points) self.btn_clear_points = QPushButton("Clear Points") self.btn_clear_points.clicked.connect(self.clear_points) btn_layout.addWidget(self.btn_clear_points) main_layout.addLayout(btn_layout) self.setCentralWidget(main_widget) self.resize(900, 600) def load_image(self): options = QFileDialog.Options() file_path, _ = QFileDialog.getOpenFileName( self, "Open Image", "", "Images (*.png *.jpg *.jpeg *.bmp *.tif)", options=options ) if file_path: self.image_view.load_image(file_path) cost_img = compute_cost_image(file_path) self.image_view.cost_image_original = cost_img self.image_view.cost_image = cost_img.copy() def toggle_editor_mode(self): is_checked = self.btn_editor_mode.isChecked() self.image_view.set_editor_mode(is_checked) if is_checked: self.btn_editor_mode.setText("Editor Mode: ON") self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;") else: self.btn_editor_mode.setText("Editor Mode: OFF") self.btn_editor_mode.setStyleSheet("background-color: lightgray;") def export_points(self): if not self.image_view.anchor_points: print("No anchor points to export.") return options = QFileDialog.Options() file_path, _ = QFileDialog.getSaveFileName( self, "Export Points", "", "NumPy Files (*.npy);;All Files (*)", options=options ) if file_path: points_array = np.array(self.image_view.anchor_points) np.save(file_path, points_array) print(f"Exported {len(points_array)} points to {file_path}") def clear_points(self): """Remove all removable anchors (guide points), keep S/E in place.""" self.image_view.clear_guide_points() def closeEvent(self, event): super().closeEvent(event) def main(): app = QApplication(sys.argv) window = MainWindow() window.show() sys.exit(app.exec_()) if __name__ == "__main__": main()