diff --git a/GUI_draft.py b/GUI_draft.py index c7c2bc94fb65292ab7f7b134a8203295a388546d..fb0231c6f144e79a0b70f31fddd1ef2a019ea9d7 100644 --- a/GUI_draft.py +++ b/GUI_draft.py @@ -10,29 +10,23 @@ from PyQt5.QtWidgets import ( from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont from PyQt5.QtCore import Qt, QRectF -# Import your live_wire functions from live_wire import compute_cost_image, find_path class LabeledPointItem(QGraphicsEllipseItem): - """ - A circle with optional (bold) label (e.g. 'S'/'E'), - which automatically scales the text if it's bigger than the circle. - """ - def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, parent=None): + def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None): super().__init__(0, 0, 2*radius, 2*radius, parent) - self._x = x # Center x - self._y = y # Center y - self._r = radius # Circle radius + self._x = x + self._y = y + self._r = radius self._removable = removable - # Circle styling pen = QPen(color) brush = QBrush(color) self.setPen(pen) self.setBrush(brush) + self.setZValue(z_value) - # Optional text label self._text_item = None if label: self._text_item = QGraphicsTextItem(self) @@ -43,7 +37,6 @@ class LabeledPointItem(QGraphicsEllipseItem): self._text_item.setFont(font) self._scale_text_to_fit() - # Move so center is at (x, y) self.set_pos(x, y) def _scale_text_to_fit(self): @@ -55,9 +48,7 @@ class LabeledPointItem(QGraphicsEllipseItem): text_w = raw_rect.width() text_h = raw_rect.height() if text_w > circle_diam or text_h > circle_diam: - scale_w = circle_diam / text_w - scale_h = circle_diam / text_h - scale_factor = min(scale_w, scale_h) + scale_factor = min(circle_diam / text_w, circle_diam / text_h) self._text_item.setScale(scale_factor) self._center_label() @@ -75,7 +66,6 @@ class LabeledPointItem(QGraphicsEllipseItem): self._text_item.setPos(tx, ty) def set_pos(self, x, y): - """Move so the circle's center is at (x,y) in scene coords.""" self._x = x self._y = y self.setPos(x - self._r, y - self._r) @@ -84,39 +74,37 @@ class LabeledPointItem(QGraphicsEllipseItem): return (self._x, self._y) def distance_to(self, x_other, y_other): - dx = self._x - x_other - dy = self._y - y_other - return math.sqrt(dx*dx + dy*dy) + return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2) def is_removable(self): return self._removable class ImageGraphicsView(QGraphicsView): - """ - Displays an image and allows placing/dragging labeled points. - Ensures points can't go outside the image boundary. - """ def __init__(self, parent=None): super().__init__(parent) self.scene = QGraphicsScene(self) self.setScene(self.scene) - # Zoom around mouse pointer + # Allow zoom around mouse pointer self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) - # Image item + # Image display item self.image_item = QGraphicsPixmapItem() self.scene.addItem(self.image_item) - self.points = [] # LabeledPointItem objects - self.editor_mode = False + # Parallel lists + self.anchor_points = [] # List[(x, y)] + self.point_items = [] # List[LabeledPointItem] + self.editor_mode = False self.dot_radius = 4 - self.path_radius = 1 # radius of circles in path + self.path_radius = 1 + self.radius_something = 3 # cost-lowering radius self._img_w = 0 self._img_h = 0 + # For pan/drag self.setDragMode(QGraphicsView.ScrollHandDrag) self.viewport().setCursor(Qt.ArrowCursor) @@ -127,52 +115,168 @@ class ImageGraphicsView(QGraphicsView): self._dragging_idx = None self._drag_offset = (0, 0) - # Cost image from compute_cost_image + # Keep original cost image to revert changes + self.cost_image_original = None self.cost_image = None - # All path points displayed in magenta - self.path_points = [] + # The path is displayed as small magenta circles in self.full_path_points + self.full_path_points = [] - def load_image(self, image_path): - pixmap = QPixmap(image_path) + # -------------------------------------------------------------------- + # LOADING + # -------------------------------------------------------------------- + def load_image(self, path): + pixmap = QPixmap(path) if not pixmap.isNull(): self.image_item.setPixmap(pixmap) self.setSceneRect(QRectF(pixmap.rect())) - # Save image dimensions self._img_w = pixmap.width() self._img_h = pixmap.height() - self._clear_point_items(remove_all=True) + self._clear_all_points() self.resetTransform() self.fitInView(self.image_item, Qt.KeepAspectRatio) - # Positions for S/E - s_x = self._img_w * 0.15 - s_y = self._img_h * 0.5 - e_x = self._img_w * 0.85 - e_y = self._img_h * 0.5 - - # Create green S/E with radius=6 - s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False) - e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False) + # Create S/E + s_x, s_y = 0.15*self._img_w, 0.5*self._img_h + e_x, e_y = 0.85*self._img_w, 0.5*self._img_h - self.points = [s_point, e_point] - self.scene.addItem(s_point) - self.scene.addItem(e_point) + # S => not removable + self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6) + # E => not removable + self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6) def set_editor_mode(self, mode: bool): self.editor_mode = mode - def _create_point(self, x, y, label, radius, color, removable=True): - # Clamp coordinates so center doesn't go outside the image - cx = self._clamp(x, radius, self._img_w - radius) - cy = self._clamp(y, radius, self._img_h - radius) - return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable) + # -------------------------------------------------------------------- + # ANCHOR POINTS + # -------------------------------------------------------------------- + def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4): + """ + Insert at index=idx, or -1 => append just before E if E exists. + """ + if idx < 0: + # If we have at least 2 anchors, the last is E => insert before that + if len(self.anchor_points) >= 2: + idx = len(self.anchor_points) - 1 + else: + idx = len(self.anchor_points) + + self.anchor_points.insert(idx, (x, y)) + color = Qt.green if label in ("S","E") else Qt.red + item = LabeledPointItem(x, y, label=label, radius=radius, color=color, + removable=removable, z_value=z_val) + self.point_items.insert(idx, item) + self.scene.addItem(item) + + def _add_guide_point(self, x, y): + """ + User added a red guide point => lower cost, insert anchor, rebuild path. + """ + # 1) Revert cost + self._revert_cost_to_original() + # 2) Insert new anchor (removable) + self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius) + # 3) Re-apply cost-lowering for all existing guide points + self._apply_all_guide_points_to_cost() + # 4) Rebuild path + self._rebuild_full_path() + + # -------------------------------------------------------------------- + # COST IMAGE + # -------------------------------------------------------------------- + def _revert_cost_to_original(self): + """self.cost_image <- copy of self.cost_image_original""" + if self.cost_image_original is not None: + self.cost_image = self.cost_image_original.copy() + + def _apply_all_guide_points_to_cost(self): + """Lower cost around every removable anchor.""" + if self.cost_image is None: + return + for i, (ax, ay) in enumerate(self.anchor_points): + if self.point_items[i].is_removable(): + self._lower_cost_in_circle(ax, ay, self.radius_something) + + def _lower_cost_in_circle(self, x_f, y_f, radius): + """Set cost_image row,col in circle of radius -> global min.""" + if self.cost_image is None: + return + h, w = self.cost_image.shape + row_c = int(round(y_f)) + col_c = int(round(x_f)) + if not (0 <= row_c < h and 0 <= col_c < w): + return + global_min = self.cost_image.min() + r_s = max(0, row_c - radius) + r_e = min(h, row_c + radius + 1) + c_s = max(0, col_c - radius) + c_e = min(w, col_c + radius + 1) + for rr in range(r_s, r_e): + for cc in range(c_s, c_e): + dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2) + if dist <= radius: + self.cost_image[rr, cc] = global_min + + # -------------------------------------------------------------------- + # PATH BUILDING + # -------------------------------------------------------------------- + def _rebuild_full_path(self): + # Remove old path items + for item in self.full_path_points: + self.scene.removeItem(item) + self.full_path_points.clear() + + # Build subpaths + if len(self.anchor_points) < 2 or self.cost_image is None: + return - def _clamp(self, val, min_val, max_val): - return max(min_val, min(val, max_val)) + big_xy = [] + for i in range(len(self.anchor_points)-1): + xA, yA = self.anchor_points[i] + xB, yB = self.anchor_points[i+1] + sub_xy = self._compute_subpath_xy(xA, yA, xB, yB) + if i == 0: + big_xy.extend(sub_xy) + else: + # avoid duplicating the point between subpaths + if len(sub_xy) > 1: + big_xy.extend(sub_xy[1:]) + + # Draw them + for (px, py) in big_xy: + path_item = LabeledPointItem(px, py, label="", radius=self.path_radius, + color=Qt.magenta, removable=False, z_value=0) + self.full_path_points.append(path_item) + self.scene.addItem(path_item) + + # Ensure S/E stay on top + for p_item in self.point_items: + if p_item._text_item: + p_item.setZValue(100) + + def _compute_subpath_xy(self, xA, yA, xB, yB): + if self.cost_image is None: + return [] + h, w = self.cost_image.shape + rA, cA = int(round(yA)), int(round(xA)) + rB, cB = int(round(yB)), int(round(xB)) + rA = max(0, min(rA, h-1)) + cA = max(0, min(cA, w-1)) + rB = max(0, min(rB, h-1)) + cB = max(0, min(cB, w-1)) + try: + path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)]) + except ValueError as e: + print("Error in find_path:", e) + return [] + return [(c, r) for (r, c) in path_rc] + # -------------------------------------------------------------------- + # MOUSE EVENTS (dragging, adding, removing points) + # -------------------------------------------------------------------- def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self._mouse_pressed = True @@ -180,17 +284,18 @@ class ImageGraphicsView(QGraphicsView): self._press_view_pos = event.pos() if self.editor_mode: - idx = self._find_point_near(event.pos(), threshold=10) + idx = self._find_item_near(event.pos(), 10) if idx is not None: + # drag existing anchor self._dragging_idx = idx scene_pos = self.mapToScene(event.pos()) - px, py = self.points[idx].get_pos() + px, py = self.point_items[idx].get_pos() self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py) - self.setDragMode(QGraphicsView.NoDrag) self.viewport().setCursor(Qt.ClosedHandCursor) return else: + # If no anchor near, user might be panning self.setDragMode(QGraphicsView.ScrollHandDrag) self.viewport().setCursor(Qt.ClosedHandCursor) else: @@ -199,25 +304,23 @@ class ImageGraphicsView(QGraphicsView): elif event.button() == Qt.RightButton: if self.editor_mode: - self._remove_point(event.pos()) + self._remove_point_by_click(event.pos()) super().mousePressEvent(event) def mouseMoveEvent(self, event): if self._dragging_idx is not None: - # Dragging an existing point + # dragging an anchor scene_pos = self.mapToScene(event.pos()) x_new = scene_pos.x() - self._drag_offset[0] y_new = scene_pos.y() - self._drag_offset[1] - - # Clamp so center doesn't go out of the image - r = self.points[self._dragging_idx]._r + r = self.point_items[self._dragging_idx]._r x_clamped = self._clamp(x_new, r, self._img_w - r) y_clamped = self._clamp(y_new, r, self._img_h - r) - self.points[self._dragging_idx].set_pos(x_clamped, y_clamped) + self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped) return else: - # If movement > threshold => treat as pan + # if movement > threshold => pan if self._mouse_pressed and (event.buttons() & Qt.LeftButton): dist = (event.pos() - self._press_view_pos).manhattanLength() if dist > self._drag_threshold: @@ -229,142 +332,117 @@ class ImageGraphicsView(QGraphicsView): if event.button() == Qt.LeftButton and self._mouse_pressed: self._mouse_pressed = False self.viewport().setCursor(Qt.ArrowCursor) - if self._dragging_idx is not None: - # The user was dragging a point and now released + idx = self._dragging_idx self._dragging_idx = None self._drag_offset = (0, 0) self.setDragMode(QGraphicsView.ScrollHandDrag) - self._run_find_path() # Recompute path + + # update anchor_points + newX, newY = self.point_items[idx].get_pos() + # even if S/E => update coords + self.anchor_points[idx] = (newX, newY) + + # revert + re-apply cost, rebuild path + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() + else: - # If not dragged, maybe add a new point if not self._was_dragging and self.editor_mode: - self._add_point(event.pos()) - self._run_find_path() + # user clicked an empty spot => add a guide point + scene_pos = self.mapToScene(event.pos()) + x, y = scene_pos.x(), scene_pos.y() + self._add_guide_point(x, y) self._was_dragging = False - def wheelEvent(self, event): - zoom_in_factor = 1.25 - zoom_out_factor = 1 / zoom_in_factor - if event.angleDelta().y() > 0: - self.scale(zoom_in_factor, zoom_in_factor) - else: - self.scale(zoom_out_factor, zoom_out_factor) - event.accept() - - def _add_point(self, view_pos): - """Add a removable red dot at the clicked location.""" - scene_pos = self.mapToScene(view_pos) - x, y = scene_pos.x(), scene_pos.y() - dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True) - # Insert before the final E point if S/E exist - if len(self.points) >= 2: - self.points.insert(len(self.points) - 1, dot) - else: - self.points.append(dot) - self.scene.addItem(dot) + def _remove_point_by_click(self, view_pos): + idx = self._find_item_near(view_pos, threshold=10) + if idx is None: + return + # check if removable => skip S/E + if not self.point_items[idx].is_removable(): + return # do nothing - def _remove_point(self, view_pos): - """Right-click => remove nearest dot if it's removable.""" - scene_pos = self.mapToScene(view_pos) - x_click, y_click = scene_pos.x(), scene_pos.y() + # remove anchor + self.scene.removeItem(self.point_items[idx]) + self.point_items.pop(idx) + self.anchor_points.pop(idx) - threshold = 10 - closest_idx = None - min_dist = float('inf') - for i, p in enumerate(self.points): - dist = p.distance_to(x_click, y_click) - if dist < min_dist: - min_dist = dist - closest_idx = i - if closest_idx is not None and min_dist <= threshold: - if self.points[closest_idx].is_removable(): - self.scene.removeItem(self.points[closest_idx]) - del self.points[closest_idx] - self._run_find_path() + # revert + re-apply cost, rebuild path + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() - def _find_point_near(self, view_pos, threshold=10): + def _find_item_near(self, view_pos, threshold=10): scene_pos = self.mapToScene(view_pos) x_click, y_click = scene_pos.x(), scene_pos.y() - - closest_idx = None min_dist = float('inf') - for i, p in enumerate(self.points): - dist = p.distance_to(x_click, y_click) - if dist < min_dist: - min_dist = dist + closest_idx = None + for i, itm in enumerate(self.point_items): + d = itm.distance_to(x_click, y_click) + if d < min_dist: + min_dist = d closest_idx = i if closest_idx is not None and min_dist <= threshold: return closest_idx return None - def _clear_point_items(self, remove_all=False): - """Remove all points if remove_all=True; else just removable ones.""" - if remove_all: - for p in self.points: - self.scene.removeItem(p) - self.points.clear() - else: - still_needed = [] - for p in self.points: - if p.is_removable(): - self.scene.removeItem(p) - else: - still_needed.append(p) - self.points = still_needed - - # Also remove any path points from the scene - for p_item in self.path_points: - self.scene.removeItem(p_item) - self.path_points.clear() - - def _run_find_path(self): + # -------------------------------------------------------------------- + # ZOOM + # -------------------------------------------------------------------- + def wheelEvent(self, event): """ - Convert the first two points (S/E) from (x,y) to (row,col) - and call find_path(). Then display the path in magenta. + Zoom in/out with mouse wheel """ - # If we don't have at least 2 points, no path - if len(self.points) < 2: - return - if self.cost_image is None: - return + zoom_in_factor = 1.25 + zoom_out_factor = 1 / zoom_in_factor - # Clear old path visualization - for item in self.path_points: - self.scene.removeItem(item) - self.path_points.clear() + # If the user scrolls upward => zoom in. Otherwise => zoom out. + if event.angleDelta().y() > 0: + self.scale(zoom_in_factor, zoom_in_factor) + else: + self.scale(zoom_out_factor, zoom_out_factor) + event.accept() - # We'll define the path between the first and last point, - # or if you specifically want the first two, you can do self.points[:2]. - s_x, s_y = self.points[0].get_pos() - e_x, e_y = self.points[-1].get_pos() + # -------------------------------------------------------------------- + # UTILS + # -------------------------------------------------------------------- + def _clamp(self, val, mn, mx): + return max(mn, min(val, mx)) - # Convert (x, y) => (row, col) = (int(y), int(x)) and clamp - h, w = self.cost_image.shape - s_r = int(round(s_y)); s_c = int(round(s_x)) - e_r = int(round(e_y)); e_c = int(round(e_x)) + def _clear_all_points(self): + for it in self.point_items: + self.scene.removeItem(it) + self.point_items.clear() + self.anchor_points.clear() - # Ensure they're inside the cost_image boundary - s_r = max(0, min(s_r, h-1)) - s_c = max(0, min(s_c, w-1)) - e_r = max(0, min(e_r, h-1)) - e_c = max(0, min(e_c, w-1)) + for p in self.full_path_points: + self.scene.removeItem(p) + self.full_path_points.clear() - # Attempt path - try: - path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)]) - except ValueError as e: - print("Error in find_path:", e) - return + def clear_guide_points(self): + """ + Removes all anchors that are 'removable' (guide points), + keeps S/E in place. Then reverts cost, re-applies, rebuilds path. + """ + i = 0 + while i < len(self.anchor_points): + if self.point_items[i].is_removable(): + self.scene.removeItem(self.point_items[i]) + del self.point_items[i] + del self.anchor_points[i] + else: + i += 1 + + for item in self.full_path_points: + self.scene.removeItem(item) + self.full_path_points.clear() - # Convert path (row,col) => (x, y) - for (r, c) in path_rc: - x = c - y = r - item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False) - self.path_points.append(item) - self.scene.addItem(item) + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() class MainWindow(QMainWindow): @@ -380,24 +458,20 @@ class MainWindow(QMainWindow): btn_layout = QHBoxLayout() - # Load Image self.btn_load_image = QPushButton("Load Image") self.btn_load_image.clicked.connect(self.load_image) btn_layout.addWidget(self.btn_load_image) - # Editor Mode self.btn_editor_mode = QPushButton("Editor Mode: OFF") self.btn_editor_mode.setCheckable(True) self.btn_editor_mode.setStyleSheet("background-color: lightgray;") self.btn_editor_mode.clicked.connect(self.toggle_editor_mode) btn_layout.addWidget(self.btn_editor_mode) - # Export Points self.btn_export_points = QPushButton("Export Points") self.btn_export_points.clicked.connect(self.export_points) btn_layout.addWidget(self.btn_export_points) - # Clear Points self.btn_clear_points = QPushButton("Clear Points") self.btn_clear_points.clicked.connect(self.clear_points) btn_layout.addWidget(self.btn_clear_points) @@ -407,7 +481,6 @@ class MainWindow(QMainWindow): self.resize(900, 600) def load_image(self): - """Open file dialog to pick an image, then load it.""" options = QFileDialog.Options() file_path, _ = QFileDialog.getOpenFileName( self, "Open Image", "", @@ -416,8 +489,9 @@ class MainWindow(QMainWindow): ) if file_path: self.image_view.load_image(file_path) - # Compute cost image - self.image_view.cost_image = compute_cost_image(file_path) + cost_img = compute_cost_image(file_path) + self.image_view.cost_image_original = cost_img + self.image_view.cost_image = cost_img.copy() def toggle_editor_mode(self): is_checked = self.btn_editor_mode.isChecked() @@ -430,10 +504,9 @@ class MainWindow(QMainWindow): self.btn_editor_mode.setStyleSheet("background-color: lightgray;") def export_points(self): - if not self.image_view.points: - print("No points to export.") + if not self.image_view.anchor_points: + print("No anchor points to export.") return - options = QFileDialog.Options() file_path, _ = QFileDialog.getSaveFileName( self, "Export Points", "", @@ -441,13 +514,16 @@ class MainWindow(QMainWindow): options=options ) if file_path: - coords = [p.get_pos() for p in self.image_view.points] - points_array = np.array(coords) + points_array = np.array(self.image_view.anchor_points) np.save(file_path, points_array) print(f"Exported {len(points_array)} points to {file_path}") def clear_points(self): - self.image_view._clear_point_items(remove_all=False) + """Remove all removable anchors (guide points), keep S/E in place.""" + self.image_view.clear_guide_points() + + def closeEvent(self, event): + super().closeEvent(event) def main():