Skip to content
Snippets Groups Projects
GUI_draft_live.py 40.8 KiB
Newer Older
import sys
import math
import numpy as np

# For smoothing the path
from scipy.signal import savgol_filter

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
    QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
Christian's avatar
Christian committed
    QSlider, QLabel, QCheckBox, QGridLayout, QSizePolicy
Christian's avatar
Christian committed
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont, QImage
from PyQt5.QtCore import Qt, QRectF, QSize
Christian's avatar
Christian committed
# live_wire.py must contain something like:
Christian's avatar
Christian committed
#   from skimage import exposure
#   from skimage.filters import gaussian
#   def preprocess_image(image, sigma=3, clip_limit=0.01): ...
Christian's avatar
Christian committed
#   def compute_cost_image(path, user_radius, sigma=3, clip_limit=0.01): ...
#   def find_path(cost_image, points): ...
#   ...
Christian's avatar
Christian committed
from live_wire import compute_cost_image, find_path, preprocess_image
# ------------------------------------------------------------------------
# A pan & zoom QGraphicsView
# ------------------------------------------------------------------------
class PanZoomGraphicsView(QGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setDragMode(QGraphicsView.NoDrag)  # We'll handle panning manually
        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
        self._panning = False
        self._pan_start = None

Christian's avatar
Christian committed
        # Let it expand in layouts
        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

    def wheelEvent(self, event):
Christian's avatar
Christian committed
        """ Zoom in/out with mouse wheel. """
        zoom_in_factor = 1.25
        zoom_out_factor = 1 / zoom_in_factor
        if event.angleDelta().y() > 0:
            self.scale(zoom_in_factor, zoom_in_factor)
        else:
            self.scale(zoom_out_factor, zoom_out_factor)
        event.accept()

    def mousePressEvent(self, event):
Christian's avatar
Christian committed
        """ If left button: Start panning (unless overridden). """
        if event.button() == Qt.LeftButton:
            self._panning = True
            self._pan_start = event.pos()
            self.setCursor(Qt.ClosedHandCursor)
        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
Christian's avatar
Christian committed
        """ If panning, translate the scene. """
        if self._panning and self._pan_start is not None:
            delta = event.pos() - self._pan_start
            self._pan_start = event.pos()
            self.translate(delta.x(), delta.y())
        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
Christian's avatar
Christian committed
        """ End panning. """
        if event.button() == Qt.LeftButton:
            self._panning = False
            self.setCursor(Qt.ArrowCursor)
        super().mouseReleaseEvent(event)


# ------------------------------------------------------------------------
# A specialized PanZoomGraphicsView for the circle editor
# ------------------------------------------------------------------------
class CircleEditorGraphicsView(PanZoomGraphicsView):
    def __init__(self, circle_editor_widget, parent=None):
        super().__init__(parent)
        self._circle_editor_widget = circle_editor_widget

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
Christian's avatar
Christian committed
            # Check if user clicked on the circle item
            clicked_item = self.itemAt(event.pos())
            if clicked_item is not None:
Christian's avatar
Christian committed
                # climb up parent chain
                it = clicked_item
                while it is not None and not hasattr(it, "boundingRect"):
                    it = it.parentItem()
                if isinstance(it, DraggableCircleItem):
Christian's avatar
Christian committed
                    # Let normal item-dragging occur, no pan
                    return QGraphicsView.mousePressEvent(self, event)
        super().mousePressEvent(event)

    def wheelEvent(self, event):
        """
Christian's avatar
Christian committed
        If the mouse is hovering over the circle, we adjust the circle's radius
        instead of zooming the image.
        """
        pos_in_widget = event.pos()
        item_under = self.itemAt(pos_in_widget)
        if item_under is not None:
            it = item_under
            while it is not None and not hasattr(it, "boundingRect"):
                it = it.parentItem()

            if isinstance(it, DraggableCircleItem):
                delta = event.angleDelta().y()
                step = 1 if delta > 0 else -1
                old_r = it.radius()
                new_r = max(1, old_r + step)
                it.set_radius(new_r)
                self._circle_editor_widget.update_slider_value(new_r)
                event.accept()
                return

        super().wheelEvent(event)


# ------------------------------------------------------------------------
# Draggable circle item (centered at (x, y) with radius)
# ------------------------------------------------------------------------
class DraggableCircleItem(QGraphicsEllipseItem):
    def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
        super().__init__(0, 0, 2*radius, 2*radius, parent)
        self._r = radius

        pen = QPen(color)
        brush = QBrush(color)
        self.setPen(pen)
        self.setBrush(brush)

        # Enable item-based dragging
        self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
                      QGraphicsEllipseItem.ItemIsSelectable |
                      QGraphicsEllipseItem.ItemSendsScenePositionChanges)

        # Position so that (x, y) is the center
        self.setPos(x - radius, y - radius)

    def set_radius(self, r):
        old_center = self.sceneBoundingRect().center()
        self._r = r
        self.setRect(0, 0, 2*r, 2*r)
        new_center = self.sceneBoundingRect().center()
        diff_x = old_center.x() - new_center.x()
        diff_y = old_center.y() - new_center.y()
        self.moveBy(diff_x, diff_y)

    def radius(self):
        return self._r


# ------------------------------------------------------------------------
# Circle editor widget with slider + done
# ------------------------------------------------------------------------
class CircleEditorWidget(QWidget):
    def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
        super().__init__(parent)
        self._pixmap = pixmap
        self._done_callback = done_callback
        self._init_radius = init_radius

        layout = QVBoxLayout(self)
        self.setLayout(layout)

        #
        # 1) ADD A CENTERED LABEL ABOVE THE IMAGE, WITH BIGGER FONT
        #
        label_instructions = QLabel("Scale the dot to be of the size of your ridge")
        label_instructions.setAlignment(Qt.AlignCenter)
        big_font = QFont("Arial", 20)
        big_font.setBold(True)
        label_instructions.setFont(big_font)
        layout.addWidget(label_instructions)

        #
        # 2) THE SPECIALIZED GRAPHICS VIEW THAT SHOWS THE IMAGE
        #
        self._graphics_view = CircleEditorGraphicsView(circle_editor_widget=self)
        self._scene = QGraphicsScene(self)
        self._graphics_view.setScene(self._scene)
        layout.addWidget(self._graphics_view)

        self._image_item = QGraphicsPixmapItem(self._pixmap)
        self._scene.addItem(self._image_item)

        # Put circle in center
        cx = self._pixmap.width() / 2
        cy = self._pixmap.height() / 2
        self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
        self._scene.addItem(self._circle_item)

        # Fit in view
        self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
        self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)
        bottom_layout = QHBoxLayout()
        layout.addLayout(bottom_layout)

Christian's avatar
Christian committed
        self._lbl_size = QLabel(f"size ({self._init_radius})")
        bottom_layout.addWidget(self._lbl_size)

        self._slider = QSlider(Qt.Horizontal)
        self._slider.setRange(1, 200)
        self._slider.setValue(self._init_radius)
        bottom_layout.addWidget(self._slider)

        self._btn_done = QPushButton("Done")
        bottom_layout.addWidget(self._btn_done)

        # Connect signals
        self._slider.valueChanged.connect(self._on_slider_changed)
        self._btn_done.clicked.connect(self._on_done_clicked)

    def _on_slider_changed(self, value):
        self._circle_item.set_radius(value)
Christian's avatar
Christian committed
        self._lbl_size.setText(f"size ({value})")

    def _on_done_clicked(self):
        final_radius = self._circle_item.radius()
        if self._done_callback is not None:
            self._done_callback(final_radius)

    def update_slider_value(self, new_radius):
Christian's avatar
Christian committed
        self._slider.blockSignals(True)
        self._slider.setValue(new_radius)
        self._slider.blockSignals(False)
Christian's avatar
Christian committed
        self._lbl_size.setText(f"size ({new_radius})")
    def sizeHint(self):
        return QSize(800, 600)


# ------------------------------------------------------------------------
# Labeled point item
# ------------------------------------------------------------------------
class LabeledPointItem(QGraphicsEllipseItem):
    def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
        super().__init__(0, 0, 2*radius, 2*radius, parent)
        self._x = x
        self._y = y
        self._r = radius
        self._removable = removable

        pen = QPen(color)
        brush = QBrush(color)
        self.setPen(pen)
        self.setBrush(brush)
        self.setZValue(z_value)

        self._text_item = None
        if label:
            self._text_item = QGraphicsTextItem(self)
            self._text_item.setPlainText(label)
            self._text_item.setDefaultTextColor(QColor("black"))
            font = QFont("Arial", 14)
            font.setBold(True)
            self._text_item.setFont(font)
            self._scale_text_to_fit()

        self.set_pos(x, y)

    def _scale_text_to_fit(self):
        if not self._text_item:
            return
        self._text_item.setScale(1.0)
        circle_diam = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        text_w = raw_rect.width()
        text_h = raw_rect.height()
        if text_w > circle_diam or text_h > circle_diam:
            scale_factor = min(circle_diam / text_w, circle_diam / text_h)
            self._text_item.setScale(scale_factor)
        self._center_label()

    def _center_label(self):
        if not self._text_item:
            return
        ellipse_w = 2 * self._r
        ellipse_h = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        scale_factor = self._text_item.scale()
        scaled_w = raw_rect.width() * scale_factor
        scaled_h = raw_rect.height() * scale_factor
        tx = (ellipse_w - scaled_w) * 0.5
        ty = (ellipse_h - scaled_h) * 0.5
        self._text_item.setPos(tx, ty)

    def set_pos(self, x, y):
Christian's avatar
Christian committed
        """Positions the circle so its center is at (x, y)."""
        self._x = x
        self._y = y
        self.setPos(x - self._r, y - self._r)

    def get_pos(self):
        return (self._x, self._y)

    def distance_to(self, x_other, y_other):
        return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)

    def is_removable(self):
        return self._removable


# ------------------------------------------------------------------------
# The original ImageGraphicsView with pan & zoom
# ------------------------------------------------------------------------
class ImageGraphicsView(PanZoomGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.scene = QGraphicsScene(self)
        self.setScene(self.scene)

Christian's avatar
Christian committed
        # Image display
        self.image_item = QGraphicsPixmapItem()
        self.scene.addItem(self.image_item)

Christian's avatar
Christian committed
        self.anchor_points = []    # List[(x, y)]
        self.point_items = []      # LabeledPointItem
        self.full_path_points = [] # QGraphicsEllipseItems for path
        self._full_path_xy = []    # entire path coords (smoothed)

        self.dot_radius = 4
        self.path_radius = 1
Christian's avatar
Christian committed
        self.radius_cost_image = 2
        self._img_w = 0
        self._img_h = 0

        self._mouse_pressed = False
        self._press_view_pos = None
        self._drag_threshold = 5
        self._was_dragging = False
        self._dragging_idx = None
        self._drag_offset = (0, 0)
Christian's avatar
Christian committed
        self._drag_counter = 0
Christian's avatar
Christian committed
        # Cost images
        self.cost_image_original = None
        self.cost_image = None

Christian's avatar
Christian committed
        # Rainbow toggle => start with OFF
        self._rainbow_enabled = False

        # Smoothing parameters
        self._savgol_window_length = 7

    def set_rainbow_enabled(self, enabled: bool):
        self._rainbow_enabled = enabled
        self._rebuild_full_path()

    def toggle_rainbow(self):
        self._rainbow_enabled = not self._rainbow_enabled
        self._rebuild_full_path()

Christian's avatar
Christian committed
    def set_savgol_window_length(self, wlen: int):
        if wlen < 3:
            wlen = 3
        if wlen % 2 == 0:
            wlen += 1
        self._savgol_window_length = wlen

        self._rebuild_full_path()

    # --------------------------------------------------------------------
    # LOADING
    # --------------------------------------------------------------------
    def load_image(self, path):
        pixmap = QPixmap(path)
        if not pixmap.isNull():
            self.image_item.setPixmap(pixmap)
            self.setSceneRect(QRectF(pixmap.rect()))

            self._img_w = pixmap.width()
            self._img_h = pixmap.height()

            self._clear_all_points()
            self.resetTransform()
            self.fitInView(self.image_item, Qt.KeepAspectRatio)

Christian's avatar
Christian committed
            # By default, add S/E
            s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
            e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
            self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
            self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)

    # --------------------------------------------------------------------
    # ANCHOR POINTS
    # --------------------------------------------------------------------
    def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
        x_clamped = self._clamp(x, radius, self._img_w - radius)
        y_clamped = self._clamp(y, radius, self._img_h - radius)

        if idx < 0:
Christian's avatar
Christian committed
            # Insert before E if there's at least 2 anchors
            if len(self.anchor_points) >= 2:
                idx = len(self.anchor_points) - 1
            else:
                idx = len(self.anchor_points)

        self.anchor_points.insert(idx, (x_clamped, y_clamped))
        color = Qt.green if label in ("S", "E") else Qt.red
Christian's avatar
Christian committed
        item = LabeledPointItem(x_clamped, y_clamped,
                                label=label, radius=radius, color=color,
                                removable=removable, z_value=z_val)
        self.point_items.insert(idx, item)
        self.scene.addItem(item)

    def _add_guide_point(self, x, y):
Christian's avatar
Christian committed
        # Ensure we clamp properly
        x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
Christian's avatar
Christian committed
        y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
        self._revert_cost_to_original()
Christian's avatar
Christian committed
        if not self._full_path_xy:
            self._insert_anchor_point(-1, x_clamped, y_clamped,
                                      label="", removable=True, z_val=1, radius=self.dot_radius)
        else:
            self._insert_anchor_between_subpath(x_clamped, y_clamped)
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

Christian's avatar
Christian committed
    def _insert_anchor_between_subpath(self, x_new, y_new):
Christian's avatar
Christian committed
        # If somehow we have no path yet
Christian's avatar
Christian committed
        if not self._full_path_xy:
            self._insert_anchor_point(-1, x_new, y_new)
            return

Christian's avatar
Christian committed
        # Find nearest point in the current full path
Christian's avatar
Christian committed
        best_idx = None
        best_d2 = float('inf')
        for i, (px, py) in enumerate(self._full_path_xy):
            dx = px - x_new
            dy = py - y_new
            d2 = dx*dx + dy*dy
            if d2 < best_d2:
                best_d2 = d2
                best_idx = i

        if best_idx is None:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        def approx_equal(xa, ya, xb, yb, tol=1e-3):
            return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)

        def is_anchor(coord):
            cx, cy = coord
            for (ax, ay) in self.anchor_points:
                if approx_equal(ax, ay, cx, cy):
                    return True
            return False

        # Walk left
Christian's avatar
Christian committed
        left_anchor_pt = None
        iL = best_idx
        while iL >= 0:
            px, py = self._full_path_xy[iL]
            if is_anchor((px, py)):
                left_anchor_pt = (px, py)
                break
            iL -= 1

        # Walk right
Christian's avatar
Christian committed
        iR = best_idx
        while iR < len(self._full_path_xy):
            px, py = self._full_path_xy[iR]
            if is_anchor((px, py)):
                right_anchor_pt = (px, py)
                break
            iR += 1

Christian's avatar
Christian committed
        # If we can't find distinct anchors on left & right,
        # just insert before E.
Christian's avatar
Christian committed
        if not left_anchor_pt or not right_anchor_pt:
            self._insert_anchor_point(-1, x_new, y_new)
            return
        if left_anchor_pt == right_anchor_pt:
            self._insert_anchor_point(-1, x_new, y_new)
            return

Christian's avatar
Christian committed
        # Convert anchor coords -> anchor_points indices
Christian's avatar
Christian committed
        left_idx = None
        right_idx = None
        for i, (ax, ay) in enumerate(self.anchor_points):
            if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
                left_idx = i
            if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
                right_idx = i

        if left_idx is None or right_idx is None:
            self._insert_anchor_point(-1, x_new, y_new)
            return

Christian's avatar
Christian committed
        # Insert between them
Christian's avatar
Christian committed
        if left_idx < right_idx:
            insert_idx = left_idx + 1
        else:
Christian's avatar
Christian committed

        self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
                                  z_val=1, radius=self.dot_radius)

    # --------------------------------------------------------------------
    # COST IMAGE
    # --------------------------------------------------------------------
    def _revert_cost_to_original(self):
        if self.cost_image_original is not None:
            self.cost_image = self.cost_image_original.copy()

    def _apply_all_guide_points_to_cost(self):
        if self.cost_image is None:
            return
        for i, (ax, ay) in enumerate(self.anchor_points):
            if self.point_items[i].is_removable():
                self._lower_cost_in_circle(ax, ay, self.radius_cost_image)

    def _lower_cost_in_circle(self, x_f, y_f, radius):
        if self.cost_image is None:
            return
        h, w = self.cost_image.shape
        row_c = int(round(y_f))
        col_c = int(round(x_f))
        if not (0 <= row_c < h and 0 <= col_c < w):
            return
        global_min = self.cost_image.min()
        r_s = max(0, row_c - radius)
        r_e = min(h, row_c + radius + 1)
        c_s = max(0, col_c - radius)
        c_e = min(w, col_c + radius + 1)
        for rr in range(r_s, r_e):
            for cc in range(c_s, c_e):
                dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                if dist <= radius:
                    self.cost_image[rr, cc] = global_min

    # --------------------------------------------------------------------
    # PATH BUILDING
    # --------------------------------------------------------------------
    def _rebuild_full_path(self):
        for item in self.full_path_points:
            self.scene.removeItem(item)
        self.full_path_points.clear()
Christian's avatar
Christian committed
        self._full_path_xy.clear()

        if len(self.anchor_points) < 2 or self.cost_image is None:
            return

        big_xy = []
        for i in range(len(self.anchor_points) - 1):
            xA, yA = self.anchor_points[i]
            xB, yB = self.anchor_points[i + 1]
            sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
            if i == 0:
                big_xy.extend(sub_xy)
            else:
                if len(sub_xy) > 1:
                    big_xy.extend(sub_xy[1:])

Christian's avatar
Christian committed
        if len(big_xy) >= self._savgol_window_length:
Christian's avatar
Christian committed
            arr_xy = np.array(big_xy)
Christian's avatar
Christian committed
            smoothed = savgol_filter(
                arr_xy,
                window_length=self._savgol_window_length,
Christian's avatar
Christian committed
                axis=0
            )
            big_xy = smoothed.tolist()

Christian's avatar
Christian committed
        self._full_path_xy = big_xy[:]

        n_points = len(big_xy)
        for i, (px, py) in enumerate(big_xy):
            fraction = i / (n_points - 1) if n_points > 1 else 0
Christian's avatar
Christian committed
            color = Qt.red
            if self._rainbow_enabled:
                color = self._rainbow_color(fraction)

            path_item = LabeledPointItem(px, py, label="",
                                         radius=self.path_radius,
                                         color=color,
                                         removable=False,
                                         z_value=0)
            self.full_path_points.append(path_item)
            self.scene.addItem(path_item)

        # Keep anchor labels on top
        for p_item in self.point_items:
            if p_item._text_item:
                p_item.setZValue(100)

    def _compute_subpath_xy(self, xA, yA, xB, yB):
        if self.cost_image is None:
            return []
        h, w = self.cost_image.shape
        rA, cA = int(round(yA)), int(round(xA))
        rB, cB = int(round(yB)), int(round(xB))
        rA = max(0, min(rA, h - 1))
        cA = max(0, min(cA, w - 1))
        rB = max(0, min(rB, h - 1))
        cB = max(0, min(cB, w - 1))
        try:
            path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
        except ValueError as e:
            print("Error in find_path:", e)
            return []
Christian's avatar
Christian committed
        # Convert from (row, col) to (x, y)
        return [(c, r) for (r, c) in path_rc]

    def _rainbow_color(self, fraction):
        hue = int(300 * fraction)
        saturation = 255
        value = 255
        return QColor.fromHsv(hue, saturation, value)

    # --------------------------------------------------------------------
Christian's avatar
Christian committed
    # MOUSE EVENTS
    # --------------------------------------------------------------------
    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self._mouse_pressed = True
            self._was_dragging = False
            self._press_view_pos = event.pos()

            idx = self._find_item_near(event.pos(), threshold=10)
            if idx is not None:
                self._dragging_idx = idx
                self._drag_counter = 0
                scene_pos = self.mapToScene(event.pos())
                px, py = self.point_items[idx].get_pos()
                self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                self.setCursor(Qt.ClosedHandCursor)

        elif event.button() == Qt.RightButton:
            self._remove_point_by_click(event.pos())

        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        if self._dragging_idx is not None:
            scene_pos = self.mapToScene(event.pos())
            x_new = scene_pos.x() - self._drag_offset[0]
            y_new = scene_pos.y() - self._drag_offset[1]
            r = self.point_items[self._dragging_idx]._r
            x_clamped = self._clamp(x_new, r, self._img_w - r)
            y_clamped = self._clamp(y_new, r, self._img_h - r)
            self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)

            self._drag_counter += 1
Christian's avatar
Christian committed
            # Update path every 4 moves
            if self._drag_counter >= 4:
                self._drag_counter = 0
                self._revert_cost_to_original()
                self._apply_all_guide_points_to_cost()
                self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                self._rebuild_full_path()
        else:
            if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                dist = (event.pos() - self._press_view_pos).manhattanLength()
                if dist > self._drag_threshold:
                    self._was_dragging = True
        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        super().mouseReleaseEvent(event)
        if event.button() == Qt.LeftButton and self._mouse_pressed:
            self._mouse_pressed = False
            self.setCursor(Qt.ArrowCursor)
            if self._dragging_idx is not None:
                idx = self._dragging_idx
                self._dragging_idx = None
                self._drag_offset = (0, 0)
                newX, newY = self.point_items[idx].get_pos()
                self.anchor_points[idx] = (newX, newY)
                self._revert_cost_to_original()
                self._apply_all_guide_points_to_cost()
                self._rebuild_full_path()
            else:
                # No drag => add point
                    scene_pos = self.mapToScene(event.pos())
                    x, y = scene_pos.x(), scene_pos.y()
                    self._add_guide_point(x, y)

            self._was_dragging = False

    def _remove_point_by_click(self, view_pos):
        idx = self._find_item_near(view_pos, threshold=10)
        if idx is None:
            return
        if not self.point_items[idx].is_removable():

        self.scene.removeItem(self.point_items[idx])
        self.point_items.pop(idx)
        self.anchor_points.pop(idx)

        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

    def _find_item_near(self, view_pos, threshold=10):
        scene_pos = self.mapToScene(view_pos)
        x_click, y_click = scene_pos.x(), scene_pos.y()
        closest_idx = None
Christian's avatar
Christian committed
        min_dist = float('inf')
        for i, itm in enumerate(self.point_items):
            d = itm.distance_to(x_click, y_click)
            if d < min_dist:
                min_dist = d
                closest_idx = i
        if closest_idx is not None and min_dist <= threshold:
            return closest_idx
        return None

    # --------------------------------------------------------------------
    # UTILS
    # --------------------------------------------------------------------
    def _clamp(self, val, mn, mx):
        return max(mn, min(val, mx))

    def _clear_all_points(self):
        for it in self.point_items:
            self.scene.removeItem(it)
        self.point_items.clear()
        self.anchor_points.clear()

        for p in self.full_path_points:
            self.scene.removeItem(p)
        self.full_path_points.clear()
Christian's avatar
Christian committed
        self._full_path_xy.clear()

    def clear_guide_points(self):
        i = 0
        while i < len(self.anchor_points):
            if self.point_items[i].is_removable():
                self.scene.removeItem(self.point_items[i])
                del self.point_items[i]
                del self.anchor_points[i]
            else:
                i += 1

Christian's avatar
Christian committed
        for it in self.full_path_points:
            self.scene.removeItem(it)
        self.full_path_points.clear()
Christian's avatar
Christian committed
        self._full_path_xy.clear()

        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

Christian's avatar
Christian committed
    def get_full_path_xy(self):
        return self._full_path_xy

Christian's avatar
Christian committed
# ------------------------------------------------------------------------
# Advanced Settings Widget
# ------------------------------------------------------------------------
class AdvancedSettingsWidget(QWidget):
Christian's avatar
Christian committed
    """
    Shows toggle rainbow, circle editor, line smoothing slider, contrast slider,
    plus two image previews (contrasted-blurred and cost).
    The images should maintain aspect ratio upon resize.
    
    Now displays the images stacked vertically with labels above them.
    """
Christian's avatar
Christian committed
    def __init__(self, main_window, parent=None):
        super().__init__(parent)
Christian's avatar
Christian committed
        self._main_window = main_window

        self._last_cb_pix = None   # store QPixmap for contrasted-blurred
        self._last_cost_pix = None # store QPixmap for cost
Christian's avatar
Christian committed

        main_layout = QVBoxLayout()
        self.setLayout(main_layout)

Christian's avatar
Christian committed
        # A small grid for controls
Christian's avatar
Christian committed
        controls_layout = QGridLayout()

        # 1) Rainbow toggle
        self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
        self.btn_toggle_rainbow.clicked.connect(self._on_toggle_rainbow)
        controls_layout.addWidget(self.btn_toggle_rainbow, 0, 0)

        # 2) Circle editor
        self.btn_circle_editor = QPushButton("Calibrate Kernel Size")
Christian's avatar
Christian committed
        self.btn_circle_editor.clicked.connect(self._main_window.open_circle_editor)
        controls_layout.addWidget(self.btn_circle_editor, 0, 1)

        # 3) Line smoothing slider + label
        self._lab_smoothing = QLabel("Line smoothing (3)")
Christian's avatar
Christian committed
        controls_layout.addWidget(self._lab_smoothing, 1, 0)
Christian's avatar
Christian committed
        self.line_smoothing_slider = QSlider(Qt.Horizontal)
Christian's avatar
Christian committed
        self.line_smoothing_slider.setRange(3, 51)
        self.line_smoothing_slider.setValue(3)
Christian's avatar
Christian committed
        self.line_smoothing_slider.valueChanged.connect(self._on_line_smoothing_slider)
        controls_layout.addWidget(self.line_smoothing_slider, 1, 1)

        # 4) Contrast slider + label
Christian's avatar
Christian committed
        self._lab_contrast = QLabel("Contrast (0.01)")
        controls_layout.addWidget(self._lab_contrast, 2, 0)
Christian's avatar
Christian committed
        self.contrast_slider = QSlider(Qt.Horizontal)
        self.contrast_slider.setRange(1, 20)
Christian's avatar
Christian committed
        self.contrast_slider.setValue(1)  # i.e. 0.01
        self.contrast_slider.setSingleStep(1)
Christian's avatar
Christian committed
        self.contrast_slider.valueChanged.connect(self._on_contrast_slider)
        controls_layout.addWidget(self.contrast_slider, 2, 1)

        main_layout.addLayout(controls_layout)

Christian's avatar
Christian committed
        # We'll set a minimum width so that the main window expands
        # rather than overlapping the image
        self.setMinimumWidth(350)

        # Now a vertical layout for the two images, each with a label above it
        images_layout = QVBoxLayout()

        # 1) Contrasted-blurred label + image
        self.label_cb_title = QLabel("Contrasted Blurred Image")
        self.label_cb_title.setAlignment(Qt.AlignCenter)
        images_layout.addWidget(self.label_cb_title)
Christian's avatar
Christian committed

        self.label_contrasted_blurred = QLabel()
        self.label_contrasted_blurred.setAlignment(Qt.AlignCenter)
        self.label_contrasted_blurred.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        images_layout.addWidget(self.label_contrasted_blurred)

Christian's avatar
Christian committed
        # 2) Cost image label + image
        self.label_cost_title = QLabel("Current COST IMAGE")
        self.label_cost_title.setAlignment(Qt.AlignCenter)
        images_layout.addWidget(self.label_cost_title)

Christian's avatar
Christian committed
        self.label_cost_image = QLabel()
        self.label_cost_image.setAlignment(Qt.AlignCenter)
        self.label_cost_image.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        images_layout.addWidget(self.label_cost_image)

        main_layout.addLayout(images_layout)

Christian's avatar
Christian committed
    def showEvent(self, event):
        """ When shown, ask parent to resize to accommodate. """
        super().showEvent(event)
        if self.parentWidget():
            self.parentWidget().adjustSize()

    def resizeEvent(self, event):
        """
        Keep the images at correct aspect ratio by re-scaling
        our stored pixmaps to the new label sizes.
        """
        super().resizeEvent(event)
        self._update_labels()

    def _update_labels(self):
        if self._last_cb_pix is not None:
            scaled_cb = self._last_cb_pix.scaled(
                self.label_contrasted_blurred.size(),
                Qt.KeepAspectRatio,
                Qt.SmoothTransformation
            )
            self.label_contrasted_blurred.setPixmap(scaled_cb)

        if self._last_cost_pix is not None:
            scaled_cost = self._last_cost_pix.scaled(
                self.label_cost_image.size(),
                Qt.KeepAspectRatio,
                Qt.SmoothTransformation
            )
            self.label_cost_image.setPixmap(scaled_cost)

Christian's avatar
Christian committed
    def _on_toggle_rainbow(self):
        self._main_window.toggle_rainbow()

    def _on_line_smoothing_slider(self, value):
Christian's avatar
Christian committed
        self._lab_smoothing.setText(f"Line smoothing ({value})")
Christian's avatar
Christian committed
        self._main_window.image_view.set_savgol_window_length(value)

    def _on_contrast_slider(self, value):
        clip_limit = value / 100.0
Christian's avatar
Christian committed
        self._lab_contrast.setText(f"Contrast ({clip_limit:.2f})")
Christian's avatar
Christian committed
        self._main_window.update_contrast(clip_limit)

    def update_displays(self, contrasted_img_np, cost_img_np):
        """
        Called by main_window to refresh the two images in the advanced panel.
Christian's avatar
Christian committed
        We'll store them as QPixmaps, then do the re-scale in _update_labels().
Christian's avatar
Christian committed
        """
        cb_pix = self._np_array_to_qpixmap(contrasted_img_np)
        cost_pix = self._np_array_to_qpixmap(cost_img_np, normalize=True)

Christian's avatar
Christian committed
        self._last_cb_pix = cb_pix
        self._last_cost_pix = cost_pix
        self._update_labels()
Christian's avatar
Christian committed

    def _np_array_to_qpixmap(self, arr, normalize=False):
        if arr is None:
            return None
        arr_ = arr.copy()
        if normalize:
            mn, mx = arr_.min(), arr_.max()
            if abs(mx - mn) < 1e-12:
                arr_[:] = 0
            else:
                arr_ = (arr_ - mn) / (mx - mn)
        arr_ = np.clip(arr_, 0, 1)
        arr_255 = (arr_ * 255).astype(np.uint8)

        h, w = arr_255.shape
        qimage = QImage(arr_255.data, w, h, w, QImage.Format_Grayscale8)
        return QPixmap.fromImage(qimage)


# ------------------------------------------------------------------------
# Main Window
# ------------------------------------------------------------------------
class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Test GUI")

        self._last_loaded_pixmap = None
Christian's avatar
Christian committed
        self._circle_calibrated_radius = 6
        self._last_loaded_file_path = None
Christian's avatar
Christian committed
        # For the contrast slider
        self._current_clip_limit = 0.01

        # Outer widget + layout
        self._main_widget = QWidget()
Christian's avatar
Christian committed
        self._main_layout = QHBoxLayout(self._main_widget)

        # The "left" part: container for the image area + its controls
Christian's avatar
Christian committed
        self._left_panel = QVBoxLayout()

        # We'll make a container widget for the left panel, so we can set stretches:
        self._left_container = QWidget()
        self._left_container.setLayout(self._left_panel)

        # Now we add them to the main layout with 70%:30% ratio
        self._main_layout.addWidget(self._left_container, 7)  # 70%
        
        # We haven't added the advanced widget yet, but we'll do so with ratio=3 => 30%
        self._advanced_widget = AdvancedSettingsWidget(self)
        # Hide it initially
        self._advanced_widget.hide()
        self._main_layout.addWidget(self._advanced_widget, 3)

Christian's avatar
Christian committed
        self.setCentralWidget(self._main_widget)
        self.image_view = ImageGraphicsView()
Christian's avatar
Christian committed
        self._left_panel.addWidget(self.image_view)
        # Button row
        btn_layout = QHBoxLayout()
        self.btn_load_image = QPushButton("Load Image")
        self.btn_load_image.clicked.connect(self.load_image)
        btn_layout.addWidget(self.btn_load_image)

Christian's avatar
Christian committed
        self.btn_export_path = QPushButton("Export Path")
        self.btn_export_path.clicked.connect(self.export_path)
        btn_layout.addWidget(self.btn_export_path)

        self.btn_clear_points = QPushButton("Clear Points")
        self.btn_clear_points.clicked.connect(self.clear_points)
        btn_layout.addWidget(self.btn_clear_points)

Christian's avatar
Christian committed
        # "Advanced Settings" toggle
Christian's avatar
Christian committed
        self.btn_advanced = QPushButton("Advanced Settings")
        self.btn_advanced.setCheckable(True)
        self.btn_advanced.clicked.connect(self._toggle_advanced_settings)
        btn_layout.addWidget(self.btn_advanced)
