import sys
import numpy as np

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
    QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    QHBoxLayout, QVBoxLayout, QWidget, QFileDialog
)
from PyQt5.QtGui import QPixmap, QPen, QBrush
from PyQt5.QtCore import Qt, QRectF


class ImageGraphicsView(QGraphicsView):
    """
    Custom class inheriting from QGraphicsView for displaying an image and placing red dots.
    """
    def __init__(self, parent=None):
        super().__init__(parent)

        # Create scene and add it to the view
        self.scene = QGraphicsScene(self)
        self.setScene(self.scene)

        # Zoom around mouse pointer
        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)

        # Image item
        self.image_item = QGraphicsPixmapItem()
        self.scene.addItem(self.image_item)

        # Points and dot items
        self.points = []
        self.point_items = []
        self.editor_mode = False
        self.dot_radius = 4

        # Enable built-in panning around image, but force arrow cursor initially
        self.setDragMode(QGraphicsView.ScrollHandDrag)
        self.viewport().setCursor(Qt.ArrowCursor)

        # Track clicking vs. dragging
        self._mouse_pressed = False
        self._press_view_pos = None
        self._drag_threshold = 5
        self._was_dragging = False

    def load_image(self, image_path):
        """Load an image and fit it in the view."""
        pixmap = QPixmap(image_path)

        if not pixmap.isNull():
            self.image_item.setPixmap(pixmap)

            # Avoid TypeError by converting to QRectF
            self.setSceneRect(QRectF(pixmap.rect()))

            # Clear existing dots from previous image
            self.points.clear()
            self._clear_point_items()

            # Reset transform then fit image in view
            self.resetTransform()
            self.fitInView(self.image_item, Qt.KeepAspectRatio)

    def set_editor_mode(self, mode: bool):
        """If True: place/remove dots; if False: do nothing on click."""
        self.editor_mode = mode

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self._mouse_pressed = True
            self._was_dragging = False
            self._press_view_pos = event.pos()

            # Switch to closed-hand cursor while left mouse is down
            self.viewport().setCursor(Qt.ClosedHandCursor)

        elif event.button() == Qt.RightButton:
            # If Editor Mode is on remove the nearest dot
            if self.editor_mode:
                self._remove_point(event.pos())

        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        """
        If movement > _drag_threshold: consider it a drag.
        The actual panning is handled by QGraphicsView in ScrollHandDrag mode.
        """
        # Check if the mouse is being dragged
        if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
            # If the mouse moved more than the threshold, consider it a drag
            dist = (event.pos() - self._press_view_pos).manhattanLength()
            if dist > self._drag_threshold:
                self._was_dragging = True

        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        """
        After releasing the left button, go back to arrow cursor.
        If it wasn't a drag, treat as a click (Editor Mode: add dot).
        """
        # Let QGraphicsView handle release first
        super().mouseReleaseEvent(event)  

        if event.button() == Qt.LeftButton and self._mouse_pressed:
            self._mouse_pressed = False

            # Always go back to arrow cursor AFTER letting QGraphicsView handle release
            self.viewport().setCursor(Qt.ArrowCursor)

            if not self._was_dragging:
                # It's a click: if editor mode is ON add a dot
                if self.editor_mode:
                    self._add_point(event.pos())

            self._was_dragging = False

    def wheelEvent(self, event):
        """Mouse wheel = zoom."""
        zoom_in_factor = 1.25
        zoom_out_factor = 1 / zoom_in_factor

        if event.angleDelta().y() > 0:
            self.scale(zoom_in_factor, zoom_in_factor)
        else:
            self.scale(zoom_out_factor, zoom_out_factor)

        event.accept()


    # -------------- Red Dots --------------
    def _add_point(self, view_pos):
        """Add a red dot at scene coords corresponding to view_pos."""
        scene_pos = self.mapToScene(view_pos)
        x, y = scene_pos.x(), scene_pos.y()

        self.points.append((x, y))
        dot = self._create_dot_item(x, y)
        self.point_items.append(dot)
        self.scene.addItem(dot)

    def _remove_point(self, view_pos):
        """Right-click: remove nearest dot if within threshold."""
        scene_pos = self.mapToScene(view_pos)
        x_click, y_click = scene_pos.x(), scene_pos.y()

        # Define threshold for removing a point
        threshold = 10
        closest_idx = None
        min_dist = float('inf')

        # Find the closest point to the click
        for i, (x, y) in enumerate(self.points):
            dist_sq = (x - x_click)**2 + (y - y_click)**2
            if dist_sq < min_dist:
                min_dist = dist_sq
                closest_idx = i

        # Remove the closest point if it's within the threshold
        if closest_idx is not None and min_dist <= threshold**2:
            self.scene.removeItem(self.point_items[closest_idx])
            del self.point_items[closest_idx]
            del self.points[closest_idx]

    def _create_dot_item(self, x, y):
        """Helper for creating a small red ellipse item."""
        r = self.dot_radius
        ellipse = QGraphicsEllipseItem(x - r, y - r, 2*r, 2*r)
        ellipse.setBrush(QBrush(Qt.red))
        ellipse.setPen(QPen(Qt.red))
        return ellipse

    def _clear_point_items(self):
        """Remove all dot items from the scene."""
        for item in self.point_items:
            self.scene.removeItem(item)
        self.point_items = []


class MainWindow(QMainWindow):
    """
    Main window with:
      - Button to load in image
      - Editor mode toggle button
      - Button for exporting placed points
    """
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Test GUI")

        main_widget = QWidget()
        main_layout = QVBoxLayout(main_widget)

        self.image_view = ImageGraphicsView()
        main_layout.addWidget(self.image_view)

        btn_layout = QHBoxLayout()

        # Load Image
        self.btn_load_image = QPushButton("Load Image")
        self.btn_load_image.clicked.connect(self.load_image)
        btn_layout.addWidget(self.btn_load_image)

        # Editor Mode
        self.btn_editor_mode = QPushButton("Editor Mode: OFF")
        self.btn_editor_mode.setCheckable(True)
        self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
        self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
        btn_layout.addWidget(self.btn_editor_mode)

        # Export Points
        self.btn_export_points = QPushButton("Export Points")
        self.btn_export_points.clicked.connect(self.export_points)
        btn_layout.addWidget(self.btn_export_points)

        main_layout.addLayout(btn_layout)
        self.setCentralWidget(main_widget)
        self.resize(900, 600)

    def load_image(self):
        """Open a file dialog to pick an image then load it."""
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Open Image", "",
            "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
            options=options
        )
        if file_path:
            self.image_view.load_image(file_path)

    def toggle_editor_mode(self):
        """Toggle whether left-click places dots and right-click removes dots."""
        is_checked = self.btn_editor_mode.isChecked()
        self.image_view.set_editor_mode(is_checked)

        if is_checked:
            self.btn_editor_mode.setText("Editor Mode: ON")
            self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
        else:
            self.btn_editor_mode.setText("Editor Mode: OFF")
            self.btn_editor_mode.setStyleSheet("background-color: lightgray;")

    def export_points(self):
        """Save the list of dot coords to a .npy file."""
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Export Points", "",
            "NumPy Files (*.npy);;All Files (*)",
            options=options
        )
        if file_path:
            points_array = np.array(self.image_view.points)
            np.save(file_path, points_array)
            print(f"Exported {len(points_array)} points to {file_path}")


def main():
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())


if __name__ == '__main__':
    main()