Skip to content
Snippets Groups Projects
Commit d60e2b78 authored by Christian's avatar Christian
Browse files

Integrated live_wire into GUI

parent 521876af
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,9 @@ from PyQt5.QtWidgets import ( ...@@ -10,6 +10,9 @@ from PyQt5.QtWidgets import (
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF from PyQt5.QtCore import Qt, QRectF
# Import your live_wire functions
from live_wire import compute_cost_image, find_path
class LabeledPointItem(QGraphicsEllipseItem): class LabeledPointItem(QGraphicsEllipseItem):
""" """
...@@ -35,62 +38,46 @@ class LabeledPointItem(QGraphicsEllipseItem): ...@@ -35,62 +38,46 @@ class LabeledPointItem(QGraphicsEllipseItem):
self._text_item = QGraphicsTextItem(self) self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label) self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black")) self._text_item.setDefaultTextColor(QColor("black"))
# Bold text
font = QFont("Arial", 14) font = QFont("Arial", 14)
font.setBold(True) font.setBold(True)
self._text_item.setFont(font) self._text_item.setFont(font)
self._scale_text_to_fit() self._scale_text_to_fit()
# Move so center is at (x, y) # Move so center is at (x, y)
self.set_pos(x, y) self.set_pos(x, y)
def _scale_text_to_fit(self): def _scale_text_to_fit(self):
"""Scale the text down so it fits fully within the circle's diameter."""
if not self._text_item: if not self._text_item:
return return
# Reset scale first
self._text_item.setScale(1.0) self._text_item.setScale(1.0)
circle_diam = 2 * self._r circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect() raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width() text_w = raw_rect.width()
text_h = raw_rect.height() text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam: if text_w > circle_diam or text_h > circle_diam:
scale_w = circle_diam / text_w scale_w = circle_diam / text_w
scale_h = circle_diam / text_h scale_h = circle_diam / text_h
scale_factor = min(scale_w, scale_h) scale_factor = min(scale_w, scale_h)
self._text_item.setScale(scale_factor) self._text_item.setScale(scale_factor)
self._center_label() self._center_label()
def _center_label(self): def _center_label(self):
"""Center the text in the circle, taking into account any scaling."""
if not self._text_item: if not self._text_item:
return return
ellipse_w = 2 * self._r ellipse_w = 2 * self._r
ellipse_h = 2 * self._r ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect() raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale() scale_factor = self._text_item.scale()
scaled_w = raw_rect.width() * scale_factor scaled_w = raw_rect.width() * scale_factor
scaled_h = raw_rect.height() * scale_factor scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5 tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5 ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty) self._text_item.setPos(tx, ty)
def set_pos(self, x, y): def set_pos(self, x, y):
""" """Move so the circle's center is at (x,y) in scene coords."""
Move so the circle's center is at (x,y) in scene coords.
"""
self._x = x self._x = x
self._y = y self._y = y
# Because our ellipse is (0,0,2*r,2*r) in local coords,
# we shift by (x-r, y-r).
self.setPos(x - self._r, y - self._r) self.setPos(x - self._r, y - self._r)
def get_pos(self): def get_pos(self):
...@@ -122,13 +109,11 @@ class ImageGraphicsView(QGraphicsView): ...@@ -122,13 +109,11 @@ class ImageGraphicsView(QGraphicsView):
self.image_item = QGraphicsPixmapItem() self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item) self.scene.addItem(self.image_item)
self.points = [] self.points = [] # LabeledPointItem objects
self.editor_mode = False self.editor_mode = False
# For normal red dots
self.dot_radius = 4 self.dot_radius = 4
self.path_radius = 1 # radius of circles in path
# Keep track of image size
self._img_w = 0 self._img_w = 0
self._img_h = 0 self._img_h = 0
...@@ -142,6 +127,12 @@ class ImageGraphicsView(QGraphicsView): ...@@ -142,6 +127,12 @@ class ImageGraphicsView(QGraphicsView):
self._dragging_idx = None self._dragging_idx = None
self._drag_offset = (0, 0) self._drag_offset = (0, 0)
# Cost image from compute_cost_image
self.cost_image = None
# All path points displayed in magenta
self.path_points = []
def load_image(self, image_path): def load_image(self, image_path):
pixmap = QPixmap(image_path) pixmap = QPixmap(image_path)
if not pixmap.isNull(): if not pixmap.isNull():
...@@ -166,7 +157,6 @@ class ImageGraphicsView(QGraphicsView): ...@@ -166,7 +157,6 @@ class ImageGraphicsView(QGraphicsView):
s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False) s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False)
e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False) e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False)
# Put S in front, E in back
self.points = [s_point, e_point] self.points = [s_point, e_point]
self.scene.addItem(s_point) self.scene.addItem(s_point)
self.scene.addItem(e_point) self.scene.addItem(e_point)
...@@ -175,20 +165,10 @@ class ImageGraphicsView(QGraphicsView): ...@@ -175,20 +165,10 @@ class ImageGraphicsView(QGraphicsView):
self.editor_mode = mode self.editor_mode = mode
def _create_point(self, x, y, label, radius, color, removable=True): def _create_point(self, x, y, label, radius, color, removable=True):
""" # Clamp coordinates so center doesn't go outside the image
Helper to create a LabeledPointItem at (x,y), but clamp inside image first.
"""
# Clamp coordinates so center doesn't go outside
cx = self._clamp(x, radius, self._img_w - radius) cx = self._clamp(x, radius, self._img_w - radius)
cy = self._clamp(y, radius, self._img_h - radius) cy = self._clamp(y, radius, self._img_h - radius)
return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable)
return LabeledPointItem(
cx, cy,
label=label,
radius=radius,
color=color,
removable=removable
)
def _clamp(self, val, min_val, max_val): def _clamp(self, val, min_val, max_val):
return max(min_val, min(val, max_val)) return max(min_val, min(val, max_val))
...@@ -230,11 +210,10 @@ class ImageGraphicsView(QGraphicsView): ...@@ -230,11 +210,10 @@ class ImageGraphicsView(QGraphicsView):
x_new = scene_pos.x() - self._drag_offset[0] x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1] y_new = scene_pos.y() - self._drag_offset[1]
# Clamp center so it doesn't go out of the image # Clamp so center doesn't go out of the image
r = self.points[self._dragging_idx]._r r = self.points[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r) x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r) y_clamped = self._clamp(y_new, r, self._img_h - r)
self.points[self._dragging_idx].set_pos(x_clamped, y_clamped) self.points[self._dragging_idx].set_pos(x_clamped, y_clamped)
return return
else: else:
...@@ -252,40 +231,38 @@ class ImageGraphicsView(QGraphicsView): ...@@ -252,40 +231,38 @@ class ImageGraphicsView(QGraphicsView):
self.viewport().setCursor(Qt.ArrowCursor) self.viewport().setCursor(Qt.ArrowCursor)
if self._dragging_idx is not None: if self._dragging_idx is not None:
# The user was dragging a point and now released
self._dragging_idx = None self._dragging_idx = None
self._drag_offset = (0, 0) self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag) self.setDragMode(QGraphicsView.ScrollHandDrag)
self._run_find_path() # Recompute path
else: else:
# If not dragged, maybe add a new point # If not dragged, maybe add a new point
if not self._was_dragging and self.editor_mode: if not self._was_dragging and self.editor_mode:
self._add_point(event.pos()) self._add_point(event.pos())
self._run_find_path()
self._was_dragging = False self._was_dragging = False
def wheelEvent(self, event): def wheelEvent(self, event):
zoom_in_factor = 1.25 zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0: if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor) self.scale(zoom_in_factor, zoom_in_factor)
else: else:
self.scale(zoom_out_factor, zoom_out_factor) self.scale(zoom_out_factor, zoom_out_factor)
event.accept() event.accept()
# ---------- Points ----------
def _add_point(self, view_pos): def _add_point(self, view_pos):
"""Add a removable red dot at the clicked location, clamped inside the image.""" """Add a removable red dot at the clicked location."""
scene_pos = self.mapToScene(view_pos) scene_pos = self.mapToScene(view_pos)
x, y = scene_pos.x(), scene_pos.y() x, y = scene_pos.x(), scene_pos.y()
dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True) dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True)
# Insert before the final E point if S/E exist
# Insert between S and E if they exist
if len(self.points) >= 2: if len(self.points) >= 2:
self.points.insert(len(self.points) - 1, dot) self.points.insert(len(self.points) - 1, dot)
else: else:
self.points.append(dot) self.points.append(dot)
self.scene.addItem(dot) self.scene.addItem(dot)
def _remove_point(self, view_pos): def _remove_point(self, view_pos):
...@@ -296,17 +273,16 @@ class ImageGraphicsView(QGraphicsView): ...@@ -296,17 +273,16 @@ class ImageGraphicsView(QGraphicsView):
threshold = 10 threshold = 10
closest_idx = None closest_idx = None
min_dist = float('inf') min_dist = float('inf')
for i, p in enumerate(self.points): for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click) dist = p.distance_to(x_click, y_click)
if dist < min_dist: if dist < min_dist:
min_dist = dist min_dist = dist
closest_idx = i closest_idx = i
if closest_idx is not None and min_dist <= threshold: if closest_idx is not None and min_dist <= threshold:
if self.points[closest_idx].is_removable(): if self.points[closest_idx].is_removable():
self.scene.removeItem(self.points[closest_idx]) self.scene.removeItem(self.points[closest_idx])
del self.points[closest_idx] del self.points[closest_idx]
self._run_find_path()
def _find_point_near(self, view_pos, threshold=10): def _find_point_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos) scene_pos = self.mapToScene(view_pos)
...@@ -314,13 +290,11 @@ class ImageGraphicsView(QGraphicsView): ...@@ -314,13 +290,11 @@ class ImageGraphicsView(QGraphicsView):
closest_idx = None closest_idx = None
min_dist = float('inf') min_dist = float('inf')
for i, p in enumerate(self.points): for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click) dist = p.distance_to(x_click, y_click)
if dist < min_dist: if dist < min_dist:
min_dist = dist min_dist = dist
closest_idx = i closest_idx = i
if closest_idx is not None and min_dist <= threshold: if closest_idx is not None and min_dist <= threshold:
return closest_idx return closest_idx
return None return None
...@@ -340,6 +314,58 @@ class ImageGraphicsView(QGraphicsView): ...@@ -340,6 +314,58 @@ class ImageGraphicsView(QGraphicsView):
still_needed.append(p) still_needed.append(p)
self.points = still_needed self.points = still_needed
# Also remove any path points from the scene
for p_item in self.path_points:
self.scene.removeItem(p_item)
self.path_points.clear()
def _run_find_path(self):
"""
Convert the first two points (S/E) from (x,y) to (row,col)
and call find_path(). Then display the path in magenta.
"""
# If we don't have at least 2 points, no path
if len(self.points) < 2:
return
if self.cost_image is None:
return
# Clear old path visualization
for item in self.path_points:
self.scene.removeItem(item)
self.path_points.clear()
# We'll define the path between the first and last point,
# or if you specifically want the first two, you can do self.points[:2].
s_x, s_y = self.points[0].get_pos()
e_x, e_y = self.points[-1].get_pos()
# Convert (x, y) => (row, col) = (int(y), int(x)) and clamp
h, w = self.cost_image.shape
s_r = int(round(s_y)); s_c = int(round(s_x))
e_r = int(round(e_y)); e_c = int(round(e_x))
# Ensure they're inside the cost_image boundary
s_r = max(0, min(s_r, h-1))
s_c = max(0, min(s_c, w-1))
e_r = max(0, min(e_r, h-1))
e_c = max(0, min(e_c, w-1))
# Attempt path
try:
path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)])
except ValueError as e:
print("Error in find_path:", e)
return
# Convert path (row,col) => (x, y)
for (r, c) in path_rc:
x = c
y = r
item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False)
self.path_points.append(item)
self.scene.addItem(item)
class MainWindow(QMainWindow): class MainWindow(QMainWindow):
def __init__(self): def __init__(self):
...@@ -390,11 +416,12 @@ class MainWindow(QMainWindow): ...@@ -390,11 +416,12 @@ class MainWindow(QMainWindow):
) )
if file_path: if file_path:
self.image_view.load_image(file_path) self.image_view.load_image(file_path)
# Compute cost image
self.image_view.cost_image = compute_cost_image(file_path)
def toggle_editor_mode(self): def toggle_editor_mode(self):
is_checked = self.btn_editor_mode.isChecked() is_checked = self.btn_editor_mode.isChecked()
self.image_view.set_editor_mode(is_checked) self.image_view.set_editor_mode(is_checked)
if is_checked: if is_checked:
self.btn_editor_mode.setText("Editor Mode: ON") self.btn_editor_mode.setText("Editor Mode: ON")
self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;") self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment