From d60e2b78238b6f3671aa13583720bd6c28910230 Mon Sep 17 00:00:00 2001 From: Christian <s224389@dtu.dk> Date: Wed, 15 Jan 2025 12:36:37 +0100 Subject: [PATCH] Integrated live_wire into GUI --- GUI_draft.py | 127 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 77 insertions(+), 50 deletions(-) diff --git a/GUI_draft.py b/GUI_draft.py index 91e9b92..c7c2bc9 100644 --- a/GUI_draft.py +++ b/GUI_draft.py @@ -10,6 +10,9 @@ from PyQt5.QtWidgets import ( from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont from PyQt5.QtCore import Qt, QRectF +# Import your live_wire functions +from live_wire import compute_cost_image, find_path + class LabeledPointItem(QGraphicsEllipseItem): """ @@ -35,62 +38,46 @@ class LabeledPointItem(QGraphicsEllipseItem): self._text_item = QGraphicsTextItem(self) self._text_item.setPlainText(label) self._text_item.setDefaultTextColor(QColor("black")) - # Bold text font = QFont("Arial", 14) font.setBold(True) self._text_item.setFont(font) - self._scale_text_to_fit() # Move so center is at (x, y) self.set_pos(x, y) def _scale_text_to_fit(self): - """Scale the text down so it fits fully within the circle's diameter.""" if not self._text_item: return - - # Reset scale first self._text_item.setScale(1.0) - circle_diam = 2 * self._r raw_rect = self._text_item.boundingRect() text_w = raw_rect.width() text_h = raw_rect.height() - if text_w > circle_diam or text_h > circle_diam: scale_w = circle_diam / text_w scale_h = circle_diam / text_h scale_factor = min(scale_w, scale_h) self._text_item.setScale(scale_factor) - self._center_label() def _center_label(self): - """Center the text in the circle, taking into account any scaling.""" if not self._text_item: return - ellipse_w = 2 * self._r ellipse_h = 2 * self._r - raw_rect = self._text_item.boundingRect() scale_factor = self._text_item.scale() - scaled_w = raw_rect.width() * scale_factor + scaled_w = raw_rect.width() * scale_factor scaled_h = raw_rect.height() * scale_factor - tx = (ellipse_w - scaled_w) * 0.5 ty = (ellipse_h - scaled_h) * 0.5 self._text_item.setPos(tx, ty) def set_pos(self, x, y): - """ - Move so the circle's center is at (x,y) in scene coords. - """ + """Move so the circle's center is at (x,y) in scene coords.""" self._x = x self._y = y - # Because our ellipse is (0,0,2*r,2*r) in local coords, - # we shift by (x-r, y-r). self.setPos(x - self._r, y - self._r) def get_pos(self): @@ -122,13 +109,11 @@ class ImageGraphicsView(QGraphicsView): self.image_item = QGraphicsPixmapItem() self.scene.addItem(self.image_item) - self.points = [] + self.points = [] # LabeledPointItem objects self.editor_mode = False - # For normal red dots - self.dot_radius = 4 - - # Keep track of image size + self.dot_radius = 4 + self.path_radius = 1 # radius of circles in path self._img_w = 0 self._img_h = 0 @@ -142,6 +127,12 @@ class ImageGraphicsView(QGraphicsView): self._dragging_idx = None self._drag_offset = (0, 0) + # Cost image from compute_cost_image + self.cost_image = None + + # All path points displayed in magenta + self.path_points = [] + def load_image(self, image_path): pixmap = QPixmap(image_path) if not pixmap.isNull(): @@ -166,7 +157,6 @@ class ImageGraphicsView(QGraphicsView): s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False) e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False) - # Put S in front, E in back self.points = [s_point, e_point] self.scene.addItem(s_point) self.scene.addItem(e_point) @@ -175,20 +165,10 @@ class ImageGraphicsView(QGraphicsView): self.editor_mode = mode def _create_point(self, x, y, label, radius, color, removable=True): - """ - Helper to create a LabeledPointItem at (x,y), but clamp inside image first. - """ - # Clamp coordinates so center doesn't go outside + # Clamp coordinates so center doesn't go outside the image cx = self._clamp(x, radius, self._img_w - radius) cy = self._clamp(y, radius, self._img_h - radius) - - return LabeledPointItem( - cx, cy, - label=label, - radius=radius, - color=color, - removable=removable - ) + return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable) def _clamp(self, val, min_val, max_val): return max(min_val, min(val, max_val)) @@ -230,11 +210,10 @@ class ImageGraphicsView(QGraphicsView): x_new = scene_pos.x() - self._drag_offset[0] y_new = scene_pos.y() - self._drag_offset[1] - # Clamp center so it doesn't go out of the image + # Clamp so center doesn't go out of the image r = self.points[self._dragging_idx]._r x_clamped = self._clamp(x_new, r, self._img_w - r) y_clamped = self._clamp(y_new, r, self._img_h - r) - self.points[self._dragging_idx].set_pos(x_clamped, y_clamped) return else: @@ -252,40 +231,38 @@ class ImageGraphicsView(QGraphicsView): self.viewport().setCursor(Qt.ArrowCursor) if self._dragging_idx is not None: + # The user was dragging a point and now released self._dragging_idx = None self._drag_offset = (0, 0) self.setDragMode(QGraphicsView.ScrollHandDrag) + self._run_find_path() # Recompute path else: # If not dragged, maybe add a new point if not self._was_dragging and self.editor_mode: self._add_point(event.pos()) + self._run_find_path() self._was_dragging = False def wheelEvent(self, event): zoom_in_factor = 1.25 zoom_out_factor = 1 / zoom_in_factor - if event.angleDelta().y() > 0: self.scale(zoom_in_factor, zoom_in_factor) else: self.scale(zoom_out_factor, zoom_out_factor) event.accept() - # ---------- Points ---------- def _add_point(self, view_pos): - """Add a removable red dot at the clicked location, clamped inside the image.""" + """Add a removable red dot at the clicked location.""" scene_pos = self.mapToScene(view_pos) x, y = scene_pos.x(), scene_pos.y() - dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True) - - # Insert between S and E if they exist + # Insert before the final E point if S/E exist if len(self.points) >= 2: self.points.insert(len(self.points) - 1, dot) else: self.points.append(dot) - self.scene.addItem(dot) def _remove_point(self, view_pos): @@ -296,17 +273,16 @@ class ImageGraphicsView(QGraphicsView): threshold = 10 closest_idx = None min_dist = float('inf') - for i, p in enumerate(self.points): dist = p.distance_to(x_click, y_click) if dist < min_dist: min_dist = dist closest_idx = i - if closest_idx is not None and min_dist <= threshold: if self.points[closest_idx].is_removable(): self.scene.removeItem(self.points[closest_idx]) del self.points[closest_idx] + self._run_find_path() def _find_point_near(self, view_pos, threshold=10): scene_pos = self.mapToScene(view_pos) @@ -314,13 +290,11 @@ class ImageGraphicsView(QGraphicsView): closest_idx = None min_dist = float('inf') - for i, p in enumerate(self.points): dist = p.distance_to(x_click, y_click) if dist < min_dist: min_dist = dist closest_idx = i - if closest_idx is not None and min_dist <= threshold: return closest_idx return None @@ -340,6 +314,58 @@ class ImageGraphicsView(QGraphicsView): still_needed.append(p) self.points = still_needed + # Also remove any path points from the scene + for p_item in self.path_points: + self.scene.removeItem(p_item) + self.path_points.clear() + + def _run_find_path(self): + """ + Convert the first two points (S/E) from (x,y) to (row,col) + and call find_path(). Then display the path in magenta. + """ + # If we don't have at least 2 points, no path + if len(self.points) < 2: + return + if self.cost_image is None: + return + + # Clear old path visualization + for item in self.path_points: + self.scene.removeItem(item) + self.path_points.clear() + + # We'll define the path between the first and last point, + # or if you specifically want the first two, you can do self.points[:2]. + s_x, s_y = self.points[0].get_pos() + e_x, e_y = self.points[-1].get_pos() + + # Convert (x, y) => (row, col) = (int(y), int(x)) and clamp + h, w = self.cost_image.shape + s_r = int(round(s_y)); s_c = int(round(s_x)) + e_r = int(round(e_y)); e_c = int(round(e_x)) + + # Ensure they're inside the cost_image boundary + s_r = max(0, min(s_r, h-1)) + s_c = max(0, min(s_c, w-1)) + e_r = max(0, min(e_r, h-1)) + e_c = max(0, min(e_c, w-1)) + + # Attempt path + try: + path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)]) + except ValueError as e: + print("Error in find_path:", e) + return + + # Convert path (row,col) => (x, y) + for (r, c) in path_rc: + x = c + y = r + item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False) + self.path_points.append(item) + self.scene.addItem(item) + class MainWindow(QMainWindow): def __init__(self): @@ -390,11 +416,12 @@ class MainWindow(QMainWindow): ) if file_path: self.image_view.load_image(file_path) + # Compute cost image + self.image_view.cost_image = compute_cost_image(file_path) def toggle_editor_mode(self): is_checked = self.btn_editor_mode.isChecked() self.image_view.set_editor_mode(is_checked) - if is_checked: self.btn_editor_mode.setText("Editor Mode: ON") self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;") -- GitLab