Newer
Older
import sys
import math
import numpy as np
# For smoothing the path
from scipy.signal import savgol_filter
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont, QImage
from PyQt5.QtCore import Qt, QRectF, QSize
# Make sure the following imports exist in live_wire.py (or similar):
# from skimage import exposure
# from skimage.filters import gaussian
# def preprocess_image(image, sigma=3, clip_limit=0.01): ...
from live_wire import compute_cost_image, find_path, preprocess_image
# ------------------------------------------------------------------------
# A pan & zoom QGraphicsView
# ------------------------------------------------------------------------
class PanZoomGraphicsView(QGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.setDragMode(QGraphicsView.NoDrag) # We'll handle panning manually
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
self._panning = False
self._pan_start = None
# Let it expand in layouts
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def wheelEvent(self, event):
"""
Zoom in/out with mouse wheel.
"""
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
def mousePressEvent(self, event):
"""
If left button: Start panning (unless overridden in a subclass).
"""
if event.button() == Qt.LeftButton:
self._panning = True
self._pan_start = event.pos()
self.setCursor(Qt.ClosedHandCursor)
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
"""
If panning, translate the scene.
"""
if self._panning and self._pan_start is not None:
delta = event.pos() - self._pan_start
self._pan_start = event.pos()
self.translate(delta.x(), delta.y())
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
"""
End panning.
"""
if event.button() == Qt.LeftButton:
self._panning = False
self.setCursor(Qt.ArrowCursor)
super().mouseReleaseEvent(event)
# ------------------------------------------------------------------------
# A specialized PanZoomGraphicsView for the circle editor
# ------------------------------------------------------------------------
class CircleEditorGraphicsView(PanZoomGraphicsView):
def __init__(self, circle_editor_widget, parent=None):
super().__init__(parent)
self._circle_editor_widget = circle_editor_widget
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
# Check if the user clicked on the circle item
clicked_item = self.itemAt(event.pos())
if clicked_item is not None:
# Walk up parent chain to see if it is our DraggableCircleItem
it = clicked_item
while it is not None and not hasattr(it, "boundingRect"):
it = it.parentItem()
if isinstance(it, DraggableCircleItem):
# Let normal item-dragging occur, don't initiate panning
return QGraphicsView.mousePressEvent(self, event)
# Otherwise proceed with normal panning logic
super().mousePressEvent(event)
def wheelEvent(self, event):
"""
Overridden so that if the mouse is hovering over the circle,
we adjust the circle's radius instead of zooming the image.
"""
pos_in_widget = event.pos()
item_under = self.itemAt(pos_in_widget)
if item_under is not None:
# climb up the chain to find if it's our DraggableCircleItem
it = item_under
while it is not None and not hasattr(it, "boundingRect"):
it = it.parentItem()
if isinstance(it, DraggableCircleItem):
# Scroll up -> increase radius, scroll down -> decrease
delta = event.angleDelta().y()
step = 1 if delta > 0 else -1
old_r = it.radius()
new_r = max(1, old_r + step)
it.set_radius(new_r)
# Also update the slider in the parent CircleEditorWidget
self._circle_editor_widget.update_slider_value(new_r)
event.accept()
return
# else do normal pan/zoom
super().wheelEvent(event)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# ------------------------------------------------------------------------
# Draggable circle item (centered at (x, y) with radius)
# ------------------------------------------------------------------------
class DraggableCircleItem(QGraphicsEllipseItem):
def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._r = radius
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
# Enable item-based dragging
self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
QGraphicsEllipseItem.ItemIsSelectable |
QGraphicsEllipseItem.ItemSendsScenePositionChanges)
# Position so that (x, y) is the center
self.setPos(x - radius, y - radius)
def set_radius(self, r):
# Keep the same center, just change radius
old_center = self.sceneBoundingRect().center()
self._r = r
self.setRect(0, 0, 2*r, 2*r)
new_center = self.sceneBoundingRect().center()
diff_x = old_center.x() - new_center.x()
diff_y = old_center.y() - new_center.y()
self.moveBy(diff_x, diff_y)
def radius(self):
return self._r
# ------------------------------------------------------------------------
# Circle editor widget with slider + done
# ------------------------------------------------------------------------
class CircleEditorWidget(QWidget):
def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
super().__init__(parent)
self._pixmap = pixmap
self._done_callback = done_callback
self._init_radius = init_radius
layout = QVBoxLayout(self)
self.setLayout(layout)
# Use specialized CircleEditorGraphicsView
self._graphics_view = CircleEditorGraphicsView(circle_editor_widget=self)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
self._scene = QGraphicsScene(self)
self._graphics_view.setScene(self._scene)
layout.addWidget(self._graphics_view)
self._image_item = QGraphicsPixmapItem(self._pixmap)
self._scene.addItem(self._image_item)
# Put circle in center
cx = self._pixmap.width() / 2
cy = self._pixmap.height() / 2
self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
self._scene.addItem(self._circle_item)
# Fit in view
self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
# Bottom controls (slider + done)
bottom_layout = QHBoxLayout()
layout.addLayout(bottom_layout)
lbl = QLabel("size:")
bottom_layout.addWidget(lbl)
self._slider = QSlider(Qt.Horizontal)
self._slider.setRange(1, 200)
self._slider.setValue(self._init_radius)
bottom_layout.addWidget(self._slider)
self._btn_done = QPushButton("Done")
bottom_layout.addWidget(self._btn_done)
# Connect signals
self._slider.valueChanged.connect(self._on_slider_changed)
self._btn_done.clicked.connect(self._on_done_clicked)
def _on_slider_changed(self, value):
self._circle_item.set_radius(value)
def _on_done_clicked(self):
final_radius = self._circle_item.radius()
if self._done_callback is not None:
self._done_callback(final_radius)
def update_slider_value(self, new_radius):
"""
Called by CircleEditorGraphicsView when the user scrolls on the circle item.
We sync the slider to the new radius.
"""
self._slider.setValue(new_radius)
self._slider.blockSignals(False)
def sizeHint(self):
return QSize(800, 600)
# ------------------------------------------------------------------------
# Labeled point item
# ------------------------------------------------------------------------
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class LabeledPointItem(QGraphicsEllipseItem):
def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x
self._y = y
self._r = radius
self._removable = removable
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self.setZValue(z_value)
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black"))
font = QFont("Arial", 14)
font.setBold(True)
self._text_item.setFont(font)
self._scale_text_to_fit()
self.set_pos(x, y)
def _scale_text_to_fit(self):
if not self._text_item:
return
self._text_item.setScale(1.0)
circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_factor = min(circle_diam / text_w, circle_diam / text_h)
self._text_item.setScale(scale_factor)
self._center_label()
def _center_label(self):
if not self._text_item:
return
ellipse_w = 2 * self._r
ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale()
scaled_w = raw_rect.width() * scale_factor
scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty)
def set_pos(self, x, y):
"""Positions the circle so its center is at (x, y)."""
self._x = x
self._y = y
self.setPos(x - self._r, y - self._r)
def get_pos(self):
return (self._x, self._y)
def distance_to(self, x_other, y_other):
return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
def is_removable(self):
return self._removable
# ------------------------------------------------------------------------
# The original ImageGraphicsView with pan & zoom
# ------------------------------------------------------------------------
class ImageGraphicsView(PanZoomGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
self.point_items = [] # LabeledPointItem
self.full_path_points = [] # QGraphicsEllipseItems for path
self._full_path_xy = [] # entire path coords (smoothed)
self.dot_radius = 4
self.path_radius = 1
self._img_w = 0
self._img_h = 0
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
self._was_dragging = False
self._dragging_idx = None
self._drag_offset = (0, 0)
self.cost_image_original = None
self.cost_image = None
# Rainbow toggle => start with OFF
self._rainbow_enabled = False
# Smoothing parameters
self._savgol_window_length = 7
self._savgol_polyorder = 1
def set_rainbow_enabled(self, enabled: bool):
"""Enable/disable rainbow mode, then rebuild the path."""
self._rainbow_enabled = enabled
self._rebuild_full_path()
def toggle_rainbow(self):
"""Flip the rainbow mode and rebuild path."""
self._rainbow_enabled = not self._rainbow_enabled
self._rebuild_full_path()
def set_savgol_window_length(self, wlen: int):
"""
Set the window length for the Savitzky-Golay filter and update polyorder
based on window_length_polyorder_ratio.
"""
# SavGol requires window_length to be odd and >= 3
if wlen < 3:
wlen = 3
if wlen % 2 == 0:
wlen += 1
self._savgol_window_length = wlen
# polyorder is nearest integer to (window_length / 7)
# but must be >= 1 and < window_length
p = round(wlen / 7.0)
p = max(1, p)
if p >= wlen:
p = wlen - 1
self._savgol_polyorder = p
self._rebuild_full_path()
# --------------------------------------------------------------------
# LOADING
# --------------------------------------------------------------------
def load_image(self, path):
pixmap = QPixmap(path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self._clear_all_points()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
# --------------------------------------------------------------------
# ANCHOR POINTS
# --------------------------------------------------------------------
def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
"""Insert anchor at index=idx (or -1 => before E). Clamps x,y to image bounds."""
x_clamped = self._clamp(x, radius, self._img_w - radius)
y_clamped = self._clamp(y, radius, self._img_h - radius)
if idx < 0:
if len(self.anchor_points) >= 2:
idx = len(self.anchor_points) - 1
else:
idx = len(self.anchor_points)
self.anchor_points.insert(idx, (x_clamped, y_clamped))
color = Qt.green if label in ("S", "E") else Qt.red
item = LabeledPointItem(x_clamped, y_clamped,
label=label, radius=radius, color=color,
removable=removable, z_value=z_val)
self.point_items.insert(idx, item)
self.scene.addItem(item)
def _add_guide_point(self, x, y):
x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
y_clamped = self._clamp(y, self.dot_radius, self._img_w - self.dot_radius)
if not self._full_path_xy:
self._insert_anchor_point(-1, x_clamped, y_clamped,
label="", removable=True, z_val=1, radius=self.dot_radius)
else:
self._insert_anchor_between_subpath(x_clamped, y_clamped)
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def _insert_anchor_between_subpath(self, x_new, y_new):
if not self._full_path_xy:
self._insert_anchor_point(-1, x_new, y_new)
return
best_idx = None
best_d2 = float('inf')
for i, (px, py) in enumerate(self._full_path_xy):
dx = px - x_new
dy = py - y_new
d2 = dx*dx + dy*dy
if d2 < best_d2:
best_d2 = d2
best_idx = i
if best_idx is None:
self._insert_anchor_point(-1, x_new, y_new)
return
def approx_equal(xa, ya, xb, yb, tol=1e-3):
return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)
def is_anchor(coord):
cx, cy = coord
for (ax, ay) in self.anchor_points:
if approx_equal(ax, ay, cx, cy):
return True
return False
left_anchor_pt = None
iL = best_idx
while iL >= 0:
px, py = self._full_path_xy[iL]
if is_anchor((px, py)):
left_anchor_pt = (px, py)
break
iL -= 1
right_anchor_pt = None
iR = best_idx
while iR < len(self._full_path_xy):
px, py = self._full_path_xy[iR]
if is_anchor((px, py)):
right_anchor_pt = (px, py)
break
iR += 1
if not left_anchor_pt or not right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
if left_anchor_pt == right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
left_idx = None
right_idx = None
for i, (ax, ay) in enumerate(self.anchor_points):
if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
left_idx = i
if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
right_idx = i
if left_idx is None or right_idx is None:
self._insert_anchor_point(-1, x_new, y_new)
return
if left_idx < right_idx:
insert_idx = left_idx + 1
else:
insert_idx = right_idx + 1
self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
z_val=1, radius=self.dot_radius)
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
# --------------------------------------------------------------------
# COST IMAGE
# --------------------------------------------------------------------
def _revert_cost_to_original(self):
if self.cost_image_original is not None:
self.cost_image = self.cost_image_original.copy()
def _apply_all_guide_points_to_cost(self):
if self.cost_image is None:
return
for i, (ax, ay) in enumerate(self.anchor_points):
if self.point_items[i].is_removable():
self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
def _lower_cost_in_circle(self, x_f, y_f, radius):
if self.cost_image is None:
return
h, w = self.cost_image.shape
row_c = int(round(y_f))
col_c = int(round(x_f))
if not (0 <= row_c < h and 0 <= col_c < w):
return
global_min = self.cost_image.min()
r_s = max(0, row_c - radius)
r_e = min(h, row_c + radius + 1)
c_s = max(0, col_c - radius)
c_e = min(w, col_c + radius + 1)
for rr in range(r_s, r_e):
for cc in range(c_s, c_e):
dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
if dist <= radius:
self.cost_image[rr, cc] = global_min
# --------------------------------------------------------------------
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
if len(self.anchor_points) < 2 or self.cost_image is None:
return
big_xy = []
for i in range(len(self.anchor_points) - 1):
xB, yB = self.anchor_points[i + 1]
sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
if i == 0:
big_xy.extend(sub_xy)
else:
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
smoothed = savgol_filter(
arr_xy,
window_length=self._savgol_window_length,
polyorder=self._savgol_polyorder,
axis=0
)
n_points = len(big_xy)
for i, (px, py) in enumerate(big_xy):
fraction = i / (n_points - 1) if n_points > 1 else 0
if self._rainbow_enabled:
color = self._rainbow_color(fraction)
else:
color = Qt.red
path_item = LabeledPointItem(px, py, label="",
radius=self.path_radius,
color=color,
removable=False,
z_value=0)
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA, yA, xB, yB):
if self.cost_image is None:
return []
h, w = self.cost_image.shape
rA, cA = int(round(yA)), int(round(xA))
rB, cB = int(round(yB)), int(round(xB))
rA = max(0, min(rA, h - 1))
cA = max(0, min(cA, w - 1))
rB = max(0, min(rB, h - 1))
cB = max(0, min(cB, w - 1))
try:
path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
except ValueError as e:
print("Error in find_path:", e)
return []
return [(c, r) for (r, c) in path_rc]
def _rainbow_color(self, fraction):
saturation = 255
value = 255
return QColor.fromHsv(hue, saturation, value)
# --------------------------------------------------------------------
# --------------------------------------------------------------------
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
self._was_dragging = False
self._press_view_pos = event.pos()
idx = self._find_item_near(event.pos(), threshold=10)
if idx is not None:
self._dragging_idx = idx
self._drag_counter = 0
scene_pos = self.mapToScene(event.pos())
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
return
elif event.button() == Qt.RightButton:
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self._dragging_idx is not None:
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
self._drag_counter += 1
if self._drag_counter >= 4:
self._drag_counter = 0
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
self._rebuild_full_path()
else:
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
def mouseReleaseEvent(self, event):
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
if self._dragging_idx is not None:
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
newX, newY = self.point_items[idx].get_pos()
self.anchor_points[idx] = (newX, newY)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
if not self._was_dragging:
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
self._add_guide_point(x, y)
self._was_dragging = False
def _remove_point_by_click(self, view_pos):
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
if not self.point_items[idx].is_removable():
return
self.scene.removeItem(self.point_items[idx])
self.point_items.pop(idx)
self.anchor_points.pop(idx)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _find_item_near(self, view_pos, threshold=10):
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
for i, itm in enumerate(self.point_items):
d = itm.distance_to(x_click, y_click)
if d < min_dist:
min_dist = d
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
def _clamp(self, val, mn, mx):
return max(mn, min(val, mx))
def _clear_all_points(self):
for it in self.point_items:
self.scene.removeItem(it)
self.point_items.clear()
self.anchor_points.clear()
for p in self.full_path_points:
self.scene.removeItem(p)
self.full_path_points.clear()
def clear_guide_points(self):
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
self.scene.removeItem(self.point_items[i])
del self.point_items[i]
del self.anchor_points[i]
else:
i += 1
for it in self.full_path_points:
self.scene.removeItem(it)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def get_full_path_xy(self):
return self._full_path_xy
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
# ------------------------------------------------------------------------
# Advanced Settings Widget
# ------------------------------------------------------------------------
class AdvancedSettingsWidget(QWidget):
def __init__(self, main_window, parent=None):
super().__init__(parent)
self._main_window = main_window # to call e.g. main_window.open_circle_editor()
main_layout = QVBoxLayout()
self.setLayout(main_layout)
# A small grid for the controls (buttons/sliders)
controls_layout = QGridLayout()
# 1) Rainbow toggle
self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
self.btn_toggle_rainbow.clicked.connect(self._on_toggle_rainbow)
controls_layout.addWidget(self.btn_toggle_rainbow, 0, 0)
# 2) Circle editor
self.btn_circle_editor = QPushButton("Open Circle Editor")
self.btn_circle_editor.clicked.connect(self._main_window.open_circle_editor)
controls_layout.addWidget(self.btn_circle_editor, 0, 1)
# 3) Line smoothing slider + label
lab_smoothing = QLabel("Line smoothing (SavGol window_length)")
controls_layout.addWidget(lab_smoothing, 1, 0)
self.line_smoothing_slider = QSlider(Qt.Horizontal)
self.line_smoothing_slider.setRange(3, 51) # allow from 3 to 51
self.line_smoothing_slider.setValue(7) # default
self.line_smoothing_slider.valueChanged.connect(self._on_line_smoothing_slider)
controls_layout.addWidget(self.line_smoothing_slider, 1, 1)
# 4) Contrast slider + label
lab_contrast = QLabel("Contrast (clip_limit)")
controls_layout.addWidget(lab_contrast, 2, 0)
self.contrast_slider = QSlider(Qt.Horizontal)
self.contrast_slider.setRange(0, 100) # 0..100 => 0..1 with step of 0.01
self.contrast_slider.setValue(1) # default is 0.01
self.contrast_slider.valueChanged.connect(self._on_contrast_slider)
controls_layout.addWidget(self.contrast_slider, 2, 1)
main_layout.addLayout(controls_layout)
# Now a horizontal layout for the two images
images_layout = QHBoxLayout()
# Contrasted blurred
self.label_contrasted_blurred = QLabel()
self.label_contrasted_blurred.setText("CONTRASTED BLURRED IMG")
self.label_contrasted_blurred.setAlignment(Qt.AlignCenter)
self.label_contrasted_blurred.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.label_contrasted_blurred.setScaledContents(True)
images_layout.addWidget(self.label_contrasted_blurred)
# Cost image
self.label_cost_image = QLabel()
self.label_cost_image.setText("Current COST IMAGE")
self.label_cost_image.setAlignment(Qt.AlignCenter)
self.label_cost_image.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.label_cost_image.setScaledContents(True)
images_layout.addWidget(self.label_cost_image)
main_layout.addLayout(images_layout)
def _on_toggle_rainbow(self):
self._main_window.toggle_rainbow()
def _on_line_smoothing_slider(self, value):
self._main_window.image_view.set_savgol_window_length(value)
def _on_contrast_slider(self, value):
clip_limit = value / 100.0
self._main_window.update_contrast(clip_limit)
def update_displays(self, contrasted_img_np, cost_img_np):
"""
Called by main_window to refresh the two images in the advanced panel.
contrasted_img_np = the grayscaled blurred+contrasted image as float or 0-1 array
cost_img_np = the current cost image (numpy array)
"""
cb_pix = self._np_array_to_qpixmap(contrasted_img_np)
cost_pix = self._np_array_to_qpixmap(cost_img_np, normalize=True)
if cb_pix is not None:
self.label_contrasted_blurred.setPixmap(cb_pix)
if cost_pix is not None:
self.label_cost_image.setPixmap(cost_pix)
def _np_array_to_qpixmap(self, arr, normalize=False):
if arr is None:
return None
arr_ = arr.copy()
if normalize:
mn, mx = arr_.min(), arr_.max()
if abs(mx - mn) < 1e-12:
arr_[:] = 0
else:
arr_ = (arr_ - mn) / (mx - mn)
arr_ = np.clip(arr_, 0, 1)
arr_255 = (arr_ * 255).astype(np.uint8)
h, w = arr_255.shape
qimage = QImage(arr_255.data, w, h, w, QImage.Format_Grayscale8)
return QPixmap.fromImage(qimage)
# ------------------------------------------------------------------------
# Main Window
# ------------------------------------------------------------------------
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Test GUI")
self._circle_calibrated_radius = 6
self._last_loaded_file_path = None
# For the contrast slider
self._current_clip_limit = 0.01
# Outer widget + layout
self._main_layout = QHBoxLayout(self._main_widget) # horizontal so we can place advanced on the right
self._left_panel = QVBoxLayout() # for the image & row of buttons
self._main_layout.addLayout(self._left_panel)
self.setCentralWidget(self._main_widget)
btn_layout = QHBoxLayout()
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
self.btn_export_path = QPushButton("Export Path")
self.btn_export_path.clicked.connect(self.export_path)
btn_layout.addWidget(self.btn_export_path)
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
# "Advanced Settings" button
self.btn_advanced = QPushButton("Advanced Settings")
self.btn_advanced.setCheckable(True)
self.btn_advanced.clicked.connect(self._toggle_advanced_settings)
btn_layout.addWidget(self.btn_advanced)
# Create advanced settings widget (hidden by default)
self._advanced_widget = AdvancedSettingsWidget(self)
self._advanced_widget.hide()
self._main_layout.addWidget(self._advanced_widget)
self._old_central_widget = None
self._editor = None
def _toggle_advanced_settings(self, checked):
if checked:
self._advanced_widget.show()
else:
self._advanced_widget.hide()
# Ask Qt to re-layout the window so it can expand/shrink as needed:
self.adjustSize()
def open_circle_editor(self):
"""Removes the current central widget, replaces with circle editor."""
if not self._last_loaded_pixmap:
print("No image loaded yet! Cannot open circle editor.")
return
old_widget = self.takeCentralWidget()
self._old_central_widget = old_widget
editor = CircleEditorWidget(
pixmap=self._last_loaded_pixmap,
init_radius=init_radius,
done_callback=self._on_circle_editor_done
)
self._editor = editor
self.setCentralWidget(editor)
def _on_circle_editor_done(self, final_radius):
print(f"Circle Editor done. Radius = {final_radius}")
if self._last_loaded_file_path:
cost_img = compute_cost_image(
self._last_loaded_file_path,
self._circle_calibrated_radius,
clip_limit=self._current_clip_limit
)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
self.image_view._apply_all_guide_points_to_cost()
self.image_view._rebuild_full_path()
self._update_advanced_images()
editor_widget = self.takeCentralWidget()
if editor_widget is not None:
editor_widget.setParent(None)
if self._old_central_widget is not None:
self.setCentralWidget(self._old_central_widget)
self._old_central_widget = None