Skip to content
Snippets Groups Projects
Commit 98a35b34 authored by Christian's avatar Christian
Browse files

Added interactive kernel size tuning

parent ee86bd84
No related branches found
No related tags found
No related merge requests found
......@@ -8,14 +8,191 @@ from scipy.signal import savgol_filter
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
QSlider, QLabel
)
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF
from PyQt5.QtCore import Qt, QRectF, QSize
from live_wire import compute_cost_image, find_path
# ------------------------------------------------------------------------
# A pan & zoom QGraphicsView
# ------------------------------------------------------------------------
class PanZoomGraphicsView(QGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.setDragMode(QGraphicsView.NoDrag) # We'll handle panning manually
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
self._panning = False
self._pan_start = None
def wheelEvent(self, event):
"""
Zoom in/out with mouse wheel.
"""
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
def mousePressEvent(self, event):
"""
If left button: Start panning (unless overridden in a subclass).
"""
if event.button() == Qt.LeftButton:
self._panning = True
self._pan_start = event.pos()
self.setCursor(Qt.ClosedHandCursor)
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
"""
If panning, translate the scene.
"""
if self._panning and self._pan_start is not None:
delta = event.pos() - self._pan_start
self._pan_start = event.pos()
self.translate(delta.x(), delta.y())
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
"""
End panning.
"""
if event.button() == Qt.LeftButton:
self._panning = False
self.setCursor(Qt.ArrowCursor)
super().mouseReleaseEvent(event)
# ------------------------------------------------------------------------
# A specialized PanZoomGraphicsView for the circle editor
# Only pan if user did NOT click on the draggable circle
# ------------------------------------------------------------------------
class CircleEditorGraphicsView(PanZoomGraphicsView):
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
# Check if the user clicked on the circle item
clicked_item = self.itemAt(event.pos().x(), event.pos().y())
if clicked_item is not None:
# Walk up parent chain to see if it is our DraggableCircleItem
it = clicked_item
while it is not None and not hasattr(it, "boundingRect"):
it = it.parentItem()
from PyQt5.QtWidgets import QGraphicsEllipseItem
if isinstance(it, DraggableCircleItem):
# Let normal item-dragging occur, don't initiate panning
return QGraphicsView.mousePressEvent(self, event)
# Otherwise proceed with normal panning logic
super().mousePressEvent(event)
# ------------------------------------------------------------------------
# Draggable circle item (centered at (x, y) with radius)
# ------------------------------------------------------------------------
class DraggableCircleItem(QGraphicsEllipseItem):
def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._r = radius
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
# Enable item-based dragging
self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
QGraphicsEllipseItem.ItemIsSelectable |
QGraphicsEllipseItem.ItemSendsScenePositionChanges)
# Position so that (x, y) is the center
self.setPos(x - radius, y - radius)
def set_radius(self, r):
# Keep the same center, just change radius
old_center = self.sceneBoundingRect().center()
self._r = r
self.setRect(0, 0, 2*r, 2*r)
new_center = self.sceneBoundingRect().center()
diff_x = old_center.x() - new_center.x()
diff_y = old_center.y() - new_center.y()
self.moveBy(diff_x, diff_y)
def radius(self):
return self._r
# ------------------------------------------------------------------------
# Circle editor widget with slider + done
# ------------------------------------------------------------------------
class CircleEditorWidget(QWidget):
def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
super().__init__(parent)
self._pixmap = pixmap
self._done_callback = done_callback
self._init_radius = init_radius
layout = QVBoxLayout(self)
self.setLayout(layout)
# Use specialized CircleEditorGraphicsView
self._graphics_view = CircleEditorGraphicsView()
self._scene = QGraphicsScene(self)
self._graphics_view.setScene(self._scene)
layout.addWidget(self._graphics_view)
self._image_item = QGraphicsPixmapItem(self._pixmap)
self._scene.addItem(self._image_item)
# Put circle in center
cx = self._pixmap.width() / 2
cy = self._pixmap.height() / 2
self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
self._scene.addItem(self._circle_item)
# Fit in view
self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
# Bottom controls (slider + done)
bottom_layout = QHBoxLayout()
layout.addLayout(bottom_layout)
lbl = QLabel("size:")
bottom_layout.addWidget(lbl)
self._slider = QSlider(Qt.Horizontal)
self._slider.setRange(1, 200)
self._slider.setValue(self._init_radius)
bottom_layout.addWidget(self._slider)
self._btn_done = QPushButton("Done")
bottom_layout.addWidget(self._btn_done)
# Connect signals
self._slider.valueChanged.connect(self._on_slider_changed)
self._btn_done.clicked.connect(self._on_done_clicked)
def _on_slider_changed(self, value):
self._circle_item.set_radius(value)
def _on_done_clicked(self):
final_radius = self._circle_item.radius()
if self._done_callback is not None:
self._done_callback(final_radius)
def sizeHint(self):
return QSize(800, 600)
# ------------------------------------------------------------------------
# Labeled point item
# ------------------------------------------------------------------------
class LabeledPointItem(QGraphicsEllipseItem):
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
......@@ -84,25 +261,23 @@ class LabeledPointItem(QGraphicsEllipseItem):
return self._removable
class ImageGraphicsView(QGraphicsView):
# ------------------------------------------------------------------------
# The original ImageGraphicsView with pan & zoom
# ------------------------------------------------------------------------
class ImageGraphicsView(PanZoomGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Zoom around mouse pointer
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
# Image display
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
self.anchor_points = [] # List[(x, y)]
self.point_items = [] # LabeledPointItem objects
self.full_path_points = [] # QGraphicsEllipseItems for the path
# We'll store the entire path coords (smoothed) for reference
self._full_path_xy = []
self.point_items = [] # LabeledPointItem
self.full_path_points = [] # QGraphicsEllipseItems for path
self._full_path_xy = [] # entire path coords (smoothed)
self.dot_radius = 4
self.path_radius = 1
......@@ -110,10 +285,6 @@ class ImageGraphicsView(QGraphicsView):
self._img_w = 0
self._img_h = 0
# Pan/Drag
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ArrowCursor)
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
......@@ -185,33 +356,25 @@ class ImageGraphicsView(QGraphicsView):
self.scene.addItem(item)
def _add_guide_point(self, x, y):
"""User clicked => find the correct sub-path, insert the point in that sub-path."""
x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
self._revert_cost_to_original()
if not self._full_path_xy:
# If there's no existing path built, just insert at the end
self._insert_anchor_point(-1, x_clamped, y_clamped,
label="", removable=True, z_val=1, radius=self.dot_radius)
else:
# Insert the new anchor in between the correct anchors,
# by finding nearest coordinate in _full_path_xy, then
# walking left+right until we find bounding anchors.
self._insert_anchor_between_subpath(x_clamped, y_clamped)
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _insert_anchor_between_subpath(self, x_new, y_new):
"""Find the subpath bounding (x_new,y_new) and insert the new anchor accordingly."""
if not self._full_path_xy:
# Fallback if no path
self._insert_anchor_point(-1, x_new, y_new)
return
# 1) Find nearest coordinate in the path
best_idx = None
best_d2 = float('inf')
for i, (px, py) in enumerate(self._full_path_xy):
......@@ -223,7 +386,6 @@ class ImageGraphicsView(QGraphicsView):
best_idx = i
if best_idx is None:
# fallback
self._insert_anchor_point(-1, x_new, y_new)
return
......@@ -237,7 +399,7 @@ class ImageGraphicsView(QGraphicsView):
return True
return False
# 2) Walk left
# Walk left
left_anchor_pt = None
iL = best_idx
while iL >= 0:
......@@ -247,7 +409,7 @@ class ImageGraphicsView(QGraphicsView):
break
iL -= 1
# 3) Walk right
# Walk right
right_anchor_pt = None
iR = best_idx
while iR < len(self._full_path_xy):
......@@ -257,17 +419,14 @@ class ImageGraphicsView(QGraphicsView):
break
iR += 1
# fallback if missing anchors
if not left_anchor_pt or not right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
# If they happen to be the same anchor, fallback
if left_anchor_pt == right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
# 4) Map these anchor coords to indices in self.anchor_points
left_idx = None
right_idx = None
for i, (ax, ay) in enumerate(self.anchor_points):
......@@ -280,7 +439,6 @@ class ImageGraphicsView(QGraphicsView):
self._insert_anchor_point(-1, x_new, y_new)
return
# 5) Insert new point in between
if left_idx < right_idx:
insert_idx = left_idx + 1
else:
......@@ -326,7 +484,6 @@ class ImageGraphicsView(QGraphicsView):
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
# Clear old path visuals
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
......@@ -343,28 +500,19 @@ class ImageGraphicsView(QGraphicsView):
if i == 0:
big_xy.extend(sub_xy)
else:
# Avoid repeating the shared anchor
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
# Smooth if we have enough points
if len(big_xy) >= 7:
arr_xy = np.array(big_xy)
smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
big_xy = smoothed.tolist()
# Store the entire path
self._full_path_xy = big_xy[:]
# Draw the path
n_points = len(big_xy)
for i, (px, py) in enumerate(big_xy):
if n_points > 1:
fraction = i / (n_points - 1)
else:
fraction = 0
# If rainbow is on, use the rainbow color; else use a constant color
fraction = i / (n_points - 1) if n_points > 1 else 0
if self._rainbow_enabled:
color = self._rainbow_color(fraction)
else:
......@@ -378,13 +526,12 @@ class ImageGraphicsView(QGraphicsView):
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
# Keep S/E on top if they have labels
# Keep anchor labels on top
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA, yA, xB, yB):
"""Return the raw path from (xA,yA)->(xB,yB)."""
if self.cost_image is None:
return []
h, w = self.cost_image.shape
......@@ -402,18 +549,13 @@ class ImageGraphicsView(QGraphicsView):
return [(c, r) for (r, c) in path_rc]
def _rainbow_color(self, fraction):
"""
fraction: 0..1
Returns a QColor whose hue is fraction * 300 (for example),
at full saturation and full brightness.
"""
hue = int(300 * fraction) # up to 300 degrees
hue = int(300 * fraction)
saturation = 255
value = 255
return QColor.fromHsv(hue, saturation, value)
# --------------------------------------------------------------------
# MOUSE EVENTS
# MOUSE EVENTS (with pan & zoom from PanZoomGraphicsView)
# --------------------------------------------------------------------
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
......@@ -421,37 +563,27 @@ class ImageGraphicsView(QGraphicsView):
self._was_dragging = False
self._press_view_pos = event.pos()
# See if user is clicking near an existing anchor => drag it
idx = self._find_item_near(event.pos(), threshold=10)
if idx is not None:
self._dragging_idx = idx
self._drag_counter = 0
scene_pos = self.mapToScene(event.pos())
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setDragMode(QGraphicsView.NoDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
self.setCursor(Qt.ClosedHandCursor)
return
else:
# No anchor => we may add a new point
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.viewport().setCursor(Qt.ClosedHandCursor)
elif event.button() == Qt.RightButton:
# Right-click => remove anchor if removable
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self._dragging_idx is not None:
# Dragging anchor
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
# clamp so user can't drag outside
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
......@@ -459,33 +591,29 @@ class ImageGraphicsView(QGraphicsView):
self._drag_counter += 1
if self._drag_counter >= 4:
# partial path update
self._drag_counter = 0
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
self._rebuild_full_path()
else:
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
super().mouseMoveEvent(event)
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.viewport().setCursor(Qt.ArrowCursor)
self.setCursor(Qt.ArrowCursor)
if self._dragging_idx is not None:
# finished dragging => final update
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
self.setDragMode(QGraphicsView.ScrollHandDrag)
newX, newY = self.point_items[idx].get_pos()
self.anchor_points[idx] = (newX, newY)
......@@ -493,9 +621,8 @@ class ImageGraphicsView(QGraphicsView):
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
# If user wasn't dragging => add new guide point
# No drag => add point
if not self._was_dragging:
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
......@@ -507,7 +634,6 @@ class ImageGraphicsView(QGraphicsView):
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
# skip if S/E
if not self.point_items[idx].is_removable():
return
......@@ -534,18 +660,6 @@ class ImageGraphicsView(QGraphicsView):
return closest_idx
return None
# --------------------------------------------------------------------
# ZOOM
# --------------------------------------------------------------------
def wheelEvent(self, event):
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
......@@ -564,7 +678,6 @@ class ImageGraphicsView(QGraphicsView):
self._full_path_xy.clear()
def clear_guide_points(self):
"""Remove all removable anchors, keep S/E. Rebuild path."""
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
......@@ -584,50 +697,106 @@ class ImageGraphicsView(QGraphicsView):
self._rebuild_full_path()
def get_full_path_xy(self):
"""Return the entire path (x,y) array after smoothing."""
return self._full_path_xy
# ------------------------------------------------------------------------
# Main Window
# ------------------------------------------------------------------------
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Test GUI")
main_widget = QWidget()
main_layout = QVBoxLayout(main_widget)
self._last_loaded_pixmap = None
self._circle_radius_for_later_use = 0
# Original main widget
self._main_widget = QWidget()
self._main_layout = QVBoxLayout(self._main_widget)
# Image view
self.image_view = ImageGraphicsView()
main_layout.addWidget(self.image_view)
self._main_layout.addWidget(self.image_view)
# Buttons layout
# Button row
btn_layout = QHBoxLayout()
# Load Image
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
# Export Path
self.btn_export_path = QPushButton("Export Path")
self.btn_export_path.clicked.connect(self.export_path)
btn_layout.addWidget(self.btn_export_path)
# Clear Points
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
# Toggle Rainbow
self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
self.btn_toggle_rainbow.clicked.connect(self.toggle_rainbow)
btn_layout.addWidget(self.btn_toggle_rainbow)
main_layout.addLayout(btn_layout)
self.setCentralWidget(main_widget)
# New circle editor button
self.btn_open_editor = QPushButton("Open Circle Editor")
self.btn_open_editor.clicked.connect(self.open_circle_editor)
btn_layout.addWidget(self.btn_open_editor)
self._main_layout.addLayout(btn_layout)
self.setCentralWidget(self._main_widget)
self.resize(900, 600)
# We keep references for old/new
self._old_central_widget = None
self._editor = None
def open_circle_editor(self):
"""Removes the current central widget, replaces with circle editor."""
if not self._last_loaded_pixmap:
print("No image loaded yet! Cannot open circle editor.")
return
# Step 1: take the old widget out of QMainWindow ownership
old_widget = self.takeCentralWidget()
self._old_central_widget = old_widget
# Step 2: create the editor
init_radius = 20
editor = CircleEditorWidget(
pixmap=self._last_loaded_pixmap,
init_radius=init_radius,
done_callback=self._on_circle_editor_done
)
self._editor = editor
# Step 3: set the new editor as the central widget
self.setCentralWidget(editor)
def _on_circle_editor_done(self, final_radius):
self._circle_radius_for_later_use = final_radius
print(f"Circle Editor done. Radius = {final_radius}")
# Take the editor out
editor_widget = self.takeCentralWidget()
if editor_widget is not None:
editor_widget.setParent(None)
# Put back the old widget
if self._old_central_widget is not None:
self.setCentralWidget(self._old_central_widget)
self._old_central_widget = None
# We can delete the editor if we like
if self._editor is not None:
self._editor.deleteLater()
self._editor = None
# --------------------------------------------------------------------
# Existing Functions
# --------------------------------------------------------------------
def toggle_rainbow(self):
"""Toggle the rainbow mode in the view."""
self.image_view.toggle_rainbow()
def load_image(self):
......@@ -643,8 +812,12 @@ class MainWindow(QMainWindow):
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
# Store a pixmap to reuse
pm = QPixmap(file_path)
if not pm.isNull():
self._last_loaded_pixmap = pm
def export_path(self):
"""Export the full path (x,y) as a .npy file."""
full_xy = self.image_view.get_full_path_xy()
if not full_xy:
print("No path to export.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment