import sys
import math
import numpy as np

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
    QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
)
from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt, QRectF

from live_wire import compute_cost_image, find_path


class LabeledPointItem(QGraphicsEllipseItem):
    def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, z_value=0, parent=None):
        super().__init__(0, 0, 2*radius, 2*radius, parent)
        self._x = x
        self._y = y
        self._r = radius
        self._removable = removable

        pen = QPen(color)
        brush = QBrush(color)
        self.setPen(pen)
        self.setBrush(brush)
        self.setZValue(z_value)

        self._text_item = None
        if label:
            self._text_item = QGraphicsTextItem(self)
            self._text_item.setPlainText(label)
            self._text_item.setDefaultTextColor(QColor("black"))
            font = QFont("Arial", 14)
            font.setBold(True)
            self._text_item.setFont(font)
            self._scale_text_to_fit()

        self.set_pos(x, y)

    def _scale_text_to_fit(self):
        if not self._text_item:
            return
        self._text_item.setScale(1.0)
        circle_diam = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        text_w = raw_rect.width()
        text_h = raw_rect.height()
        if text_w > circle_diam or text_h > circle_diam:
            scale_factor = min(circle_diam / text_w, circle_diam / text_h)
            self._text_item.setScale(scale_factor)
        self._center_label()

    def _center_label(self):
        if not self._text_item:
            return
        ellipse_w = 2 * self._r
        ellipse_h = 2 * self._r
        raw_rect = self._text_item.boundingRect()
        scale_factor = self._text_item.scale()
        scaled_w = raw_rect.width() * scale_factor
        scaled_h = raw_rect.height() * scale_factor
        tx = (ellipse_w - scaled_w) * 0.5
        ty = (ellipse_h - scaled_h) * 0.5
        self._text_item.setPos(tx, ty)

    def set_pos(self, x, y):
        self._x = x
        self._y = y
        self.setPos(x - self._r, y - self._r)

    def get_pos(self):
        return (self._x, self._y)

    def distance_to(self, x_other, y_other):
        return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)

    def is_removable(self):
        return self._removable


class ImageGraphicsView(QGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.scene = QGraphicsScene(self)
        self.setScene(self.scene)

        # Allow zoom around mouse pointer
        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)

        # Image display item
        self.image_item = QGraphicsPixmapItem()
        self.scene.addItem(self.image_item)

        # Parallel lists
        self.anchor_points = []  # List[(x, y)]
        self.point_items = []    # List[LabeledPointItem]

        self.editor_mode = False
        self.dot_radius = 4
        self.path_radius = 1
        self.radius_cost_image = 2  # cost-lowering radius
        self._img_w = 0
        self._img_h = 0

        # For pan/drag
        self.setDragMode(QGraphicsView.ScrollHandDrag)
        self.viewport().setCursor(Qt.ArrowCursor)

        self._mouse_pressed = False
        self._press_view_pos = None
        self._drag_threshold = 5
        self._was_dragging = False
        self._dragging_idx = None
        self._drag_offset = (0, 0)

        # Keep original cost image to revert changes
        self.cost_image_original = None
        self.cost_image = None

        # The path is displayed as small magenta circles in self.full_path_points
        self.full_path_points = []

    # --------------------------------------------------------------------
    # LOADING
    # --------------------------------------------------------------------
    def load_image(self, path):
        pixmap = QPixmap(path)
        if not pixmap.isNull():
            self.image_item.setPixmap(pixmap)
            self.setSceneRect(QRectF(pixmap.rect()))

            self._img_w = pixmap.width()
            self._img_h = pixmap.height()

            self._clear_all_points()
            self.resetTransform()
            self.fitInView(self.image_item, Qt.KeepAspectRatio)

            # Create S/E
            s_x, s_y = 0.15*self._img_w, 0.5*self._img_h
            e_x, e_y = 0.85*self._img_w, 0.5*self._img_h

            # S => not removable
            self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
            # E => not removable
            self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)

    def set_editor_mode(self, mode: bool):
        self.editor_mode = mode

    # --------------------------------------------------------------------
    # ANCHOR POINTS
    # --------------------------------------------------------------------
    def _insert_anchor_point(self, idx, x, y, label="", removable=True, z_val=0, radius=4):
        """
        Insert at index=idx, or -1 => append just before E if E exists.
        """
        if idx < 0:
            # If we have at least 2 anchors, the last is E => insert before that
            if len(self.anchor_points) >= 2:
                idx = len(self.anchor_points) - 1
            else:
                idx = len(self.anchor_points)

        self.anchor_points.insert(idx, (x, y))
        color = Qt.green if label in ("S","E") else Qt.red
        item = LabeledPointItem(x, y, label=label, radius=radius, color=color, 
                                removable=removable, z_value=z_val)
        self.point_items.insert(idx, item)
        self.scene.addItem(item)

    def _add_guide_point(self, x, y):
        """
        User added a red guide point => lower cost, insert anchor, rebuild path.
        """
        # 1) Revert cost
        self._revert_cost_to_original()
        # 2) Insert new anchor (removable)
        self._insert_anchor_point(-1, x, y, label="", removable=True, z_val=1, radius=self.dot_radius)
        # 3) Re-apply cost-lowering for all existing guide points
        self._apply_all_guide_points_to_cost()
        # 4) Rebuild path
        self._rebuild_full_path()

    # --------------------------------------------------------------------
    # COST IMAGE
    # --------------------------------------------------------------------
    def _revert_cost_to_original(self):
        """self.cost_image <- copy of self.cost_image_original"""
        if self.cost_image_original is not None:
            self.cost_image = self.cost_image_original.copy()

    def _apply_all_guide_points_to_cost(self):
        """Lower cost around every removable anchor."""
        if self.cost_image is None:
            return
        for i, (ax, ay) in enumerate(self.anchor_points):
            if self.point_items[i].is_removable():
                self._lower_cost_in_circle(ax, ay, self.radius_cost_image)

    def _lower_cost_in_circle(self, x_f, y_f, radius):
        """Set cost_image row,col in circle of radius -> global min."""
        if self.cost_image is None:
            return
        h, w = self.cost_image.shape
        row_c = int(round(y_f))
        col_c = int(round(x_f))
        if not (0 <= row_c < h and 0 <= col_c < w):
            return
        global_min = self.cost_image.min()
        r_s = max(0, row_c - radius)
        r_e = min(h, row_c + radius + 1)
        c_s = max(0, col_c - radius)
        c_e = min(w, col_c + radius + 1)
        for rr in range(r_s, r_e):
            for cc in range(c_s, c_e):
                dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
                if dist <= radius:
                    self.cost_image[rr, cc] = global_min

    # --------------------------------------------------------------------
    # PATH BUILDING
    # --------------------------------------------------------------------
    def _rebuild_full_path(self):
        # Remove old path items
        for item in self.full_path_points:
            self.scene.removeItem(item)
        self.full_path_points.clear()

        # Build subpaths
        if len(self.anchor_points) < 2 or self.cost_image is None:
            return

        big_xy = []
        for i in range(len(self.anchor_points)-1):
            xA, yA = self.anchor_points[i]
            xB, yB = self.anchor_points[i+1]
            sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
            if i == 0:
                big_xy.extend(sub_xy)
            else:
                # avoid duplicating the point between subpaths
                if len(sub_xy) > 1:
                    big_xy.extend(sub_xy[1:])

        # Draw them
        for (px, py) in big_xy:
            path_item = LabeledPointItem(px, py, label="", radius=self.path_radius,
                                         color=Qt.magenta, removable=False, z_value=0)
            self.full_path_points.append(path_item)
            self.scene.addItem(path_item)

        # Ensure S/E stay on top
        for p_item in self.point_items:
            if p_item._text_item:
                p_item.setZValue(100)

    def _compute_subpath_xy(self, xA, yA, xB, yB):
        if self.cost_image is None:
            return []
        h, w = self.cost_image.shape
        rA, cA = int(round(yA)), int(round(xA))
        rB, cB = int(round(yB)), int(round(xB))
        rA = max(0, min(rA, h-1))
        cA = max(0, min(cA, w-1))
        rB = max(0, min(rB, h-1))
        cB = max(0, min(cB, w-1))
        try:
            path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
        except ValueError as e:
            print("Error in find_path:", e)
            return []
        return [(c, r) for (r, c) in path_rc]

    # --------------------------------------------------------------------
    # MOUSE EVENTS (dragging, adding, removing points)
    # --------------------------------------------------------------------
    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self._mouse_pressed = True
            self._was_dragging = False
            self._press_view_pos = event.pos()

            if self.editor_mode:
                idx = self._find_item_near(event.pos(), 10)
                if idx is not None:
                    # drag existing anchor
                    self._dragging_idx = idx
                    scene_pos = self.mapToScene(event.pos())
                    px, py = self.point_items[idx].get_pos()
                    self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
                    self.setDragMode(QGraphicsView.NoDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
                    return
                else:
                    # If no anchor near, user might be panning
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
            else:
                self.setDragMode(QGraphicsView.ScrollHandDrag)
                self.viewport().setCursor(Qt.ClosedHandCursor)

        elif event.button() == Qt.RightButton:
            if self.editor_mode:
                self._remove_point_by_click(event.pos())

        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        if self._dragging_idx is not None:
            # dragging an anchor
            scene_pos = self.mapToScene(event.pos())
            x_new = scene_pos.x() - self._drag_offset[0]
            y_new = scene_pos.y() - self._drag_offset[1]
            r = self.point_items[self._dragging_idx]._r
            x_clamped = self._clamp(x_new, r, self._img_w - r)
            y_clamped = self._clamp(y_new, r, self._img_h - r)
            self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
            return
        else:
            # if movement > threshold => pan
            if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                dist = (event.pos() - self._press_view_pos).manhattanLength()
                if dist > self._drag_threshold:
                    self._was_dragging = True
            super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        super().mouseReleaseEvent(event)
        if event.button() == Qt.LeftButton and self._mouse_pressed:
            self._mouse_pressed = False
            self.viewport().setCursor(Qt.ArrowCursor)
            if self._dragging_idx is not None:
                idx = self._dragging_idx
                self._dragging_idx = None
                self._drag_offset = (0, 0)
                self.setDragMode(QGraphicsView.ScrollHandDrag)

                # update anchor_points
                newX, newY = self.point_items[idx].get_pos()
                # even if S/E => update coords
                self.anchor_points[idx] = (newX, newY)

                # revert + re-apply cost, rebuild path
                self._revert_cost_to_original()
                self._apply_all_guide_points_to_cost()
                self._rebuild_full_path()

            else:
                if not self._was_dragging and self.editor_mode:
                    # user clicked an empty spot => add a guide point
                    scene_pos = self.mapToScene(event.pos())
                    x, y = scene_pos.x(), scene_pos.y()
                    self._add_guide_point(x, y)

            self._was_dragging = False

    def _remove_point_by_click(self, view_pos):
        idx = self._find_item_near(view_pos, threshold=10)
        if idx is None:
            return
        # check if removable => skip S/E
        if not self.point_items[idx].is_removable():
            return  # do nothing

        # remove anchor
        self.scene.removeItem(self.point_items[idx])
        self.point_items.pop(idx)
        self.anchor_points.pop(idx)

        # revert + re-apply cost, rebuild path
        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()

    def _find_item_near(self, view_pos, threshold=10):
        scene_pos = self.mapToScene(view_pos)
        x_click, y_click = scene_pos.x(), scene_pos.y()
        min_dist = float('inf')
        closest_idx = None
        for i, itm in enumerate(self.point_items):
            d = itm.distance_to(x_click, y_click)
            if d < min_dist:
                min_dist = d
                closest_idx = i
        if closest_idx is not None and min_dist <= threshold:
            return closest_idx
        return None

    # --------------------------------------------------------------------
    # ZOOM
    # --------------------------------------------------------------------
    def wheelEvent(self, event):
        """
        Zoom in/out with mouse wheel
        """
        zoom_in_factor = 1.25
        zoom_out_factor = 1 / zoom_in_factor

        # If the user scrolls upward => zoom in. Otherwise => zoom out.
        if event.angleDelta().y() > 0:
            self.scale(zoom_in_factor, zoom_in_factor)
        else:
            self.scale(zoom_out_factor, zoom_out_factor)
        event.accept()

    # --------------------------------------------------------------------
    # UTILS
    # --------------------------------------------------------------------
    def _clamp(self, val, mn, mx):
        return max(mn, min(val, mx))

    def _clear_all_points(self):
        for it in self.point_items:
            self.scene.removeItem(it)
        self.point_items.clear()
        self.anchor_points.clear()

        for p in self.full_path_points:
            self.scene.removeItem(p)
        self.full_path_points.clear()

    def clear_guide_points(self):
        """
        Removes all anchors that are 'removable' (guide points),
        keeps S/E in place. Then reverts cost, re-applies, rebuilds path.
        """
        i = 0
        while i < len(self.anchor_points):
            if self.point_items[i].is_removable():
                self.scene.removeItem(self.point_items[i])
                del self.point_items[i]
                del self.anchor_points[i]
            else:
                i += 1

        for item in self.full_path_points:
            self.scene.removeItem(item)
        self.full_path_points.clear()

        self._revert_cost_to_original()
        self._apply_all_guide_points_to_cost()
        self._rebuild_full_path()


class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Test GUI")

        main_widget = QWidget()
        main_layout = QVBoxLayout(main_widget)

        self.image_view = ImageGraphicsView()
        main_layout.addWidget(self.image_view)

        btn_layout = QHBoxLayout()

        self.btn_load_image = QPushButton("Load Image")
        self.btn_load_image.clicked.connect(self.load_image)
        btn_layout.addWidget(self.btn_load_image)

        self.btn_editor_mode = QPushButton("Editor Mode: OFF")
        self.btn_editor_mode.setCheckable(True)
        self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
        self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
        btn_layout.addWidget(self.btn_editor_mode)

        self.btn_export_points = QPushButton("Export Points")
        self.btn_export_points.clicked.connect(self.export_points)
        btn_layout.addWidget(self.btn_export_points)

        self.btn_clear_points = QPushButton("Clear Points")
        self.btn_clear_points.clicked.connect(self.clear_points)
        btn_layout.addWidget(self.btn_clear_points)

        main_layout.addLayout(btn_layout)
        self.setCentralWidget(main_widget)
        self.resize(900, 600)

    def load_image(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Open Image", "",
            "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
            options=options
        )
        if file_path:
            self.image_view.load_image(file_path)
            cost_img = compute_cost_image(file_path)
            self.image_view.cost_image_original = cost_img
            self.image_view.cost_image = cost_img.copy()

    def toggle_editor_mode(self):
        is_checked = self.btn_editor_mode.isChecked()
        self.image_view.set_editor_mode(is_checked)
        if is_checked:
            self.btn_editor_mode.setText("Editor Mode: ON")
            self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
        else:
            self.btn_editor_mode.setText("Editor Mode: OFF")
            self.btn_editor_mode.setStyleSheet("background-color: lightgray;")

    def export_points(self):
        if not self.image_view.anchor_points:
            print("No anchor points to export.")
            return
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Export Points", "",
            "NumPy Files (*.npy);;All Files (*)",
            options=options
        )
        if file_path:
            points_array = np.array(self.image_view.anchor_points)
            np.save(file_path, points_array)
            print(f"Exported {len(points_array)} points to {file_path}")

    def clear_points(self):
        """Remove all removable anchors (guide points), keep S/E in place."""
        self.image_view.clear_guide_points()

    def closeEvent(self, event):
        super().closeEvent(event)


def main():
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()