Newer
Older
import numpy as np
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
from PyQt5.QtCore import Qt, QRectF
s224389
committed
# Import your live_wire functions
from live_wire import compute_cost_image, find_path
s224389
committed
class LabeledPointItem(QGraphicsEllipseItem):
s224389
committed
"""
A circle with optional (bold) label (e.g. 'S'/'E'),
which automatically scales the text if it's bigger than the circle.
s224389
committed
"""
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x # Center x
self._y = y # Center y
self._r = radius # Circle radius
self._removable = removable
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black"))
font = QFont("Arial", 14)
font.setBold(True)
self._text_item.setFont(font)
self._scale_text_to_fit()
# Move so center is at (x, y)
self.set_pos(x, y)
def _scale_text_to_fit(self):
if not self._text_item:
return
self._text_item.setScale(1.0)
circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_w = circle_diam / text_w
scale_h = circle_diam / text_h
scale_factor = min(scale_w, scale_h)
self._text_item.setScale(scale_factor)
self._center_label()
def _center_label(self):
if not self._text_item:
return
ellipse_w = 2 * self._r
ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale()
scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty)
s224389
committed
def set_pos(self, x, y):
"""Move so the circle's center is at (x,y) in scene coords."""
s224389
committed
self._x = x
self._y = y
s224389
committed
s224389
committed
def distance_to(self, x_other, y_other):
dx = self._x - x_other
dy = self._y - y_other
return math.sqrt(dx*dx + dy*dy)
class ImageGraphicsView(QGraphicsView):
"""
Displays an image and allows placing/dragging labeled points.
Ensures points can't go outside the image boundary.
"""
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Zoom around mouse pointer
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
# Image item
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
self.editor_mode = False
self.dot_radius = 4
self.path_radius = 1 # radius of circles in path
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ArrowCursor)
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
self._was_dragging = False
s224389
committed
self._dragging_idx = None
self._drag_offset = (0, 0)
# Cost image from compute_cost_image
self.cost_image = None
# All path points displayed in magenta
self.path_points = []
def load_image(self, image_path):
pixmap = QPixmap(image_path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
# Save image dimensions
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
# Positions for S/E
s_x = self._img_w * 0.15
s_y = self._img_h * 0.5
e_x = self._img_w * 0.85
e_y = self._img_h * 0.5
# Create green S/E with radius=6
s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False)
e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False)
self.points = [s_point, e_point]
self.scene.addItem(s_point)
self.scene.addItem(e_point)
def set_editor_mode(self, mode: bool):
self.editor_mode = mode
# Clamp coordinates so center doesn't go outside the image
cx = self._clamp(x, radius, self._img_w - radius)
cy = self._clamp(y, radius, self._img_h - radius)
return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable)
def _clamp(self, val, min_val, max_val):
return max(min_val, min(val, max_val))
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
self._was_dragging = False
self._press_view_pos = event.pos()
s224389
committed
if self.editor_mode:
idx = self._find_point_near(event.pos(), threshold=10)
if idx is not None:
self._dragging_idx = idx
scene_pos = self.mapToScene(event.pos())
px, py = self.points[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setDragMode(QGraphicsView.NoDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
return
else:
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
else:
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
elif event.button() == Qt.RightButton:
if self.editor_mode:
self._remove_point(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
s224389
committed
if self._dragging_idx is not None:
s224389
committed
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
r = self.points[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.points[self._dragging_idx].set_pos(x_clamped, y_clamped)
s224389
committed
else:
s224389
committed
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
s224389
committed
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.viewport().setCursor(Qt.ArrowCursor)
s224389
committed
if self._dragging_idx is not None:
# The user was dragging a point and now released
s224389
committed
self._dragging_idx = None
self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag)
s224389
committed
else:
s224389
committed
if not self._was_dragging and self.editor_mode:
self._add_point(event.pos())
self._was_dragging = False
def wheelEvent(self, event):
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
def _add_point(self, view_pos):
"""Add a removable red dot at the clicked location."""
scene_pos = self.mapToScene(view_pos)
x, y = scene_pos.x(), scene_pos.y()
dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True)
# Insert before the final E point if S/E exist
if len(self.points) >= 2:
self.points.insert(len(self.points) - 1, dot)
else:
self.points.append(dot)
self.scene.addItem(dot)
def _remove_point(self, view_pos):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
threshold = 10
closest_idx = None
min_dist = float('inf')
for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click)
s224389
committed
if dist < min_dist:
min_dist = dist
s224389
committed
if closest_idx is not None and min_dist <= threshold:
if self.points[closest_idx].is_removable():
self.scene.removeItem(self.points[closest_idx])
del self.points[closest_idx]
s224389
committed
def _find_point_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
closest_idx = None
min_dist = float('inf')
for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click)
if dist < min_dist:
min_dist = dist
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
def _clear_point_items(self, remove_all=False):
if remove_all:
for p in self.points:
self.scene.removeItem(p)
self.points.clear()
else:
still_needed = []
for p in self.points:
if p.is_removable():
self.scene.removeItem(p)
else:
still_needed.append(p)
self.points = still_needed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
# Also remove any path points from the scene
for p_item in self.path_points:
self.scene.removeItem(p_item)
self.path_points.clear()
def _run_find_path(self):
"""
Convert the first two points (S/E) from (x,y) to (row,col)
and call find_path(). Then display the path in magenta.
"""
# If we don't have at least 2 points, no path
if len(self.points) < 2:
return
if self.cost_image is None:
return
# Clear old path visualization
for item in self.path_points:
self.scene.removeItem(item)
self.path_points.clear()
# We'll define the path between the first and last point,
# or if you specifically want the first two, you can do self.points[:2].
s_x, s_y = self.points[0].get_pos()
e_x, e_y = self.points[-1].get_pos()
# Convert (x, y) => (row, col) = (int(y), int(x)) and clamp
h, w = self.cost_image.shape
s_r = int(round(s_y)); s_c = int(round(s_x))
e_r = int(round(e_y)); e_c = int(round(e_x))
# Ensure they're inside the cost_image boundary
s_r = max(0, min(s_r, h-1))
s_c = max(0, min(s_c, w-1))
e_r = max(0, min(e_r, h-1))
e_c = max(0, min(e_c, w-1))
# Attempt path
try:
path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)])
except ValueError as e:
print("Error in find_path:", e)
return
# Convert path (row,col) => (x, y)
for (r, c) in path_rc:
x = c
y = r
item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False)
self.path_points.append(item)
self.scene.addItem(item)
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Test GUI")
main_widget = QWidget()
main_layout = QVBoxLayout(main_widget)
self.image_view = ImageGraphicsView()
main_layout.addWidget(self.image_view)
btn_layout = QHBoxLayout()
# Load Image
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
# Editor Mode
self.btn_editor_mode = QPushButton("Editor Mode: OFF")
self.btn_editor_mode.setCheckable(True)
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
btn_layout.addWidget(self.btn_editor_mode)
# Export Points
self.btn_export_points = QPushButton("Export Points")
self.btn_export_points.clicked.connect(self.export_points)
btn_layout.addWidget(self.btn_export_points)
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
main_layout.addLayout(btn_layout)
self.setCentralWidget(main_widget)
self.resize(900, 600)
def load_image(self):
"""Open file dialog to pick an image, then load it."""
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "",
"Images (*.png *.jpg *.jpeg *.bmp *.tif)",
options=options
)
if file_path:
self.image_view.load_image(file_path)
# Compute cost image
self.image_view.cost_image = compute_cost_image(file_path)
def toggle_editor_mode(self):
is_checked = self.btn_editor_mode.isChecked()
self.image_view.set_editor_mode(is_checked)
if is_checked:
self.btn_editor_mode.setText("Editor Mode: ON")
self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
else:
self.btn_editor_mode.setText("Editor Mode: OFF")
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
def export_points(self):
if not self.image_view.points:
print("No points to export.")
return
options = QFileDialog.Options()
file_path, _ = QFileDialog.getSaveFileName(
self, "Export Points", "",
"NumPy Files (*.npy);;All Files (*)",
options=options
)
if file_path:
coords = [p.get_pos() for p in self.image_view.points]
points_array = np.array(coords)
np.save(file_path, points_array)
print(f"Exported {len(points_array)} points to {file_path}")
def clear_points(self):
self.image_view._clear_point_items(remove_all=False)
def main():
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())