Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
import math
import numpy as np
# For smoothing the path
from scipy.signal import savgol_filter
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
)
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF
from live_wire import compute_cost_image, find_path
class LabeledPointItem(QGraphicsEllipseItem):
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x
self._y = y
self._r = radius
self._removable = removable
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self.setZValue(z_value)
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black"))
font = QFont("Arial", 14)
font.setBold(True)
self._text_item.setFont(font)
self._scale_text_to_fit()
self.set_pos(x, y)
def _scale_text_to_fit(self):
if not self._text_item:
return
self._text_item.setScale(1.0)
circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_factor = min(circle_diam / text_w, circle_diam / text_h)
self._text_item.setScale(scale_factor)
self._center_label()
def _center_label(self):
if not self._text_item:
return
ellipse_w = 2 * self._r
ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale()
scaled_w = raw_rect.width() * scale_factor
scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty)
def set_pos(self, x, y):
"""Positions the circle so that its center is at (x, y)."""
self._x = x
self._y = y
self.setPos(x - self._r, y - self._r)
def get_pos(self):
return (self._x, self._y)
def distance_to(self, x_other, y_other):
return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
def is_removable(self):
return self._removable
class ImageGraphicsView(QGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Zoom around mouse pointer
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
# Image display item
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
# Parallel lists: anchor_points + LabeledPointItem
self.anchor_points = [] # List of (x, y)
self.point_items = [] # List of LabeledPointItem
self.dot_radius = 4
self.path_radius = 1
self.radius_cost_image = 2 # cost-lowering radius
self._img_w = 0
self._img_h = 0
# For panning/dragging
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ArrowCursor)
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
self._was_dragging = False
self._dragging_idx = None
self._drag_offset = (0, 0)
self._drag_counter = 0 # throttles path updates while dragging
# We will keep two copies of the cost image
self.cost_image_original = None
self.cost_image = None
# Path circles displayed
self.full_path_points = []
# --------------------------------------------------------------------
# LOADING
# --------------------------------------------------------------------
def load_image(self, path):
pixmap = QPixmap(path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self._clear_all_points()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
# Place S/E at left and right
s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
# Insert S/E
self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
# --------------------------------------------------------------------
# ANCHOR POINTS
# --------------------------------------------------------------------
def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
"""
Insert at index=idx, or -1 => append just before E if E exists.
Clamps x,y so points can't go outside the image.
x_clamped = self._clamp(x, radius, self._img_w - radius)
y_clamped = self._clamp(y, radius, self._img_h - radius)
# If we have at least 2 anchors, the last is E => insert before it
if len(self.anchor_points) >= 2:
idx = len(self.anchor_points) - 1
else:
idx = len(self.anchor_points)
self.anchor_points.insert(idx, (x_clamped, y_clamped))
color = Qt.green if label in ("S", "E") else Qt.red
item = LabeledPointItem(
x_clamped, y_clamped, label=label,
radius=radius, color=color, removable=removable, z_value=z_val
)
self.point_items.insert(idx, item)
self.scene.addItem(item)
def _add_guide_point(self, x, y):
"""Called when user left-clicks an empty spot. Insert a red guide point, recalc path."""
# clamp to image boundaries
x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
# 1) revert cost
# 2) Insert new anchor
self._insert_anchor_point(-1, x_clamped, y_clamped, label="", removable=True, z_val=1, radius=self.dot_radius)
# 3) Re-apply cost-lowering
# 4) Rebuild path
self._rebuild_full_path()
# --------------------------------------------------------------------
# COST IMAGE
# --------------------------------------------------------------------
def _revert_cost_to_original(self):
if self.cost_image_original is not None:
self.cost_image = self.cost_image_original.copy()
def _apply_all_guide_points_to_cost(self):
"""Lower cost around every REMOVABLE anchor (the red ones)."""
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
if self.cost_image is None:
return
for i, (ax, ay) in enumerate(self.anchor_points):
if self.point_items[i].is_removable():
self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
def _lower_cost_in_circle(self, x_f, y_f, radius):
if self.cost_image is None:
return
h, w = self.cost_image.shape
row_c = int(round(y_f))
col_c = int(round(x_f))
if not (0 <= row_c < h and 0 <= col_c < w):
return
global_min = self.cost_image.min()
r_s = max(0, row_c - radius)
r_e = min(h, row_c + radius + 1)
c_s = max(0, col_c - radius)
c_e = min(w, col_c + radius + 1)
for rr in range(r_s, r_e):
for cc in range(c_s, c_e):
dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
if dist <= radius:
self.cost_image[rr, cc] = global_min
# --------------------------------------------------------------------
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
# Remove old path items
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
# Build subpaths
if len(self.anchor_points) < 2 or self.cost_image is None:
return
big_xy = []
for i in range(len(self.anchor_points) - 1):
xB, yB = self.anchor_points[i + 1]
sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
if i == 0:
big_xy.extend(sub_xy)
else:
# avoid duplicating the shared anchor
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
# Smoothing with Savitzky-Golay
if len(big_xy) >= 7:
arr_xy = np.array(big_xy) # shape (N,2)
smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
big_xy = smoothed.tolist()
# Draw them
for (px, py) in big_xy:
path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
color=Qt.magenta, removable=False, z_value=0)
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
# Ensure S/E remain on top
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA, yA, xB, yB):
if self.cost_image is None:
return []
h, w = self.cost_image.shape
rA, cA = int(round(yA)), int(round(xA))
rB, cB = int(round(yB)), int(round(xB))
rA = max(0, min(rA, h - 1))
cA = max(0, min(cA, w - 1))
rB = max(0, min(rB, h - 1))
cB = max(0, min(cB, w - 1))
try:
path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
except ValueError as e:
print("Error in find_path:", e)
return []
return [(c, r) for (r, c) in path_rc]
# --------------------------------------------------------------------
# MOUSE EVENTS
# --------------------------------------------------------------------
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
self._was_dragging = False
self._press_view_pos = event.pos()
# Check if user clicked near an existing anchor => drag
idx = self._find_item_near(event.pos(), threshold=10)
if idx is not None:
# drag existing anchor
self._dragging_idx = idx
self._drag_counter = 0
scene_pos = self.mapToScene(event.pos())
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setDragMode(QGraphicsView.NoDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
return
# If no anchor near => user is placing a new point
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
elif event.button() == Qt.RightButton:
# Right-click => remove point if it's removable
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self._dragging_idx is not None:
# dragging an anchor
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
# clamp so user can't drag outside
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
self._drag_counter += 1
if self._drag_counter >= 4:
# partial path update => revert cost, reapply, rebuild
self._drag_counter = 0
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
# anchor_points updated
self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
self._rebuild_full_path()
return
else:
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.viewport().setCursor(Qt.ArrowCursor)
# done dragging => final path update
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag)
newX, newY = self.point_items[idx].get_pos()
self.anchor_points[idx] = (newX, newY)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
# If not dragging => place a new guide point
if not self._was_dragging:
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
self._add_guide_point(x, y)
self._was_dragging = False
def _remove_point_by_click(self, view_pos):
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
# skip if it's S/E
if not self.point_items[idx].is_removable():
return
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
self.scene.removeItem(self.point_items[idx])
self.point_items.pop(idx)
self.anchor_points.pop(idx)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _find_item_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
min_dist = float('inf')
closest_idx = None
for i, itm in enumerate(self.point_items):
d = itm.distance_to(x_click, y_click)
if d < min_dist:
min_dist = d
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
# --------------------------------------------------------------------
# ZOOM
# --------------------------------------------------------------------
def wheelEvent(self, event):
"""Zoom in/out with mouse wheel."""
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
def _clamp(self, val, mn, mx):
return max(mn, min(val, mx))
def _clear_all_points(self):
for it in self.point_items:
self.scene.removeItem(it)
self.point_items.clear()
self.anchor_points.clear()
for p in self.full_path_points:
self.scene.removeItem(p)
self.full_path_points.clear()
def clear_guide_points(self):
"""Remove all removable (guide) anchors, keep S/E. Then rebuild."""
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
self.scene.removeItem(self.point_items[i])
del self.point_items[i]
del self.anchor_points[i]
else:
i += 1
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Test GUI")
main_widget = QWidget()
main_layout = QVBoxLayout(main_widget)
self.image_view = ImageGraphicsView()
main_layout.addWidget(self.image_view)
# Buttons layout
# Load Image
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
# Export Points
self.btn_export_points = QPushButton("Export Points")
self.btn_export_points.clicked.connect(self.export_points)
btn_layout.addWidget(self.btn_export_points)
# Clear Points
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
main_layout.addLayout(btn_layout)
self.setCentralWidget(main_widget)
self.resize(900, 600)
def load_image(self):
"""Open file dialog, load image, compute cost image, store in view."""
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "",
"Images (*.png *.jpg *.jpeg *.bmp *.tif)",
options=options
)
if file_path:
self.image_view.load_image(file_path)
cost_img = compute_cost_image(file_path)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
def export_points(self):
if not self.image_view.anchor_points:
print("No anchor points to export.")
return
options = QFileDialog.Options()
file_path, _ = QFileDialog.getSaveFileName(
self, "Export Points", "",
"NumPy Files (*.npy);;All Files (*)",
options=options
)
if file_path:
points_array = np.array(self.image_view.anchor_points)
np.save(file_path, points_array)
print(f"Exported {len(points_array)} points to {file_path}")
def clear_points(self):
"""Remove all removable anchors (guide points), keep S/E in place."""
self.image_view.clear_guide_points()
def closeEvent(self, event):
super().closeEvent(event)
def main():
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()