Skip to content
Snippets Groups Projects
Commit c33b8c2c authored by Christian's avatar Christian
Browse files

added live updating of shown path

parent a2aa6286
No related branches found
No related tags found
No related merge requests found
import sys
import math
import numpy as np
# For smoothing the path
from scipy.signal import savgol_filter
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
)
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF
from live_wire import compute_cost_image, find_path
class LabeledPointItem(QGraphicsEllipseItem):
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x
self._y = y
self._r = radius
self._removable = removable
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self.setZValue(z_value)
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black"))
font = QFont("Arial", 14)
font.setBold(True)
self._text_item.setFont(font)
self._scale_text_to_fit()
self.set_pos(x, y)
def _scale_text_to_fit(self):
if not self._text_item:
return
self._text_item.setScale(1.0)
circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_factor = min(circle_diam / text_w, circle_diam / text_h)
self._text_item.setScale(scale_factor)
self._center_label()
def _center_label(self):
if not self._text_item:
return
ellipse_w = 2 * self._r
ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale()
scaled_w = raw_rect.width() * scale_factor
scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty)
def set_pos(self, x, y):
self._x = x
self._y = y
self.setPos(x - self._r, y - self._r)
def get_pos(self):
return (self._x, self._y)
def distance_to(self, x_other, y_other):
return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
def is_removable(self):
return self._removable
class ImageGraphicsView(QGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Allow zoom around mouse pointer
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
# Image display item
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
# Parallel lists
self.anchor_points = [] # List[(x, y)]
self.point_items = [] # List[LabeledPointItem]
self.editor_mode = False
self.dot_radius = 4
self.path_radius = 1
self.radius_cost_image = 2 # cost-lowering radius
self._img_w = 0
self._img_h = 0
# For pan/drag
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ArrowCursor)
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
self._was_dragging = False
self._dragging_idx = None
self._drag_offset = (0, 0)
# NEW: We'll count how many times we've updated the drag => partial path update
self._drag_counter = 0
# Keep original cost image to revert changes
self.cost_image_original = None
self.cost_image = None
# The path is displayed as small magenta circles in self.full_path_points
self.full_path_points = []
# --------------------------------------------------------------------
# LOADING
# --------------------------------------------------------------------
def load_image(self, path):
pixmap = QPixmap(path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self._clear_all_points()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
# Create S/E
s_x, s_y = 0.15*self._img_w, 0.5*self._img_h
e_x, e_y = 0.85*self._img_w, 0.5*self._img_h
# S => not removable
self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
# E => not removable
self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
def set_editor_mode(self, mode: bool):
self.editor_mode = mode
# --------------------------------------------------------------------
# ANCHOR POINTS
# --------------------------------------------------------------------
def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
"""
Insert at index=idx, or -1 => append just before E if E exists.
"""
if idx < 0:
# If we have at least 2 anchors, the last is E => insert before that
if len(self.anchor_points) >= 2:
idx = len(self.anchor_points) - 1
else:
idx = len(self.anchor_points)
self.anchor_points.insert(idx, (x, y))
color = Qt.green if label in ("S","E") else Qt.red
item = LabeledPointItem(x, y, label=label, radius=radius, color=color,
removable=removable, z_value=z_val)
self.point_items.insert(idx, item)
self.scene.addItem(item)
def _add_guide_point(self, x, y):
"""
User added a red guide point => lower cost, insert anchor, rebuild path.
"""
# 1) Revert cost
self._revert_cost_to_original()
# 2) Insert new anchor (removable)
self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius)
# 3) Re-apply cost-lowering for all existing guide points
self._apply_all_guide_points_to_cost()
# 4) Rebuild path
self._rebuild_full_path()
# --------------------------------------------------------------------
# COST IMAGE
# --------------------------------------------------------------------
def _revert_cost_to_original(self):
"""self.cost_image <- copy of self.cost_image_original"""
if self.cost_image_original is not None:
self.cost_image = self.cost_image_original.copy()
def _apply_all_guide_points_to_cost(self):
"""Lower cost around every removable anchor."""
if self.cost_image is None:
return
for i, (ax, ay) in enumerate(self.anchor_points):
if self.point_items[i].is_removable():
self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
def _lower_cost_in_circle(self, x_f, y_f, radius):
"""Set cost_image row,col in circle of radius -> global min."""
if self.cost_image is None:
return
h, w = self.cost_image.shape
row_c = int(round(y_f))
col_c = int(round(x_f))
if not (0 <= row_c < h and 0 <= col_c < w):
return
global_min = self.cost_image.min()
r_s = max(0, row_c - radius)
r_e = min(h, row_c + radius + 1)
c_s = max(0, col_c - radius)
c_e = min(w, col_c + radius + 1)
for rr in range(r_s, r_e):
for cc in range(c_s, c_e):
dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
if dist <= radius:
self.cost_image[rr, cc] = global_min
# --------------------------------------------------------------------
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
# Remove old path items
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
# Build subpaths
if len(self.anchor_points) < 2 or self.cost_image is None:
return
big_xy = []
for i in range(len(self.anchor_points)-1):
xA, yA = self.anchor_points[i]
xB, yB = self.anchor_points[i+1]
sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
if i == 0:
big_xy.extend(sub_xy)
else:
# avoid duplicating the point between subpaths
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
# Smoothing with Savitzky-Golay
if len(big_xy) >= 7:
arr_xy = np.array(big_xy) # shape (N,2)
smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
big_xy = smoothed.tolist()
# Draw them
for (px, py) in big_xy:
path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
color=Qt.magenta, removable=False, z_value=0)
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
# Ensure S/E stay on top
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA, yA, xB, yB):
if self.cost_image is None:
return []
h, w = self.cost_image.shape
rA, cA = int(round(yA)), int(round(xA))
rB, cB = int(round(yB)), int(round(xB))
rA = max(0, min(rA, h-1))
cA = max(0, min(cA, w-1))
rB = max(0, min(rB, h-1))
cB = max(0, min(cB, w-1))
try:
path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
except ValueError as e:
print("Error in find_path:", e)
return []
return [(c, r) for (r, c) in path_rc]
# --------------------------------------------------------------------
# MOUSE EVENTS (dragging, adding, removing points)
# --------------------------------------------------------------------
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
self._was_dragging = False
self._press_view_pos = event.pos()
if self.editor_mode:
idx = self._find_item_near(event.pos(), 10)
if idx is not None:
# drag existing anchor
self._dragging_idx = idx
# Reset drag counter
self._drag_counter = 0
scene_pos = self.mapToScene(event.pos())
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setDragMode(QGraphicsView.NoDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
return
else:
# If no anchor near, user might be panning
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
else:
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
elif event.button() == Qt.RightButton:
if self.editor_mode:
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self._dragging_idx is not None:
# dragging an anchor
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
# THROTTLE: only recalc path every few frames
self._drag_counter += 1
if self._drag_counter >= 4:
self._drag_counter = 0
# do partial path update:
# (We won't revert cost if you want the user to see the “final” cost-lowered path only at the end
# or you can do the entire revert+reapply if you like)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
# update anchor_points
newX, newY = self.point_items[self._dragging_idx].get_pos()
self.anchor_points[self._dragging_idx] = (newX, newY)
self._rebuild_full_path()
return
else:
# if movement > threshold => pan
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.viewport().setCursor(Qt.ArrowCursor)
if self._dragging_idx is not None:
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag)
# Final big update
newX, newY = self.point_items[idx].get_pos()
self.anchor_points[idx] = (newX, newY)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
if not self._was_dragging and self.editor_mode:
# user clicked an empty spot => add a guide point
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
self._add_guide_point(x, y)
self._was_dragging = False
def _remove_point_by_click(self, view_pos):
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
# check if removable => skip S/E
if not self.point_items[idx].is_removable():
return # do nothing
# remove anchor
self.scene.removeItem(self.point_items[idx])
self.point_items.pop(idx)
self.anchor_points.pop(idx)
# revert + re-apply cost, rebuild path
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _find_item_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
min_dist = float('inf')
closest_idx = None
for i, itm in enumerate(self.point_items):
d = itm.distance_to(x_click, y_click)
if d < min_dist:
min_dist = d
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
# --------------------------------------------------------------------
# ZOOM
# --------------------------------------------------------------------
def wheelEvent(self, event):
"""
Zoom in/out with mouse wheel
"""
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
# If the user scrolls upward => zoom in. Otherwise => zoom out.
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
def _clamp(self, val, mn, mx):
return max(mn, min(val, mx))
def _clear_all_points(self):
for it in self.point_items:
self.scene.removeItem(it)
self.point_items.clear()
self.anchor_points.clear()
for p in self.full_path_points:
self.scene.removeItem(p)
self.full_path_points.clear()
def clear_guide_points(self):
"""
Removes all anchors that are 'removable' (guide points),
keeps S/E in place. Then reverts cost, re-applies, rebuilds path.
"""
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
self.scene.removeItem(self.point_items[i])
del self.point_items[i]
del self.anchor_points[i]
else:
i += 1
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Test GUI")
main_widget = QWidget()
main_layout = QVBoxLayout(main_widget)
self.image_view = ImageGraphicsView()
main_layout.addWidget(self.image_view)
btn_layout = QHBoxLayout()
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
self.btn_editor_mode = QPushButton("Editor Mode: OFF")
self.btn_editor_mode.setCheckable(True)
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
btn_layout.addWidget(self.btn_editor_mode)
self.btn_export_points = QPushButton("Export Points")
self.btn_export_points.clicked.connect(self.export_points)
btn_layout.addWidget(self.btn_export_points)
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
main_layout.addLayout(btn_layout)
self.setCentralWidget(main_widget)
self.resize(900, 600)
def load_image(self):
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "",
"Images (*.png *.jpg *.jpeg *.bmp *.tif)",
options=options
)
if file_path:
self.image_view.load_image(file_path)
cost_img = compute_cost_image(file_path)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
def toggle_editor_mode(self):
is_checked = self.btn_editor_mode.isChecked()
self.image_view.set_editor_mode(is_checked)
if is_checked:
self.btn_editor_mode.setText("Editor Mode: ON")
self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
else:
self.btn_editor_mode.setText("Editor Mode: OFF")
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
def export_points(self):
if not self.image_view.anchor_points:
print("No anchor points to export.")
return
options = QFileDialog.Options()
file_path, _ = QFileDialog.getSaveFileName(
self, "Export Points", "",
"NumPy Files (*.npy);;All Files (*)",
options=options
)
if file_path:
points_array = np.array(self.image_view.anchor_points)
np.save(file_path, points_array)
print(f"Exported {len(points_array)} points to {file_path}")
def clear_points(self):
"""Remove all removable anchors (guide points), keep S/E in place."""
self.image_view.clear_guide_points()
def closeEvent(self, event):
super().closeEvent(event)
def main():
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment