Skip to content
Snippets Groups Projects
GUI_draft.py 8.85 KiB
Newer Older
  • Learn to ignore specific revisions
  • Christian's avatar
    Christian committed
    import sys
    import numpy as np
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog
    )
    from PyQt5.QtGui import QPixmap, QPen, QBrush
    from PyQt5.QtCore import Qt, QRectF
    
    
    class ImageGraphicsView(QGraphicsView):
        """
        Custom class inheriting from QGraphicsView for displaying an image and placing red dots.
        """
        def __init__(self, parent=None):
            super().__init__(parent)
    
            # Create scene and add it to the view
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
            # Zoom around mouse pointer
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
    
            # Image item
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
            # Points and dot items
            self.points = []
            self.point_items = []
            self.editor_mode = False
            self.dot_radius = 4
    
            # Enable built-in panning around image, but force arrow cursor initially
            self.setDragMode(QGraphicsView.ScrollHandDrag)
            self.viewport().setCursor(Qt.ArrowCursor)
    
            # Track clicking vs. dragging
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
    
        def load_image(self, image_path):
            """Load an image and fit it in the view."""
            pixmap = QPixmap(image_path)
    
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
    
                # Avoid TypeError by converting to QRectF
                self.setSceneRect(QRectF(pixmap.rect()))
    
                # Clear existing dots from previous image
                self.points.clear()
                self._clear_point_items()
    
                # Reset transform then fit image in view
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
        def set_editor_mode(self, mode: bool):
            """If True: place/remove dots; if False: do nothing on click."""
            self.editor_mode = mode
    
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
                # Switch to closed-hand cursor while left mouse is down
                self.viewport().setCursor(Qt.ClosedHandCursor)
    
            elif event.button() == Qt.RightButton:
                # If Editor Mode is on remove the nearest dot
                if self.editor_mode:
                    self._remove_point(event.pos())
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
            """
            If movement > _drag_threshold: consider it a drag.
            The actual panning is handled by QGraphicsView in ScrollHandDrag mode.
            """
            # Check if the mouse is being dragged
            if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                # If the mouse moved more than the threshold, consider it a drag
                dist = (event.pos() - self._press_view_pos).manhattanLength()
                if dist > self._drag_threshold:
                    self._was_dragging = True
    
            super().mouseMoveEvent(event)
    
        def mouseReleaseEvent(self, event):
            """
            After releasing the left button, go back to arrow cursor.
            If it wasn't a drag, treat as a click (Editor Mode: add dot).
            """
            # Let QGraphicsView handle release first
            super().mouseReleaseEvent(event)  
    
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
    
                # Always go back to arrow cursor AFTER letting QGraphicsView handle release
                self.viewport().setCursor(Qt.ArrowCursor)
    
                if not self._was_dragging:
                    # It's a click: if editor mode is ON add a dot
                    if self.editor_mode:
                        self._add_point(event.pos())
    
                self._was_dragging = False
    
        def wheelEvent(self, event):
            """Mouse wheel = zoom."""
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
    
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
    
            event.accept()
    
    
        # -------------- Red Dots --------------
        def _add_point(self, view_pos):
            """Add a red dot at scene coords corresponding to view_pos."""
            scene_pos = self.mapToScene(view_pos)
            x, y = scene_pos.x(), scene_pos.y()
    
            self.points.append((x, y))
            dot = self._create_dot_item(x, y)
            self.point_items.append(dot)
            self.scene.addItem(dot)
    
        def _remove_point(self, view_pos):
            """Right-click: remove nearest dot if within threshold."""
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            # Define threshold for removing a point
            threshold = 10
            closest_idx = None
            min_dist = float('inf')
    
            # Find the closest point to the click
            for i, (x, y) in enumerate(self.points):
                dist_sq = (x - x_click)**2 + (y - y_click)**2
                if dist_sq < min_dist:
                    min_dist = dist_sq
                    closest_idx = i
    
            # Remove the closest point if it's within the threshold
            if closest_idx is not None and min_dist <= threshold**2:
                self.scene.removeItem(self.point_items[closest_idx])
                del self.point_items[closest_idx]
                del self.points[closest_idx]
    
        def _create_dot_item(self, x, y):
            """Helper for creating a small red ellipse item."""
            r = self.dot_radius
            ellipse = QGraphicsEllipseItem(x - r, y - r, 2*r, 2*r)
            ellipse.setBrush(QBrush(Qt.red))
            ellipse.setPen(QPen(Qt.red))
            return ellipse
    
        def _clear_point_items(self):
            """Remove all dot items from the scene."""
            for item in self.point_items:
                self.scene.removeItem(item)
            self.point_items = []
    
    
    class MainWindow(QMainWindow):
        """
        Main window with:
          - Button to load in image
          - Editor mode toggle button
          - Button for exporting placed points
        """
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            main_widget = QWidget()
            main_layout = QVBoxLayout(main_widget)
    
            self.image_view = ImageGraphicsView()
            main_layout.addWidget(self.image_view)
    
            btn_layout = QHBoxLayout()
    
            # Load Image
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
            # Editor Mode
            self.btn_editor_mode = QPushButton("Editor Mode: OFF")
            self.btn_editor_mode.setCheckable(True)
            self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
            self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
            btn_layout.addWidget(self.btn_editor_mode)
    
            # Export Points
            self.btn_export_points = QPushButton("Export Points")
            self.btn_export_points.clicked.connect(self.export_points)
            btn_layout.addWidget(self.btn_export_points)
    
            main_layout.addLayout(btn_layout)
            self.setCentralWidget(main_widget)
            self.resize(900, 600)
    
        def load_image(self):
            """Open a file dialog to pick an image then load it."""
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
    
        def toggle_editor_mode(self):
            """Toggle whether left-click places dots and right-click removes dots."""
            is_checked = self.btn_editor_mode.isChecked()
            self.image_view.set_editor_mode(is_checked)
    
            if is_checked:
                self.btn_editor_mode.setText("Editor Mode: ON")
                self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
            else:
                self.btn_editor_mode.setText("Editor Mode: OFF")
                self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
    
        def export_points(self):
            """Save the list of dot coords to a .npy file."""
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Points", "",
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
                points_array = np.array(self.image_view.points)
                np.save(file_path, points_array)
                print(f"Exported {len(points_array)} points to {file_path}")
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    if __name__ == '__main__':
        main()