Christian's avatar
Christian committed
        self._left_panel.addLayout(btn_layout)
Christian's avatar
Christian committed
        self.resize(1000, 600)
        self._old_central_widget = None
        self._editor = None

Christian's avatar
Christian committed
    def _toggle_advanced_settings(self, checked):
        if checked:
            self._advanced_widget.show()
        else:
            self._advanced_widget.hide()
Christian's avatar
Christian committed
        self.adjustSize()

    def open_circle_editor(self):
Christian's avatar
Christian committed
        """ Replace central widget with circle editor. """
        if not self._last_loaded_pixmap:
            print("No image loaded yet! Cannot open circle editor.")
            return

        old_widget = self.takeCentralWidget()
        self._old_central_widget = old_widget

Christian's avatar
Christian committed
        init_radius = self._circle_calibrated_radius
        editor = CircleEditorWidget(
            pixmap=self._last_loaded_pixmap,
            init_radius=init_radius,
            done_callback=self._on_circle_editor_done
        )
        self._editor = editor
        self.setCentralWidget(editor)

    def _on_circle_editor_done(self, final_radius):
Christian's avatar
Christian committed
        self._circle_calibrated_radius = final_radius
        print(f"Circle Editor done. Radius = {final_radius}")

Christian's avatar
Christian committed
        if self._last_loaded_file_path:
            cost_img = compute_cost_image(
                self._last_loaded_file_path,
                self._circle_calibrated_radius,
                clip_limit=self._current_clip_limit
            )
            self.image_view.cost_image_original = cost_img
            self.image_view.cost_image = cost_img.copy()
            self.image_view._apply_all_guide_points_to_cost()
            self.image_view._rebuild_full_path()
            self._update_advanced_images()

        editor_widget = self.takeCentralWidget()
        if editor_widget is not None:
            editor_widget.setParent(None)

        if self._old_central_widget is not None:
            self.setCentralWidget(self._old_central_widget)
            self._old_central_widget = None

        if self._editor is not None:
            self._editor.deleteLater()
            self._editor = None

    def toggle_rainbow(self):
        self.image_view.toggle_rainbow()

    def load_image(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Open Image", "",
            "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
            options=options
        )
        if file_path:
            self.image_view.load_image(file_path)
Christian's avatar
Christian committed

            cost_img = compute_cost_image(
                file_path,
                self._circle_calibrated_radius,
                clip_limit=self._current_clip_limit
            )
            self.image_view.cost_image_original = cost_img
            self.image_view.cost_image = cost_img.copy()

            pm = QPixmap(file_path)
            if not pm.isNull():
                self._last_loaded_pixmap = pm

Christian's avatar
Christian committed
            self._last_loaded_file_path = file_path
            self._update_advanced_images()

    def update_contrast(self, clip_limit):
        self._current_clip_limit = clip_limit
        if self._last_loaded_file_path:
            cost_img = compute_cost_image(
                self._last_loaded_file_path,
                self._circle_calibrated_radius,
                clip_limit=clip_limit
            )
            self.image_view.cost_image_original = cost_img
            self.image_view.cost_image = cost_img.copy()
            self.image_view._apply_all_guide_points_to_cost()
            self.image_view._rebuild_full_path()

        self._update_advanced_images()

    def _update_advanced_images(self):
        if not self._last_loaded_pixmap:
            return
        pm_np = self._qpixmap_to_gray_float(self._last_loaded_pixmap)
        contrasted_blurred = preprocess_image(
            pm_np,
            sigma=3,
            clip_limit=self._current_clip_limit
        )
        cost_img_np = self.image_view.cost_image
        self._advanced_widget.update_displays(contrasted_blurred, cost_img_np)

    def _qpixmap_to_gray_float(self, qpix):
        img = qpix.toImage()
        img = img.convertToFormat(QImage.Format_ARGB32)
        ptr = img.bits()
        ptr.setsize(img.byteCount())
        arr = np.frombuffer(ptr, np.uint8).reshape((img.height(), img.width(), 4))
        rgb = arr[..., :3].astype(np.float32)
        gray = rgb.mean(axis=2) / 255.0
        return gray

Christian's avatar
Christian committed
    def export_path(self):
        full_xy = self.image_view.get_full_path_xy()
        if not full_xy:
            print("No path to export.")
            return
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
Christian's avatar
Christian committed
            self, "Export Path", "",
            "NumPy Files (*.npy);;All Files (*)",
            options=options
        )
        if file_path:
Christian's avatar
Christian committed
            arr = np.array(full_xy)
            np.save(file_path, arr)
            print(f"Exported path with {len(arr)} points to {file_path}")

    def clear_points(self):
        self.image_view.clear_guide_points()

    def closeEvent(self, event):
        super().closeEvent(event)


def main():
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()