Skip to content
Snippets Groups Projects
GUI_draft.py 16.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QGraphicsView, QGraphicsScene,
        QGraphicsEllipseItem, QGraphicsPixmapItem, QPushButton,
    
    s224389's avatar
    s224389 committed
        QHBoxLayout, QVBoxLayout, QWidget, QFileDialog, QGraphicsTextItem
    
    s224389's avatar
    s224389 committed
    from PyQt5.QtGui import QPixmap, QPen, QBrush, QColor, QFont
    
    from PyQt5.QtCore import Qt, QRectF
    
    Christian's avatar
    Christian committed
    # Import your live_wire functions
    from live_wire import compute_cost_image, find_path
    
    
    class LabeledPointItem(QGraphicsEllipseItem):
    
    s224389's avatar
    s224389 committed
        A circle with optional (bold) label (e.g. 'S'/'E'),
        which automatically scales the text if it's bigger than the circle.
    
        def __init__(self, x, y, label="", radius=4, color=Qt.red, removable=True, parent=None):
    
    s224389's avatar
    s224389 committed
            super().__init__(0, 0, 2*radius, 2*radius, parent)
            self._x = x       # Center x
            self._y = y       # Center y
            self._r = radius  # Circle radius
            self._removable = removable
    
    s224389's avatar
    s224389 committed
            # Circle styling
    
            pen = QPen(color)
            brush = QBrush(color)
            self.setPen(pen)
            self.setBrush(brush)
    
    
    s224389's avatar
    s224389 committed
            # Optional text label
    
            self._text_item = None
            if label:
    
    s224389's avatar
    s224389 committed
                self._text_item = QGraphicsTextItem(self)
                self._text_item.setPlainText(label)
                self._text_item.setDefaultTextColor(QColor("black"))
                font = QFont("Arial", 14)
                font.setBold(True)
                self._text_item.setFont(font)
                self._scale_text_to_fit()
    
            # Move so center is at (x, y)
            self.set_pos(x, y)
    
        def _scale_text_to_fit(self):
            if not self._text_item:
                return
            self._text_item.setScale(1.0)
            circle_diam = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            text_w = raw_rect.width()
            text_h = raw_rect.height()
            if text_w > circle_diam or text_h > circle_diam:
                scale_w = circle_diam / text_w
                scale_h = circle_diam / text_h
                scale_factor = min(scale_w, scale_h)
                self._text_item.setScale(scale_factor)
            self._center_label()
    
        def _center_label(self):
            if not self._text_item:
                return
            ellipse_w = 2 * self._r
            ellipse_h = 2 * self._r
            raw_rect = self._text_item.boundingRect()
            scale_factor = self._text_item.scale()
    
    Christian's avatar
    Christian committed
            scaled_w = raw_rect.width() * scale_factor
    
    s224389's avatar
    s224389 committed
            scaled_h = raw_rect.height() * scale_factor
            tx = (ellipse_w - scaled_w) * 0.5
            ty = (ellipse_h - scaled_h) * 0.5
            self._text_item.setPos(tx, ty)
    
    Christian's avatar
    Christian committed
            """Move so the circle's center is at (x,y) in scene coords."""
    
    s224389's avatar
    s224389 committed
            self.setPos(x - self._r, y - self._r)
    
    s224389's avatar
    s224389 committed
        def get_pos(self):
            return (self._x, self._y)
    
        def distance_to(self, x_other, y_other):
            dx = self._x - x_other
            dy = self._y - y_other
            return math.sqrt(dx*dx + dy*dy)
    
    
    s224389's avatar
    s224389 committed
        def is_removable(self):
            return self._removable
    
    
    
    class ImageGraphicsView(QGraphicsView):
        """
    
    s224389's avatar
    s224389 committed
        Displays an image and allows placing/dragging labeled points.
        Ensures points can't go outside the image boundary.
    
        """
        def __init__(self, parent=None):
            super().__init__(parent)
            self.scene = QGraphicsScene(self)
            self.setScene(self.scene)
    
            # Zoom around mouse pointer
            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
    
            # Image item
            self.image_item = QGraphicsPixmapItem()
            self.scene.addItem(self.image_item)
    
    
    Christian's avatar
    Christian committed
            self.points = []     # LabeledPointItem objects
    
            self.editor_mode = False
    
    
    Christian's avatar
    Christian committed
            self.dot_radius = 4
            self.path_radius = 1 # radius of circles in path
    
    s224389's avatar
    s224389 committed
            self._img_w = 0
            self._img_h = 0
    
    
            self.setDragMode(QGraphicsView.ScrollHandDrag)
            self.viewport().setCursor(Qt.ArrowCursor)
    
            self._mouse_pressed = False
            self._press_view_pos = None
            self._drag_threshold = 5
            self._was_dragging = False
    
            self._dragging_idx = None
            self._drag_offset = (0, 0)
    
    Christian's avatar
    Christian committed
            # Cost image from compute_cost_image
            self.cost_image = None
    
            # All path points displayed in magenta
            self.path_points = []
    
    
        def load_image(self, image_path):
            pixmap = QPixmap(image_path)
            if not pixmap.isNull():
                self.image_item.setPixmap(pixmap)
                self.setSceneRect(QRectF(pixmap.rect()))
    
    
    s224389's avatar
    s224389 committed
                # Save image dimensions
                self._img_w = pixmap.width()
                self._img_h = pixmap.height()
    
    s224389's avatar
    s224389 committed
                self._clear_point_items(remove_all=True)
    
                self.resetTransform()
                self.fitInView(self.image_item, Qt.KeepAspectRatio)
    
    
    s224389's avatar
    s224389 committed
                # Positions for S/E
                s_x = self._img_w * 0.15
                s_y = self._img_h * 0.5
                e_x = self._img_w * 0.85
                e_y = self._img_h * 0.5
    
                # Create green S/E with radius=6
                s_point = self._create_point(s_x, s_y, "S", 6, Qt.green, removable=False)
                e_point = self._create_point(e_x, e_y, "E", 6, Qt.green, removable=False)
    
    
                self.points = [s_point, e_point]
                self.scene.addItem(s_point)
                self.scene.addItem(e_point)
    
    
        def set_editor_mode(self, mode: bool):
            self.editor_mode = mode
    
    
    s224389's avatar
    s224389 committed
        def _create_point(self, x, y, label, radius, color, removable=True):
    
    Christian's avatar
    Christian committed
            # Clamp coordinates so center doesn't go outside the image
    
    s224389's avatar
    s224389 committed
            cx = self._clamp(x, radius, self._img_w - radius)
            cy = self._clamp(y, radius, self._img_h - radius)
    
    Christian's avatar
    Christian committed
            return LabeledPointItem(cx, cy, label=label, radius=radius, color=color, removable=removable)
    
    s224389's avatar
    s224389 committed
    
        def _clamp(self, val, min_val, max_val):
            return max(min_val, min(val, max_val))
    
    
        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                self._mouse_pressed = True
                self._was_dragging = False
                self._press_view_pos = event.pos()
    
    
                if self.editor_mode:
                    idx = self._find_point_near(event.pos(), threshold=10)
                    if idx is not None:
                        self._dragging_idx = idx
                        scene_pos = self.mapToScene(event.pos())
                        px, py = self.points[idx].get_pos()
                        self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
    
                        self.setDragMode(QGraphicsView.NoDrag)
                        self.viewport().setCursor(Qt.ClosedHandCursor)
                        return
                    else:
                        self.setDragMode(QGraphicsView.ScrollHandDrag)
                        self.viewport().setCursor(Qt.ClosedHandCursor)
                else:
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                    self.viewport().setCursor(Qt.ClosedHandCursor)
    
    
            elif event.button() == Qt.RightButton:
                if self.editor_mode:
                    self._remove_point(event.pos())
    
            super().mousePressEvent(event)
    
        def mouseMoveEvent(self, event):
    
    s224389's avatar
    s224389 committed
                # Dragging an existing point
    
                scene_pos = self.mapToScene(event.pos())
                x_new = scene_pos.x() - self._drag_offset[0]
                y_new = scene_pos.y() - self._drag_offset[1]
    
    s224389's avatar
    s224389 committed
    
    
    Christian's avatar
    Christian committed
                # Clamp so center doesn't go out of the image
    
    s224389's avatar
    s224389 committed
                r = self.points[self._dragging_idx]._r
                x_clamped = self._clamp(x_new, r, self._img_w - r)
                y_clamped = self._clamp(y_new, r, self._img_h - r)
                self.points[self._dragging_idx].set_pos(x_clamped, y_clamped)
    
    s224389's avatar
    s224389 committed
                # If movement > threshold => treat as pan
    
                if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
                    dist = (event.pos() - self._press_view_pos).manhattanLength()
                    if dist > self._drag_threshold:
                        self._was_dragging = True
                super().mouseMoveEvent(event)
    
    
        def mouseReleaseEvent(self, event):
    
            if event.button() == Qt.LeftButton and self._mouse_pressed:
                self._mouse_pressed = False
                self.viewport().setCursor(Qt.ArrowCursor)
    
    
    Christian's avatar
    Christian committed
                    # The user was dragging a point and now released
    
                    self._dragging_idx = None
                    self._drag_offset = (0, 0)
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
    
    Christian's avatar
    Christian committed
                    self._run_find_path()  # Recompute path
    
    s224389's avatar
    s224389 committed
                    # If not dragged, maybe add a new point
    
                    if not self._was_dragging and self.editor_mode:
    
                        self._add_point(event.pos())
    
    Christian's avatar
    Christian committed
                        self._run_find_path()
    
    
                self._was_dragging = False
    
        def wheelEvent(self, event):
            zoom_in_factor = 1.25
            zoom_out_factor = 1 / zoom_in_factor
            if event.angleDelta().y() > 0:
                self.scale(zoom_in_factor, zoom_in_factor)
            else:
                self.scale(zoom_out_factor, zoom_out_factor)
            event.accept()
    
        def _add_point(self, view_pos):
    
    Christian's avatar
    Christian committed
            """Add a removable red dot at the clicked location."""
    
            scene_pos = self.mapToScene(view_pos)
            x, y = scene_pos.x(), scene_pos.y()
    
    s224389's avatar
    s224389 committed
            dot = self._create_point(x, y, label="", radius=self.dot_radius, color=Qt.red, removable=True)
    
    Christian's avatar
    Christian committed
            # Insert before the final E point if S/E exist
    
    s224389's avatar
    s224389 committed
            if len(self.points) >= 2:
                self.points.insert(len(self.points) - 1, dot)
            else:
                self.points.append(dot)
    
            self.scene.addItem(dot)
    
        def _remove_point(self, view_pos):
    
    s224389's avatar
    s224389 committed
            """Right-click => remove nearest dot if it's removable."""
    
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            threshold = 10
            closest_idx = None
            min_dist = float('inf')
    
    s224389's avatar
    s224389 committed
            for i, p in enumerate(self.points):
                dist = p.distance_to(x_click, y_click)
    
            if closest_idx is not None and min_dist <= threshold:
    
                if self.points[closest_idx].is_removable():
                    self.scene.removeItem(self.points[closest_idx])
                    del self.points[closest_idx]
    
    Christian's avatar
    Christian committed
                    self._run_find_path()
    
    
        def _find_point_near(self, view_pos, threshold=10):
            scene_pos = self.mapToScene(view_pos)
            x_click, y_click = scene_pos.x(), scene_pos.y()
    
            closest_idx = None
            min_dist = float('inf')
            for i, p in enumerate(self.points):
                dist = p.distance_to(x_click, y_click)
                if dist < min_dist:
                    min_dist = dist
                    closest_idx = i
            if closest_idx is not None and min_dist <= threshold:
                return closest_idx
            return None
    
    
        def _clear_point_items(self, remove_all=False):
    
    s224389's avatar
    s224389 committed
            """Remove all points if remove_all=True; else just removable ones."""
    
            if remove_all:
                for p in self.points:
                    self.scene.removeItem(p)
                self.points.clear()
            else:
                still_needed = []
                for p in self.points:
                    if p.is_removable():
                        self.scene.removeItem(p)
                    else:
                        still_needed.append(p)
                self.points = still_needed
    
    Christian's avatar
    Christian committed
            # Also remove any path points from the scene
            for p_item in self.path_points:
                self.scene.removeItem(p_item)
            self.path_points.clear()
    
        def _run_find_path(self):
            """
            Convert the first two points (S/E) from (x,y) to (row,col)
            and call find_path(). Then display the path in magenta.
            """
            # If we don't have at least 2 points, no path
            if len(self.points) < 2:
                return
            if self.cost_image is None:
                return
    
            # Clear old path visualization
            for item in self.path_points:
                self.scene.removeItem(item)
            self.path_points.clear()
    
            # We'll define the path between the first and last point,
            # or if you specifically want the first two, you can do self.points[:2].
            s_x, s_y = self.points[0].get_pos()
            e_x, e_y = self.points[-1].get_pos()
    
            # Convert (x, y) => (row, col) = (int(y), int(x)) and clamp
            h, w = self.cost_image.shape
            s_r = int(round(s_y)); s_c = int(round(s_x))
            e_r = int(round(e_y)); e_c = int(round(e_x))
    
            # Ensure they're inside the cost_image boundary
            s_r = max(0, min(s_r, h-1))
            s_c = max(0, min(s_c, w-1))
            e_r = max(0, min(e_r, h-1))
            e_c = max(0, min(e_c, w-1))
    
            # Attempt path
            try:
                path_rc = find_path(self.cost_image, [(s_r, s_c), (e_r, e_c)])
            except ValueError as e:
                print("Error in find_path:", e)
                return
    
            # Convert path (row,col) => (x, y)
            for (r, c) in path_rc:
                x = c
                y = r
                item = self._create_point(x, y, "", self.path_radius, Qt.red, removable=False)
                self.path_points.append(item)
                self.scene.addItem(item)
    
    
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            main_widget = QWidget()
            main_layout = QVBoxLayout(main_widget)
    
            self.image_view = ImageGraphicsView()
            main_layout.addWidget(self.image_view)
    
            btn_layout = QHBoxLayout()
    
            # Load Image
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
            # Editor Mode
            self.btn_editor_mode = QPushButton("Editor Mode: OFF")
            self.btn_editor_mode.setCheckable(True)
            self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
            self.btn_editor_mode.clicked.connect(self.toggle_editor_mode)
            btn_layout.addWidget(self.btn_editor_mode)
    
            # Export Points
            self.btn_export_points = QPushButton("Export Points")
            self.btn_export_points.clicked.connect(self.export_points)
            btn_layout.addWidget(self.btn_export_points)
    
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
            main_layout.addLayout(btn_layout)
            self.setCentralWidget(main_widget)
            self.resize(900, 600)
    
        def load_image(self):
    
            """Open file dialog to pick an image, then load it."""
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
    
    Christian's avatar
    Christian committed
                # Compute cost image
                self.image_view.cost_image = compute_cost_image(file_path)
    
    
        def toggle_editor_mode(self):
            is_checked = self.btn_editor_mode.isChecked()
            self.image_view.set_editor_mode(is_checked)
            if is_checked:
                self.btn_editor_mode.setText("Editor Mode: ON")
                self.btn_editor_mode.setStyleSheet("background-color: #ffcccc;")
            else:
                self.btn_editor_mode.setText("Editor Mode: OFF")
                self.btn_editor_mode.setStyleSheet("background-color: lightgray;")
    
        def export_points(self):
    
            if not self.image_view.points:
                print("No points to export.")
                return
    
    
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Points", "",
                "NumPy Files (*.npy);;All Files (*)",
                options=options
            )
            if file_path:
    
                coords = [p.get_pos() for p in self.image_view.points]
                points_array = np.array(coords)
    
                np.save(file_path, points_array)
                print(f"Exported {len(points_array)} points to {file_path}")
    
        def clear_points(self):
    
            self.image_view._clear_point_items(remove_all=False)
    
    s224389's avatar
    s224389 committed
    
    
    def main():
        app = QApplication(sys.argv)
        window = MainWindow()
        window.show()
        sys.exit(app.exec_())
    
    
    
    s224389's avatar
    s224389 committed
    if __name__ == "__main__":