import sys
import math
import numpy as np

# For smoothing the path
from scipy.signal import savgol_filter

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
    QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem,
    QSlider, QLabel
)
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF, QSize

from live_wire import compute_cost_image, find_path


# ------------------------------------------------------------------------
# A pan & zoom QGraphicsView
# ------------------------------------------------------------------------
class PanZoomGraphicsView(QGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setDragMode(QGraphicsView.NoDrag)  # We'll handle panning manually
        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
        self._panning = False
        self._pan_start = None

    def wheelEvent(self, event):
        """
        Zoom in/out with mouse wheel.
        """
        zoom_in_factor = 1.25
        zoom_out_factor = 1 / zoom_in_factor
        if event.angleDelta().y() > 0:
            self.scale(zoom_in_factor, zoom_in_factor)
        else:
            self.scale(zoom_out_factor, zoom_out_factor)
        event.accept()

    def mousePressEvent(self, event):
        """
        If left button: Start panning (unless overridden in a subclass).
        """
        if event.button() == Qt.LeftButton:
            self._panning = True
            self._pan_start = event.pos()
            self.setCursor(Qt.ClosedHandCursor)
        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        """
        If panning, translate the scene.
        """
        if self._panning and self._pan_start is not None:
            delta = event.pos() - self._pan_start
            self._pan_start = event.pos()
            self.translate(delta.x(), delta.y())
        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        """
        End panning.
        """
        if event.button() == Qt.LeftButton:
            self._panning = False
            self.setCursor(Qt.ArrowCursor)
        super().mouseReleaseEvent(event)


# ------------------------------------------------------------------------
# A specialized PanZoomGraphicsView for the circle editor
#    - Only pan if user did NOT click on the draggable circle
#    - If the mouse is over the circle item, scrolling changes radius
# ------------------------------------------------------------------------
class CircleEditorGraphicsView(PanZoomGraphicsView):
    def __init__(self, circle_editor_widget, parent=None):
        """
        :param circle_editor_widget: Reference to the parent CircleEditorWidget
                                     so we can communicate (e.g. update slider).
        """
        super().__init__(parent)
        self._circle_editor_widget = circle_editor_widget

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            # Check if the user clicked on the circle item
            clicked_item = self.itemAt(event.pos())
            if clicked_item is not None:
                # Walk up parent chain to see if it is our DraggableCircleItem
                it = clicked_item
                while it is not None and not hasattr(it, "boundingRect"):
                    it = it.parentItem()

                if isinstance(it, DraggableCircleItem):
                    # Let normal item-dragging occur, don't initiate panning
                    return QGraphicsView.mousePressEvent(self, event)
        # Otherwise proceed with normal panning logic
        super().mousePressEvent(event)

    def wheelEvent(self, event):
        """
        Overridden so that if the mouse is hovering over the circle,
        we adjust the circle's radius instead of zooming the image.
        """
        pos_in_widget = event.pos()
        item_under = self.itemAt(pos_in_widget)
        if item_under is not None:
            # climb up the chain to find if it's our DraggableCircleItem
            it = item_under
            while it is not None and not hasattr(it, "boundingRect"):
                it = it.parentItem()

            if isinstance(it, DraggableCircleItem):
                # Scroll up -> increase radius, scroll down -> decrease
                delta = event.angleDelta().y()
                # each wheel "notch" is typically 120, so let's do small steps
                step = 1 if delta > 0 else -1

                old_r = it.radius()
                new_r = max(1, old_r + step)
                it.set_radius(new_r)

                # Also update the slider in the parent CircleEditorWidget
                self._circle_editor_widget.update_slider_value(new_r)

                event.accept()
                return

        # else do normal pan/zoom
        super().wheelEvent(event)


# ------------------------------------------------------------------------
# Draggable circle item (centered at (x, y) with radius)
# ------------------------------------------------------------------------
class DraggableCircleItem(QGraphicsEllipseItem):
    def __init__(self, x, y, radius=20, color=Qt.red, parent=None):
        super().__init__(0, 0, 2*radius, 2*radius, parent)
        self._r = radius

        pen = QPen(color)
        brush = QBrush(color)
        self.setPen(pen)
        self.setBrush(brush)

        # Enable item-based dragging
        self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
                      QGraphicsEllipseItem.ItemIsSelectable |
                      QGraphicsEllipseItem.ItemSendsScenePositionChanges)

        # Position so that (x, y) is the center
        self.setPos(x - radius, y - radius)

    def set_radius(self, r):
        # Keep the same center, just change radius
        old_center = self.sceneBoundingRect().center()
        self._r = r
        self.setRect(0, 0, 2*r, 2*r)
        new_center = self.sceneBoundingRect().center()
        diff_x = old_center.x() - new_center.x()
        diff_y = old_center.y() - new_center.y()
        self.moveBy(diff_x, diff_y)

    def radius(self):
        return self._r


# ------------------------------------------------------------------------
# Circle editor widget with slider + done
# ------------------------------------------------------------------------
class CircleEditorWidget(QWidget):
    def __init__(self, pixmap, init_radius=20, done_callback=None, parent=None):
        super().__init__(parent)
        self._pixmap = pixmap
        self._done_callback = done_callback
        self._init_radius = init_radius

        layout = QVBoxLayout(self)
        self.setLayout(layout)

        # Use specialized CircleEditorGraphicsView
        self._graphics_view = CircleEditorGraphicsView(circle_editor_widget=self)
        self._scene = QGraphicsScene(self)
        self._graphics_view.setScene(self._scene)
        layout.addWidget(self._graphics_view)

        self._image_item = QGraphicsPixmapItem(self._pixmap)
        self._scene.addItem(self._image_item)

        # Put circle in center
        cx = self._pixmap.width() / 2
        cy = self._pixmap.height() / 2
        self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red)
        self._scene.addItem(self._circle_item)

        # Fit in view
        self._graphics_view.setSceneRect(QRectF(self._pixmap.rect()))
        self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio)

        # Bottom controls (slider + done)
        bottom_layout = QHBoxLayout()
        layout.addLayout(bottom_layout)

        lbl = QLabel("size:")
        bottom_layout.addWidget(lbl)

        self._slider = QSlider(Qt.Horizontal)
        self._slider.setRange(1, 200)
        self._slider.setValue(self._init_radius)
        bottom_layout.addWidget(self._slider)

        self._btn_done = QPushButton("Done")
        bottom_layout.addWidget(self._btn_done)

        # Connect signals
        self._slider.valueChanged.connect(self._on_slider_changed)
        self._btn_done.clicked.connect(self._on_done_clicked)

    def _on_slider_changed(self, value):
        self._circle_item.set_radius(value)

    def _on_done_clicked(self):
        final_radius = self._circle_item.radius()
        if self._done_callback is not None:
            self._done_callback(final_radius)

    def update_slider_value(self, new_radius):
        """
        Called by CircleEditorGraphicsView when the user scrolls on the circle item.
        We sync the slider to the new radius.
        """
        self._slider.blockSignals(True)  # to avoid recursively calling set_radius
        self._slider.setValue(new_radius)
        self._slider.blockSignals(False)

    def sizeHint(self):
        return QSize(800, 600)


# ------------------------------------------------------------------------
# Labeled point item
# ------------------------------------------------------------------------
class LabeledPointItem(QGraphicsEllipseItem):
    def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
        super().__init__(0, 0, 2*radius, 2*radius, parent)
        self._x = x
        self._y = y
        self._r = radius
        self._removable = removable

        pen = QPen(color)
        brush = QBrush(color)
        self.setPen(pen)
        self.setBrush(brush)
        self.setZValue(z_value)

        self._text_item = None
        if label:
            self._text_item = QGraphicsTextItem(self)
            self._text_item.setPlainText(label)
            self._text_item.setDefaultTextColor(QColor("black"))
            font = QFont("Arial", 14)
            font.setBold(True)
            self._text_item.setFont(font)
            self._scale_text_to_fit()

        self.set_pos(x, y)

    def _scale_text_to_fit(self):
        if not self._text_item:
            return
        self._text_item.setScale(1.0)
        circle_diam = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        text_w = raw_rect.width()
        text_h = raw_rect.height()
        if text_w > circle_diam or text_h > circle_diam:
            scale_factor = min(circle_diam / text_w, circle_diam / text_h)
            self._text_item.setScale(scale_factor)
        self._center_label()

    def _center_label(self):
        if not self._text_item:
            return
        ellipse_w = 2 * self._r
        ellipse_h = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        scale_factor = self._text_item.scale()
        scaled_w = raw_rect.width() * scale_factor
        scaled_h = raw_rect.height() * scale_factor
        tx = (ellipse_w - scaled_w) * 0.5
        ty = (ellipse_h - scaled_h) * 0.5
        self._text_item.setPos(tx, ty)

    def set_pos(self, x, y):
        """Positions the circle so its center is at (x, y)."""
        self._x = x
        self._y = y
        self.setPos(x - self._r, y - self._r)

    def get_pos(self):
        return (self._x, self._y)

    def distance_to(self, x_other, y_other):
        return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)

    def is_removable(self):
        return self._removable


# ------------------------------------------------------------------------
# The original ImageGraphicsView with pan & zoom
# ------------------------------------------------------------------------
class ImageGraphicsView(PanZoomGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.scene = QGraphicsScene(self)
        self.setScene(self.scene)

        # Image display
        self.image_item = QGraphicsPixmapItem()
        self.scene.addItem(self.image_item)

        self.anchor_points = []    # List[(x, y)]
        self.point_items = []      # LabeledPointItem
        self.full_path_points = [] # QGraphicsEllipseItems for path
        self._full_path_xy = []    # entire path coords (smoothed)

        self.dot_radius = 4
        self.path_radius = 1
        self.radius_cost_image = 2
        self._img_w = 0
        self._img_h = 0

        self._mouse_pressed = False
        self._press_view_pos = None
        self._drag_threshold = 5
        self._was_dragging = False
        self._dragging_idx = None
        self._drag_offset = (0, 0)
        self._drag_counter = 0

        # Cost images
        self.cost_image_original = None
        self.cost_image = None

        # Rainbow toggle
        self._rainbow_enabled = True

    def set_rainbow_enabled(self, enabled: bool):
        """Enable/disable rainbow mode, then rebuild the path."""
        self._rainbow_enabled = enabled
        self._rebuild_full_path()

    def toggle_rainbow(self):
        """Flip the rainbow mode and rebuild path."""
        self._rainbow_enabled = not self._rainbow_enabled
        self._rebuild_full_path()

    # --------------------------------------------------------------------
    # LOADING
    # --------------------------------------------------------------------
    def load_image(self, path):
        pixmap = QPixmap(path)
        if not pixmap.isNull():
            self.image_item.setPixmap(pixmap)
            self.setSceneRect(QRectF(pixmap.rect()))

            self._img_w = pixmap.width()
            self._img_h = pixmap.height()

            self._clear_all_points()
            self.resetTransform()
            self.fitInView(self.image_item, Qt.KeepAspectRatio)

            # By default, add S/E
            s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
            e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
            self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
            self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)

    # --------------------------------------------------------------------
    # ANCHOR POINTS
    # --------------------------------------------------------------------
    def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
        """Insert anchor at index=idx (or -1 => before E). Clamps x,y to image bounds."""
        x_clamped = self._clamp(x, radius, self._img_w - radius)
        y_clamped = self._clamp(y, radius, self._img_h - radius)

        if idx < 0:
            if len(self.anchor_points) >= 2:
                idx = len(self.anchor_points) - 1
            else:
                idx = len(self.anchor_points)

        self.anchor_points.insert(idx, (x_clamped, y_clamped))

        color = Qt.green if label in ("S", "E") else Qt.red
        item = LabeledPointItem(x_clamped, y_clamped,
                                label=label, radius=radius, color=color,
                                removable=removable, z_value=z_val)
        self.point_items.insert(idx, item)
        self.scene.addItem(item)

    def _add_guide_point(self, x, y):
        x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
        y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)

        self._revert_cost_to_original()

        if not self._full_path_xy:
            self._insert_anchor_point(-1, x_clamped, y_clamped,
                                      label="", removable=True, z_val=1, radius=self.dot_radius)
        else:
            self._insert_anchor_between_subpath(x_clamped, y_clamped)

        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

    def _insert_anchor_between_subpath(self, x_new, y_new):
        if not self._full_path_xy:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        best_idx = None
        best_d2 = float('inf')
        for i, (px, py) in enumerate(self._full_path_xy):
            dx = px - x_new
            dy = py - y_new
            d2 = dx*dx + dy*dy
            if d2 < best_d2:
                best_d2 = d2
                best_idx = i

        if best_idx is None:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        def approx_equal(xa, ya, xb, yb, tol=1e-3):
            return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)

        def is_anchor(coord):
            cx, cy = coord
            for (ax, ay) in self.anchor_points:
                if approx_equal(ax, ay, cx, cy):
                    return True
            return False

        # Walk left
        left_anchor_pt = None
        iL = best_idx
        while iL >= 0:
            px, py = self._full_path_xy[iL]
            if is_anchor((px, py)):
                left_anchor_pt = (px, py)
                break
            iL -= 1

        # Walk right
        right_anchor_pt = None
        iR = best_idx
        while iR < len(self._full_path_xy):
            px, py = self._full_path_xy[iR]
            if is_anchor((px, py)):
                right_anchor_pt = (px, py)
                break
            iR += 1

        if not left_anchor_pt or not right_anchor_pt:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        if left_anchor_pt == right_anchor_pt:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        left_idx = None
        right_idx = None
        for i, (ax, ay) in enumerate(self.anchor_points):
            if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
                left_idx = i
            if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
                right_idx = i

        if left_idx is None or right_idx is None:
            self._insert_anchor_point(-1, x_new, y_new)
            return

        if left_idx < right_idx:
            insert_idx = left_idx + 1
        else:
            insert_idx = right_idx + 1

        self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
                                  z_val=1, radius=self.dot_radius)

    # --------------------------------------------------------------------
    # COST IMAGE
    # --------------------------------------------------------------------
    def _revert_cost_to_original(self):
        if self.cost_image_original is not None:
            self.cost_image = self.cost_image_original.copy()

    def _apply_all_guide_points_to_cost(self):
        if self.cost_image is None:
            return
        for i, (ax, ay) in enumerate(self.anchor_points):
            if self.point_items[i].is_removable():
                self._lower_cost_in_circle(ax, ay, self.radius_cost_image)

    def _lower_cost_in_circle(self, x_f, y_f, radius):
        if self.cost_image is None:
            return
        h, w = self.cost_image.shape
        row_c = int(round(y_f))
        col_c = int(round(x_f))
        if not (0 <= row_c < h and 0 <= col_c < w):
            return
        global_min = self.cost_image.min()
        r_s = max(0, row_c - radius)
        r_e = min(h, row_c + radius + 1)
        c_s = max(0, col_c - radius)
        c_e = min(w, col_c + radius + 1)
        for rr in range(r_s, r_e):
            for cc in range(c_s, c_e):
                dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                if dist <= radius:
                    self.cost_image[rr, cc] = global_min

    # --------------------------------------------------------------------
    # PATH BUILDING
    # --------------------------------------------------------------------
    def _rebuild_full_path(self):
        for item in self.full_path_points:
            self.scene.removeItem(item)
        self.full_path_points.clear()
        self._full_path_xy.clear()

        if len(self.anchor_points) < 2 or self.cost_image is None:
            return

        big_xy = []
        for i in range(len(self.anchor_points) - 1):
            xA, yA = self.anchor_points[i]
            xB, yB = self.anchor_points[i + 1]
            sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
            if i == 0:
                big_xy.extend(sub_xy)
            else:
                if len(sub_xy) > 1:
                    big_xy.extend(sub_xy[1:])

        if len(big_xy) >= 7:
            arr_xy = np.array(big_xy)
            smoothed = savgol_filter(arr_xy, window_length=7, polyorder=1, axis=0)
            big_xy = smoothed.tolist()

        self._full_path_xy = big_xy[:]

        n_points = len(big_xy)
        for i, (px, py) in enumerate(big_xy):
            fraction = i / (n_points - 1) if n_points > 1 else 0
            if self._rainbow_enabled:
                color = self._rainbow_color(fraction)
            else:
                color = Qt.red

            path_item = LabeledPointItem(px, py, label="",
                                         radius=self.path_radius,
                                         color=color,
                                         removable=False,
                                         z_value=0)
            self.full_path_points.append(path_item)
            self.scene.addItem(path_item)

        # Keep anchor labels on top
        for p_item in self.point_items:
            if p_item._text_item:
                p_item.setZValue(100)

    def _compute_subpath_xy(self, xA, yA, xB, yB):
        if self.cost_image is None:
            return []
        h, w = self.cost_image.shape
        rA, cA = int(round(yA)), int(round(xA))
        rB, cB = int(round(yB)), int(round(xB))
        rA = max(0, min(rA, h - 1))
        cA = max(0, min(cA, w - 1))
        rB = max(0, min(rB, h - 1))
        cB = max(0, min(cB, w - 1))
        try:
            path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
        except ValueError as e:
            print("Error in find_path:", e)
            return []
        return [(c, r) for (r, c) in path_rc]

    def _rainbow_color(self, fraction):
        hue = int(300 * fraction)
        saturation = 255
        value = 255
        return QColor.fromHsv(hue, saturation, value)

    # --------------------------------------------------------------------
    # MOUSE EVENTS (with pan & zoom from PanZoomGraphicsView)
    # --------------------------------------------------------------------
    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self._mouse_pressed = True
            self._was_dragging = False
            self._press_view_pos = event.pos()

            idx = self._find_item_near(event.pos(), threshold=10)
            if idx is not None:
                self._dragging_idx = idx
                self._drag_counter = 0
                scene_pos = self.mapToScene(event.pos())
                px, py = self.point_items[idx].get_pos()
                self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                self.setCursor(Qt.ClosedHandCursor)
                return

        elif event.button() == Qt.RightButton:
            self._remove_point_by_click(event.pos())

        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        if self._dragging_idx is not None:
            scene_pos = self.mapToScene(event.pos())
            x_new = scene_pos.x() - self._drag_offset[0]
            y_new = scene_pos.y() - self._drag_offset[1]

            r = self.point_items[self._dragging_idx]._r
            x_clamped = self._clamp(x_new, r, self._img_w - r)
            y_clamped = self._clamp(y_new, r, self._img_h - r)
            self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)

            self._drag_counter += 1
            if self._drag_counter >= 4:
                self._drag_counter = 0
                self._revert_cost_to_original()
                self._apply_all_guide_points_to_cost()
                self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
                self._rebuild_full_path()
        else:
            if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                dist = (event.pos() - self._press_view_pos).manhattanLength()
                if dist > self._drag_threshold:
                    self._was_dragging = True

        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        super().mouseReleaseEvent(event)
        if event.button() == Qt.LeftButton and self._mouse_pressed:
            self._mouse_pressed = False
            self.setCursor(Qt.ArrowCursor)

            if self._dragging_idx is not None:
                idx = self._dragging_idx
                self._dragging_idx = None
                self._drag_offset = (0, 0)

                newX, newY = self.point_items[idx].get_pos()
                self.anchor_points[idx] = (newX, newY)

                self._revert_cost_to_original()
                self._apply_all_guide_points_to_cost()
                self._rebuild_full_path()
            else:
                # No drag => add point
                if not self._was_dragging:
                    scene_pos = self.mapToScene(event.pos())
                    x, y = scene_pos.x(), scene_pos.y()
                    self._add_guide_point(x, y)

            self._was_dragging = False

    def _remove_point_by_click(self, view_pos):
        idx = self._find_item_near(view_pos, threshold=10)
        if idx is None:
            return
        if not self.point_items[idx].is_removable():
            return

        self.scene.removeItem(self.point_items[idx])
        self.point_items.pop(idx)
        self.anchor_points.pop(idx)

        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

    def _find_item_near(self, view_pos, threshold=10):
        scene_pos = self.mapToScene(view_pos)
        x_click, y_click = scene_pos.x(), scene_pos.y()

        closest_idx = None
        min_dist = float('inf')
        for i, itm in enumerate(self.point_items):
            d = itm.distance_to(x_click, y_click)
            if d < min_dist:
                min_dist = d
                closest_idx = i
        if closest_idx is not None and min_dist <= threshold:
            return closest_idx
        return None

    # --------------------------------------------------------------------
    # UTILS
    # --------------------------------------------------------------------
    def _clamp(self, val, mn, mx):
        return max(mn, min(val, mx))

    def _clear_all_points(self):
        for it in self.point_items:
            self.scene.removeItem(it)
        self.point_items.clear()
        self.anchor_points.clear()

        for p in self.full_path_points:
            self.scene.removeItem(p)
        self.full_path_points.clear()
        self._full_path_xy.clear()

    def clear_guide_points(self):
        i = 0
        while i < len(self.anchor_points):
            if self.point_items[i].is_removable():
                self.scene.removeItem(self.point_items[i])
                del self.point_items[i]
                del self.anchor_points[i]
            else:
                i += 1

        for it in self.full_path_points:
            self.scene.removeItem(it)
        self.full_path_points.clear()
        self._full_path_xy.clear()

        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

    def get_full_path_xy(self):
        return self._full_path_xy


# ------------------------------------------------------------------------
# Main Window
# ------------------------------------------------------------------------
class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Test GUI")

        self._last_loaded_pixmap = None
        self._circle_radius_for_later_use = 0

        # Original main widget
        self._main_widget = QWidget()
        self._main_layout = QVBoxLayout(self._main_widget)

        # Image view
        self.image_view = ImageGraphicsView()
        self._main_layout.addWidget(self.image_view)

        # Button row
        btn_layout = QHBoxLayout()

        self.btn_load_image = QPushButton("Load Image")
        self.btn_load_image.clicked.connect(self.load_image)
        btn_layout.addWidget(self.btn_load_image)

        self.btn_export_path = QPushButton("Export Path")
        self.btn_export_path.clicked.connect(self.export_path)
        btn_layout.addWidget(self.btn_export_path)

        self.btn_clear_points = QPushButton("Clear Points")
        self.btn_clear_points.clicked.connect(self.clear_points)
        btn_layout.addWidget(self.btn_clear_points)

        self.btn_toggle_rainbow = QPushButton("Toggle Rainbow")
        self.btn_toggle_rainbow.clicked.connect(self.toggle_rainbow)
        btn_layout.addWidget(self.btn_toggle_rainbow)

        # New circle editor button
        self.btn_open_editor = QPushButton("Open Circle Editor")
        self.btn_open_editor.clicked.connect(self.open_circle_editor)
        btn_layout.addWidget(self.btn_open_editor)

        self._main_layout.addLayout(btn_layout)
        self.setCentralWidget(self._main_widget)

        self.resize(900, 600)

        # We keep references for old/new
        self._old_central_widget = None
        self._editor = None

    def open_circle_editor(self):
        """Removes the current central widget, replaces with circle editor."""
        if not self._last_loaded_pixmap:
            print("No image loaded yet! Cannot open circle editor.")
            return

        # Step 1: take the old widget out of QMainWindow ownership
        old_widget = self.takeCentralWidget()
        self._old_central_widget = old_widget

        # Step 2: create the editor
        init_radius = 20
        editor = CircleEditorWidget(
            pixmap=self._last_loaded_pixmap,
            init_radius=init_radius,
            done_callback=self._on_circle_editor_done
        )
        self._editor = editor

        # Step 3: set the new editor as the central widget
        self.setCentralWidget(editor)

    def _on_circle_editor_done(self, final_radius):
        self._circle_radius_for_later_use = final_radius
        print(f"Circle Editor done. Radius = {final_radius}")

        # Take the editor out
        editor_widget = self.takeCentralWidget()
        if editor_widget is not None:
            editor_widget.setParent(None)

        # Put back the old widget
        if self._old_central_widget is not None:
            self.setCentralWidget(self._old_central_widget)
            self._old_central_widget = None

        # We can delete the editor if we like
        if self._editor is not None:
            self._editor.deleteLater()
            self._editor = None

    # --------------------------------------------------------------------
    # Existing Functions
    # --------------------------------------------------------------------
    def toggle_rainbow(self):
        self.image_view.toggle_rainbow()

    def load_image(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Open Image", "",
            "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
            options=options
        )
        if file_path:
            self.image_view.load_image(file_path)
            cost_img = compute_cost_image(file_path)
            self.image_view.cost_image_original = cost_img
            self.image_view.cost_image = cost_img.copy()

            # Store a pixmap to reuse
            pm = QPixmap(file_path)
            if not pm.isNull():
                self._last_loaded_pixmap = pm

    def export_path(self):
        full_xy = self.image_view.get_full_path_xy()
        if not full_xy:
            print("No path to export.")
            return

        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Export Path", "",
            "NumPy Files (*.npy);;All Files (*)",
            options=options
        )
        if file_path:
            arr = np.array(full_xy)
            np.save(file_path, arr)
            print(f"Exported path with {len(arr)} points to {file_path}")

    def clear_points(self):
        self.image_view.clear_guide_points()

    def closeEvent(self, event):
        super().closeEvent(event)


def main():
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()