Skip to content
Snippets Groups Projects
Commit affd43e3 authored by Christian's avatar Christian
Browse files

Added functionality for handeling mutiple points placed

parent 5358d1df
Branches
No related tags found
No related merge requests found
......@@ -10,29 +10,23 @@ from PyQt5.QtWidgets import (
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF
# Import your live_wire functions
from live_wire import compute_cost_image, find_path
class LabeledPointItem(QGraphicsEllipseItem):
"""
A circle with optional (bold) label (e.g. 'S'/'E'),
which automatically scales the text if it's bigger than the circle.
"""
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, parent=None):
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x # Center x
self._y = y # Center y
self._r = radius # Circle radius
self._x = x
self._y = y
self._r = radius
self._removable = removable
# Circle styling
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self.setZValue(z_value)
# Optional text label
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
......@@ -43,7 +37,6 @@ class LabeledPointItem(QGraphicsEllipseItem):
self._text_item.setFont(font)
self._scale_text_to_fit()
# Move so center is at (x, y)
self.set_pos(x, y)
def _scale_text_to_fit(self):
......@@ -55,9 +48,7 @@ class LabeledPointItem(QGraphicsEllipseItem):
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_w = circle_diam / text_w
scale_h = circle_diam / text_h
scale_factor = min(scale_w, scale_h)
scale_factor = min(circle_diam / text_w, circle_diam / text_h)
self._text_item.setScale(scale_factor)
self._center_label()
......@@ -75,7 +66,6 @@ class LabeledPointItem(QGraphicsEllipseItem):
self._text_item.setPos(tx, ty)
def set_pos(self, x, y):
"""Move so the circle's center is at (x,y) in scene coords."""
self._x = x
self._y = y
self.setPos(x - self._r, y - self._r)
......@@ -84,39 +74,37 @@ class LabeledPointItem(QGraphicsEllipseItem):
return (self._x, self._y)
def distance_to(self, x_other, y_other):
dx = self._x - x_other
dy = self._y - y_other
return math.sqrt(dx*dx + dy*dy)
return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
def is_removable(self):
return self._removable
class ImageGraphicsView(QGraphicsView):
"""
Displays an image and allows placing/dragging labeled points.
Ensures points can't go outside the image boundary.
"""
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Zoom around mouse pointer
# Allow zoom around mouse pointer
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
# Image item
# Image display item
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
self.points = [] # LabeledPointItem objects
self.editor_mode = False
# Parallel lists
self.anchor_points = [] # List[(x, y)]
self.point_items = [] # List[LabeledPointItem]
self.editor_mode = False
self.dot_radius = 4
self.path_radius = 1 # radius of circles in path
self.path_radius = 1
self.radius_something = 3 # cost-lowering radius
self._img_w = 0
self._img_h = 0
# For pan/drag
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ArrowCursor)
......@@ -127,52 +115,168 @@ class ImageGraphicsView(QGraphicsView):
self._dragging_idx = None
self._drag_offset = (0, 0)
# Cost image from compute_cost_image
# Keep original cost image to revert changes
self.cost_image_original = None
self.cost_image = None
# All path points displayed in magenta
self.path_points = []
# The path is displayed as small magenta circles in self.full_path_points
self.full_path_points = []
def load_image(self, image_path):
pixmap = QPixmap(image_path)
# --------------------------------------------------------------------
# LOADING
# --------------------------------------------------------------------
def load_image(self, path):
pixmap = QPixmap(path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
# Save image dimensions
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self._clear_point_items(remove_all=True)
self._clear_all_points()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
# Positions for S/E
s_x = self._img_w * 0.15
s_y = self._img_h * 0.5
e_x = self._img_w * 0.85
e_y = self._img_h * 0.5
# Create green S/E with radius=6
s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False)
e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False)
# Create S/E
s_x, s_y = 0.15*self._img_w, 0.5*self._img_h
e_x, e_y = 0.85*self._img_w, 0.5*self._img_h
self.points = [s_point, e_point]
self.scene.addItem(s_point)
self.scene.addItem(e_point)
# S => not removable
self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
# E => not removable
self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
def set_editor_mode(self, mode: bool):
self.editor_mode = mode
def _create_point(self, x, y, label, radius, color, removable=True):
# Clamp coordinates so center doesn't go outside the image
cx = self._clamp(x, radius, self._img_w - radius)
cy = self._clamp(y, radius, self._img_h - radius)
return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable)
# --------------------------------------------------------------------
# ANCHOR POINTS
# --------------------------------------------------------------------
def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
"""
Insert at index=idx, or -1 => append just before E if E exists.
"""
if idx < 0:
# If we have at least 2 anchors, the last is E => insert before that
if len(self.anchor_points) >= 2:
idx = len(self.anchor_points) - 1
else:
idx = len(self.anchor_points)
self.anchor_points.insert(idx, (x, y))
color = Qt.green if label in ("S","E") else Qt.red
item = LabeledPointItem(x, y, label=label, radius=radius, color=color,
removable=removable, z_value=z_val)
self.point_items.insert(idx, item)
self.scene.addItem(item)
def _add_guide_point(self, x, y):
"""
User added a red guide point => lower cost, insert anchor, rebuild path.
"""
# 1) Revert cost
self._revert_cost_to_original()
# 2) Insert new anchor (removable)
self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius)
# 3) Re-apply cost-lowering for all existing guide points
self._apply_all_guide_points_to_cost()
# 4) Rebuild path
self._rebuild_full_path()
# --------------------------------------------------------------------
# COST IMAGE
# --------------------------------------------------------------------
def _revert_cost_to_original(self):
"""self.cost_image <- copy of self.cost_image_original"""
if self.cost_image_original is not None:
self.cost_image = self.cost_image_original.copy()
def _apply_all_guide_points_to_cost(self):
"""Lower cost around every removable anchor."""
if self.cost_image is None:
return
for i, (ax, ay) in enumerate(self.anchor_points):
if self.point_items[i].is_removable():
self._lower_cost_in_circle(ax, ay, self.radius_something)
def _lower_cost_in_circle(self, x_f, y_f, radius):
"""Set cost_image row,col in circle of radius -> global min."""
if self.cost_image is None:
return
h, w = self.cost_image.shape
row_c = int(round(y_f))
col_c = int(round(x_f))
if not (0 <= row_c < h and 0 <= col_c < w):
return
global_min = self.cost_image.min()
r_s = max(0, row_c - radius)
r_e = min(h, row_c + radius + 1)
c_s = max(0, col_c - radius)
c_e = min(w, col_c + radius + 1)
for rr in range(r_s, r_e):
for cc in range(c_s, c_e):
dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
if dist <= radius:
self.cost_image[rr, cc] = global_min
# --------------------------------------------------------------------
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
# Remove old path items
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
# Build subpaths
if len(self.anchor_points) < 2 or self.cost_image is None:
return
def _clamp(self, val, min_val, max_val):
return max(min_val, min(val, max_val))
big_xy = []
for i in range(len(self.anchor_points)-1):
xA, yA = self.anchor_points[i]
xB, yB = self.anchor_points[i+1]
sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
if i == 0:
big_xy.extend(sub_xy)
else:
# avoid duplicating the point between subpaths
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
# Draw them
for (px, py) in big_xy:
path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
color=Qt.magenta, removable=False, z_value=0)
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
# Ensure S/E stay on top
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA, yA, xB, yB):
if self.cost_image is None:
return []
h, w = self.cost_image.shape
rA, cA = int(round(yA)), int(round(xA))
rB, cB = int(round(yB)), int(round(xB))
rA = max(0, min(rA, h-1))
cA = max(0, min(cA, w-1))
rB = max(0, min(rB, h-1))
cB = max(0, min(cB, w-1))
try:
path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
except ValueError as e:
print("Error in find_path:", e)
return []
return [(c, r) for (r, c) in path_rc]
# --------------------------------------------------------------------
# MOUSE EVENTS (dragging, adding, removing points)
# --------------------------------------------------------------------
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
......@@ -180,17 +284,18 @@ class ImageGraphicsView(QGraphicsView):
self._press_view_pos = event.pos()
if self.editor_mode:
idx = self._find_point_near(event.pos(), threshold=10)
idx = self._find_item_near(event.pos(), 10)
if idx is not None:
# drag existing anchor
self._dragging_idx = idx
scene_pos = self.mapToScene(event.pos())
px, py = self.points[idx].get_pos()
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setDragMode(QGraphicsView.NoDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
return
else:
# If no anchor near, user might be panning
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
else:
......@@ -199,25 +304,23 @@ class ImageGraphicsView(QGraphicsView):
elif event.button() == Qt.RightButton:
if self.editor_mode:
self._remove_point(event.pos())
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self._dragging_idx is not None:
# Dragging an existing point
# dragging an anchor
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
# Clamp so center doesn't go out of the image
r = self.points[self._dragging_idx]._r
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.points[self._dragging_idx].set_pos(x_clamped, y_clamped)
self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
return
else:
# If movement > threshold => treat as pan
# if movement > threshold => pan
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
......@@ -229,142 +332,117 @@ class ImageGraphicsView(QGraphicsView):
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.viewport().setCursor(Qt.ArrowCursor)
if self._dragging_idx is not None:
# The user was dragging a point and now released
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag)
self._run_find_path() # Recompute path
# update anchor_points
newX, newY = self.point_items[idx].get_pos()
# even if S/E => update coords
self.anchor_points[idx] = (newX, newY)
# revert + re-apply cost, rebuild path
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
# If not dragged, maybe add a new point
if not self._was_dragging and self.editor_mode:
self._add_point(event.pos())
self._run_find_path()
# user clicked an empty spot => add a guide point
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
self._add_guide_point(x, y)
self._was_dragging = False
def wheelEvent(self, event):
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
def _add_point(self, view_pos):
"""Add a removable red dot at the clicked location."""
scene_pos = self.mapToScene(view_pos)
x, y = scene_pos.x(), scene_pos.y()
dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True)
# Insert before the final E point if S/E exist
if len(self.points) >= 2:
self.points.insert(len(self.points) - 1, dot)
else:
self.points.append(dot)
self.scene.addItem(dot)
def _remove_point_by_click(self, view_pos):
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
# check if removable => skip S/E
if not self.point_items[idx].is_removable():
return # do nothing
def _remove_point(self, view_pos):
"""Right-click => remove nearest dot if it's removable."""
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
# remove anchor
self.scene.removeItem(self.point_items[idx])
self.point_items.pop(idx)
self.anchor_points.pop(idx)
threshold = 10
closest_idx = None
min_dist = float('inf')
for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click)
if dist < min_dist:
min_dist = dist
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
if self.points[closest_idx].is_removable():
self.scene.removeItem(self.points[closest_idx])
del self.points[closest_idx]
self._run_find_path()
# revert + re-apply cost, rebuild path
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _find_point_near(self, view_pos, threshold=10):
def _find_item_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
closest_idx = None
min_dist = float('inf')
for i, p in enumerate(self.points):
dist = p.distance_to(x_click, y_click)
if dist < min_dist:
min_dist = dist
closest_idx = None
for i, itm in enumerate(self.point_items):
d = itm.distance_to(x_click, y_click)
if d < min_dist:
min_dist = d
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
def _clear_point_items(self, remove_all=False):
"""Remove all points if remove_all=True; else just removable ones."""
if remove_all:
for p in self.points:
self.scene.removeItem(p)
self.points.clear()
else:
still_needed = []
for p in self.points:
if p.is_removable():
self.scene.removeItem(p)
else:
still_needed.append(p)
self.points = still_needed
# Also remove any path points from the scene
for p_item in self.path_points:
self.scene.removeItem(p_item)
self.path_points.clear()
def _run_find_path(self):
# --------------------------------------------------------------------
# ZOOM
# --------------------------------------------------------------------
def wheelEvent(self, event):
"""
Convert the first two points (S/E) from (x,y) to (row,col)
and call find_path(). Then display the path in magenta.
Zoom in/out with mouse wheel
"""
# If we don't have at least 2 points, no path
if len(self.points) < 2:
return
if self.cost_image is None:
return
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
# Clear old path visualization
for item in self.path_points:
self.scene.removeItem(item)
self.path_points.clear()
# If the user scrolls upward => zoom in. Otherwise => zoom out.
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
# We'll define the path between the first and last point,
# or if you specifically want the first two, you can do self.points[:2].
s_x, s_y = self.points[0].get_pos()
e_x, e_y = self.points[-1].get_pos()
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
def _clamp(self, val, mn, mx):
return max(mn, min(val, mx))
# Convert (x, y) => (row, col) = (int(y), int(x)) and clamp
h, w = self.cost_image.shape
s_r = int(round(s_y)); s_c = int(round(s_x))
e_r = int(round(e_y)); e_c = int(round(e_x))
def _clear_all_points(self):
for it in self.point_items:
self.scene.removeItem(it)
self.point_items.clear()
self.anchor_points.clear()
# Ensure they're inside the cost_image boundary
s_r = max(0, min(s_r, h-1))
s_c = max(0, min(s_c, w-1))
e_r = max(0, min(e_r, h-1))
e_c = max(0, min(e_c, w-1))
for p in self.full_path_points:
self.scene.removeItem(p)
self.full_path_points.clear()
# Attempt path
try:
path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)])
except ValueError as e:
print("Error in find_path:", e)
return
def clear_guide_points(self):
"""
Removes all anchors that are 'removable' (guide points),
keeps S/E in place. Then reverts cost, re-applies, rebuilds path.
"""
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
self.scene.removeItem(self.point_items[i])
del self.point_items[i]
del self.anchor_points[i]
else:
i += 1
# Convert path (row,col) => (x, y)
for (r, c) in path_rc:
x = c
y = r
item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False)
self.path_points.append(item)
self.scene.addItem(item)
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
class MainWindow(QMainWindow):
......@@ -380,24 +458,20 @@ class MainWindow(QMainWindow):
btn_layout = QHBoxLayout()
# Load Image
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
# Editor Mode
self.btn_editor_mode = QPushButton("Editor Mode: OFF")
self.btn_editor_mode.setCheckable(True)
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
btn_layout.addWidget(self.btn_editor_mode)
# Export Points
self.btn_export_points = QPushButton("Export Points")
self.btn_export_points.clicked.connect(self.export_points)
btn_layout.addWidget(self.btn_export_points)
# Clear Points
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
......@@ -407,7 +481,6 @@ class MainWindow(QMainWindow):
self.resize(900, 600)
def load_image(self):
"""Open file dialog to pick an image, then load it."""
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "",
......@@ -416,8 +489,9 @@ class MainWindow(QMainWindow):
)
if file_path:
self.image_view.load_image(file_path)
# Compute cost image
self.image_view.cost_image = compute_cost_image(file_path)
cost_img = compute_cost_image(file_path)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
def toggle_editor_mode(self):
is_checked = self.btn_editor_mode.isChecked()
......@@ -430,10 +504,9 @@ class MainWindow(QMainWindow):
self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
def export_points(self):
if not self.image_view.points:
print("No points to export.")
if not self.image_view.anchor_points:
print("No anchor points to export.")
return
options = QFileDialog.Options()
file_path, _ = QFileDialog.getSaveFileName(
self, "Export Points", "",
......@@ -441,13 +514,16 @@ class MainWindow(QMainWindow):
options=options
)
if file_path:
coords = [p.get_pos() for p in self.image_view.points]
points_array = np.array(coords)
points_array = np.array(self.image_view.anchor_points)
np.save(file_path, points_array)
print(f"Exported {len(points_array)} points to {file_path}")
def clear_points(self):
self.image_view._clear_point_items(remove_all=False)
"""Remove all removable anchors (guide points), keep S/E in place."""
self.image_view.clear_guide_points()
def closeEvent(self, event):
super().closeEvent(event)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment