Skip to content
Snippets Groups Projects
mainWindow.py 8.75 KiB
Newer Older
  • Learn to ignore specific revisions
  • import math
    import numpy as np
    from scipy.signal import savgol_filter
    from PyQt5.QtWidgets import (
        QMainWindow, QPushButton, QHBoxLayout, 
        QVBoxLayout, QWidget, QFileDialog
    )
    from PyQt5.QtGui import QPixmap, QImage
    from compute_cost_image import compute_cost_image
    from preprocess_image import preprocess_image
    from advancedSettingsWidget import AdvancedSettingsWidget
    from imageGraphicsView import ImageGraphicsView
    from circleEditorWidget import CircleEditorWidget
    
    class MainWindow(QMainWindow):
        def __init__(self):
            super().__init__()
            self.setWindowTitle("Test GUI")
    
            self._last_loaded_pixmap = None
            self._circle_calibrated_radius = 6
            self._last_loaded_file_path = None
    
            # For the contrast slider
            self._current_clip_limit = 0.01
    
            # Outer widget + layout
            self._main_widget = QWidget()
            self._main_layout = QHBoxLayout(self._main_widget)
    
            # The "left" part: container for the image area + its controls
            self._left_panel = QVBoxLayout()
    
            # We'll make a container widget for the left panel, so we can set stretches:
            self._left_container = QWidget()
            self._left_container.setLayout(self._left_panel)
    
            # Now we add them to the main layout with 70%:30% ratio
            self._main_layout.addWidget(self._left_container, 7)  # 70%
            
            # We haven't added the advanced widget yet, but we'll do so with ratio=3 => 30%
            self._advanced_widget = AdvancedSettingsWidget(self)
            # Hide it initially
            self._advanced_widget.hide()
            self._main_layout.addWidget(self._advanced_widget, 3)
    
            self.setCentralWidget(self._main_widget)
    
            # The image view
            self.image_view = ImageGraphicsView()
            self._left_panel.addWidget(self.image_view)
    
            # Button row
            btn_layout = QHBoxLayout()
            self.btn_load_image = QPushButton("Load Image")
            self.btn_load_image.clicked.connect(self.load_image)
            btn_layout.addWidget(self.btn_load_image)
    
            self.btn_export_path = QPushButton("Export Path")
            self.btn_export_path.clicked.connect(self.export_path)
            btn_layout.addWidget(self.btn_export_path)
    
            self.btn_clear_points = QPushButton("Clear Points")
            self.btn_clear_points.clicked.connect(self.clear_points)
            btn_layout.addWidget(self.btn_clear_points)
    
            # "Advanced Settings" toggle
            self.btn_advanced = QPushButton("Advanced Settings")
            self.btn_advanced.setCheckable(True)
            self.btn_advanced.clicked.connect(self._toggle_advanced_settings)
            btn_layout.addWidget(self.btn_advanced)
    
            self._left_panel.addLayout(btn_layout)
    
            self.resize(1000, 600)
            self._old_central_widget = None
            self._editor = None
    
        def _toggle_advanced_settings(self, checked):
            if checked:
                self._advanced_widget.show()
            else:
                self._advanced_widget.hide()
            # Force re-layout
            self.adjustSize()
    
        def open_circle_editor(self):
            """ Replace central widget with circle editor. """
            if not self._last_loaded_pixmap:
                print("No image loaded yet! Cannot open circle editor.")
                return
    
            old_widget = self.takeCentralWidget()
            self._old_central_widget = old_widget
    
            init_radius = self._circle_calibrated_radius
            editor = CircleEditorWidget(
                pixmap=self._last_loaded_pixmap,
                init_radius=init_radius,
                done_callback=self._on_circle_editor_done
            )
            self._editor = editor
            self.setCentralWidget(editor)
    
        def _on_circle_editor_done(self, final_radius):
            self._circle_calibrated_radius = final_radius
            print(f"Circle Editor done. Radius = {final_radius}")
    
            if self._last_loaded_file_path:
                cost_img = compute_cost_image(
                    self._last_loaded_file_path,
                    self._circle_calibrated_radius,
                    clip_limit=self._current_clip_limit
                )
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
                self.image_view._apply_all_guide_points_to_cost()
                self.image_view._rebuild_full_path()
                self._update_advanced_images()
    
            editor_widget = self.takeCentralWidget()
            if editor_widget is not None:
                editor_widget.setParent(None)
    
            if self._old_central_widget is not None:
                self.setCentralWidget(self._old_central_widget)
                self._old_central_widget = None
    
            if self._editor is not None:
                self._editor.deleteLater()
                self._editor = None
    
        def toggle_rainbow(self):
            self.image_view.toggle_rainbow()
    
        def load_image(self):
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getOpenFileName(
                self, "Open Image", "",
                "Images (*.png *.jpg *.jpeg *.bmp *.tif)",
                options=options
            )
            if file_path:
                self.image_view.load_image(file_path)
    
                cost_img = compute_cost_image(
                    file_path,
                    self._circle_calibrated_radius,
                    clip_limit=self._current_clip_limit
                )
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
    
                pm = QPixmap(file_path)
                if not pm.isNull():
                    self._last_loaded_pixmap = pm
    
                self._last_loaded_file_path = file_path
                self._update_advanced_images()
    
        def update_contrast(self, clip_limit):
            self._current_clip_limit = clip_limit
            if self._last_loaded_file_path:
                cost_img = compute_cost_image(
                    self._last_loaded_file_path,
                    self._circle_calibrated_radius,
                    clip_limit=clip_limit
                )
                self.image_view.cost_image_original = cost_img
                self.image_view.cost_image = cost_img.copy()
                self.image_view._apply_all_guide_points_to_cost()
                self.image_view._rebuild_full_path()
    
            self._update_advanced_images()
    
        def _update_advanced_images(self):
            if not self._last_loaded_pixmap:
                return
            pm_np = self._qpixmap_to_gray_float(self._last_loaded_pixmap)
            contrasted_blurred = preprocess_image(
                pm_np,
                sigma=3,
                clip_limit=self._current_clip_limit
            )
            cost_img_np = self.image_view.cost_image
            self._advanced_widget.update_displays(contrasted_blurred, cost_img_np)
    
        def _qpixmap_to_gray_float(self, qpix):
            img = qpix.toImage()
            img = img.convertToFormat(QImage.Format_ARGB32)
            ptr = img.bits()
            ptr.setsize(img.byteCount())
            arr = np.frombuffer(ptr, np.uint8).reshape((img.height(), img.width(), 4))
            rgb = arr[..., :3].astype(np.float32)
            gray = rgb.mean(axis=2) / 255.0
            return gray
    
        def export_path(self):
            """
            Exports the path as a CSV in the format: x, y, TYPE,
            ensuring that each anchor influences exactly one path point.
            """
            full_xy = self.image_view.get_full_path_xy()
            if not full_xy:
                print("No path to export.")
                return
    
            # We'll consider each anchor point as "USER-PLACED".
            # But unlike a distance-threshold approach, we assign each anchor
            # to exactly one closest path point.
            anchor_points = self.image_view.anchor_points
    
            # For each anchor, find the index of the closest path point
            user_placed_indices = set()
            for ax, ay in anchor_points:
                min_dist = float('inf')
                closest_idx = None
                for i, (px, py) in enumerate(full_xy):
                    dist = math.hypot(px - ax, py - ay)
                    if dist < min_dist:
                        min_dist = dist
                        closest_idx = i
                if closest_idx is not None:
                    user_placed_indices.add(closest_idx)
    
            # Ask user for the CSV filename
            options = QFileDialog.Options()
            file_path, _ = QFileDialog.getSaveFileName(
                self, "Export Path", "",
                "CSV Files (*.csv);;All Files (*)",
                options=options
            )
            if not file_path:
                return
    
            import csv
            with open(file_path, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(["x", "y", "TYPE"])
    
                for i, (x, y) in enumerate(full_xy):
                    ptype = "USER-PLACED" if i in user_placed_indices else "PATH"
                    writer.writerow([x, y, ptype])
    
            print(f"Exported path with {len(full_xy)} points to {file_path}")
    
        def clear_points(self):
            self.image_view.clear_guide_points()
    
        def closeEvent(self, event):
            super().closeEvent(event)