diff --git a/ast2d/__init__.py b/ast2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f15a7c960bd7192ca1c6c06523f667e8e1785079 --- /dev/null +++ b/ast2d/__init__.py @@ -0,0 +1,13 @@ +import sys +from PyQt5.QtWidgets import QApplication +from .mainWindow import MainWindow + +def main(): + app = QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ast2d/__pycache__/__init__.cpython-310.pyc b/ast2d/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b27e7d2af13426ea3e07b8356a8d4f3a3537b643 Binary files /dev/null and b/ast2d/__pycache__/__init__.cpython-310.pyc differ diff --git a/ast2d/__pycache__/advancedSettingsWidget.cpython-310.pyc b/ast2d/__pycache__/advancedSettingsWidget.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54244ebfa2b2c008c1d1c29dc26e5e41c97c324a Binary files /dev/null and b/ast2d/__pycache__/advancedSettingsWidget.cpython-310.pyc differ diff --git a/ast2d/__pycache__/circleEditorGraphicsView.cpython-310.pyc b/ast2d/__pycache__/circleEditorGraphicsView.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f9e5767d207f9cf7705b7cc0f922fbe0e7538fc Binary files /dev/null and b/ast2d/__pycache__/circleEditorGraphicsView.cpython-310.pyc differ diff --git a/ast2d/__pycache__/circleEditorWidget.cpython-310.pyc b/ast2d/__pycache__/circleEditorWidget.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7110c576008b3c637114913c35c71e289f3bb45e Binary files /dev/null and b/ast2d/__pycache__/circleEditorWidget.cpython-310.pyc differ diff --git a/ast2d/__pycache__/circle_edge_kernel.cpython-310.pyc b/ast2d/__pycache__/circle_edge_kernel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3c7fc0d04717ae6aca9d1cdc4d35d53995bf69e Binary files /dev/null and b/ast2d/__pycache__/circle_edge_kernel.cpython-310.pyc differ diff --git a/ast2d/__pycache__/compute_cost_image.cpython-310.pyc b/ast2d/__pycache__/compute_cost_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baecc404908466ddcdf86b320446ef6c083b0be9 Binary files /dev/null and b/ast2d/__pycache__/compute_cost_image.cpython-310.pyc differ diff --git a/ast2d/__pycache__/compute_disk_size.cpython-310.pyc b/ast2d/__pycache__/compute_disk_size.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b117c107aaa6b8248d9f11a59c2d464176a5243 Binary files /dev/null and b/ast2d/__pycache__/compute_disk_size.cpython-310.pyc differ diff --git a/ast2d/__pycache__/draggableCircleItem.cpython-310.pyc b/ast2d/__pycache__/draggableCircleItem.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e52caac1a41facfb0bd0a074a987fd9d4c5e1234 Binary files /dev/null and b/ast2d/__pycache__/draggableCircleItem.cpython-310.pyc differ diff --git a/ast2d/__pycache__/find_path.cpython-310.pyc b/ast2d/__pycache__/find_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb04cf6448adf91b543234b99e84c26abb94d721 Binary files /dev/null and b/ast2d/__pycache__/find_path.cpython-310.pyc differ diff --git a/ast2d/__pycache__/imageGraphicsView.cpython-310.pyc b/ast2d/__pycache__/imageGraphicsView.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aa5867d7d89d02a091b26f50864c32e42ffc379 Binary files /dev/null and b/ast2d/__pycache__/imageGraphicsView.cpython-310.pyc differ diff --git a/ast2d/__pycache__/labeledPointItem.cpython-310.pyc b/ast2d/__pycache__/labeledPointItem.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978aafc5bcc2653043b1d103e1935c2baf21bff9 Binary files /dev/null and b/ast2d/__pycache__/labeledPointItem.cpython-310.pyc differ diff --git a/ast2d/__pycache__/load_image.cpython-310.pyc b/ast2d/__pycache__/load_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963948684b5ede87065750dcf44efdfc98c3ea73 Binary files /dev/null and b/ast2d/__pycache__/load_image.cpython-310.pyc differ diff --git a/ast2d/__pycache__/mainWindow.cpython-310.pyc b/ast2d/__pycache__/mainWindow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71fe376075022673956a3c4d2a4f79e811b528a3 Binary files /dev/null and b/ast2d/__pycache__/mainWindow.cpython-310.pyc differ diff --git a/ast2d/__pycache__/panZoomGraphicsView.cpython-310.pyc b/ast2d/__pycache__/panZoomGraphicsView.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f12867c40be921a81b1059fd4ad7a040e88e610b Binary files /dev/null and b/ast2d/__pycache__/panZoomGraphicsView.cpython-310.pyc differ diff --git a/ast2d/__pycache__/preprocess_image.cpython-310.pyc b/ast2d/__pycache__/preprocess_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0dfa7c9fb3934e154036fdb3cdef6adb6b2c66c Binary files /dev/null and b/ast2d/__pycache__/preprocess_image.cpython-310.pyc differ diff --git a/ast2d/advancedSettingsWidget.py b/ast2d/advancedSettingsWidget.py new file mode 100644 index 0000000000000000000000000000000000000000..065a8e29b6a404c430db8d265fa872d923e53e4b --- /dev/null +++ b/ast2d/advancedSettingsWidget.py @@ -0,0 +1,174 @@ +from PyQt5.QtWidgets import ( + QPushButton, QVBoxLayout, QWidget, + QSlider, QLabel, QGridLayout, QSizePolicy +) +from PyQt5.QtGui import QPixmap, QImage, QShowEvent +from PyQt5.QtCore import Qt +import numpy as np +from typing import Optional + +class AdvancedSettingsWidget(QWidget): + """ + Shows toggle rainbow, circle editor, line smoothing slider, contrast slider, + plus two image previews (contrasted-blurred and cost). + The images maintain aspect ratio upon resize. + """ + def __init__(self, main_window, parent: Optional[QWidget] = None): + """ + Constructor. + """ + super().__init__(parent) + self._main_window = main_window + + self._last_cb_pix = None # store QPixmap for contrasted-blurred image + self._last_cost_pix = None # store QPixmap for cost image + + main_layout = QVBoxLayout() + self.setLayout(main_layout) + + # A small grid for controls + controls_layout = QGridLayout() + + # Rainbow toggle + self.btn_toggle_rainbow = QPushButton("Toggle Rainbow") + self.btn_toggle_rainbow.clicked.connect(self._on_toggle_rainbow) + controls_layout.addWidget(self.btn_toggle_rainbow, 0, 0) + + # Disk size calibration (Circle editor) + self.btn_circle_editor = QPushButton("Calibrate Kernel Size") + self.btn_circle_editor.clicked.connect(self._main_window.open_circle_editor) + controls_layout.addWidget(self.btn_circle_editor, 0, 1) + + # Line smoothing slider + label + self._lab_smoothing = QLabel("Line smoothing (3)") + controls_layout.addWidget(self._lab_smoothing, 1, 0) + self.line_smoothing_slider = QSlider(Qt.Horizontal) + self.line_smoothing_slider.setRange(3, 51) + self.line_smoothing_slider.setValue(3) + self.line_smoothing_slider.valueChanged.connect(self._on_line_smoothing_slider) + controls_layout.addWidget(self.line_smoothing_slider, 1, 1) + + # Contrast slider + label + self._lab_contrast = QLabel("Contrast (0.01)") + controls_layout.addWidget(self._lab_contrast, 2, 0) + self.contrast_slider = QSlider(Qt.Horizontal) + self.contrast_slider.setRange(1, 20) + self.contrast_slider.setValue(1) # i.e. 0.01 + self.contrast_slider.setSingleStep(1) + self.contrast_slider.valueChanged.connect(self._on_contrast_slider) + controls_layout.addWidget(self.contrast_slider, 2, 1) + + main_layout.addLayout(controls_layout) + + self.setMinimumWidth(350) + + # A vertical layout for the two images, each with a label above it + images_layout = QVBoxLayout() + + # Contrasted-blurred label + image + self.label_cb_title = QLabel("Contrasted Blurred Image") + self.label_cb_title.setAlignment(Qt.AlignCenter) + images_layout.addWidget(self.label_cb_title) + + self.label_contrasted_blurred = QLabel() + self.label_contrasted_blurred.setAlignment(Qt.AlignCenter) + self.label_contrasted_blurred.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + images_layout.addWidget(self.label_contrasted_blurred) + + # Cost image label + image + self.label_cost_title = QLabel("Current COST IMAGE") + self.label_cost_title.setAlignment(Qt.AlignCenter) + images_layout.addWidget(self.label_cost_title) + + self.label_cost_image = QLabel() + self.label_cost_image.setAlignment(Qt.AlignCenter) + self.label_cost_image.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + images_layout.addWidget(self.label_cost_image) + + main_layout.addLayout(images_layout) + + def showEvent(self, event: QShowEvent): + """ When shown, ask parent to resize to accommodate. """ + super().showEvent(event) + if self.parentWidget(): + self.parentWidget().adjustSize() + + def resizeEvent(self, event: QShowEvent): + """ + Keep the images at correct aspect ratio by re-scaling + stored pixmaps to the new label sizes. + """ + super().resizeEvent(event) + self._update_labels() + + def _update_labels(self): + """ + Re-scale stored pixmaps to the new label sizes. + """ + if self._last_cb_pix is not None: + scaled_cb = self._last_cb_pix.scaled( + self.label_contrasted_blurred.size(), + Qt.KeepAspectRatio, + Qt.SmoothTransformation + ) + self.label_contrasted_blurred.setPixmap(scaled_cb) + + if self._last_cost_pix is not None: + scaled_cost = self._last_cost_pix.scaled( + self.label_cost_image.size(), + Qt.KeepAspectRatio, + Qt.SmoothTransformation + ) + self.label_cost_image.setPixmap(scaled_cost) + + def _on_toggle_rainbow(self): + """ + Called when the rainbow toggle button is clicked. + """ + self._main_window.toggle_rainbow() + + def _on_line_smoothing_slider(self, value: int): + """ + Called when the line smoothing slider is moved. + """ + self._lab_smoothing.setText(f"Line smoothing ({value})") + self._main_window.image_view.set_savgol_window_length(value) + + def _on_contrast_slider(self, value: int): + """ + Called when the contrast slider is moved. + """ + clip_limit = value / 100.0 + self._lab_contrast.setText(f"Contrast ({clip_limit:.2f})") + self._main_window.update_contrast(clip_limit) + + def update_displays(self, contrasted_img_np: np.ndarray, cost_img_np: np.ndarray): + """ + Update the contrasted-blurred and cost images. + """ + cb_pix = self._np_array_to_qpixmap(contrasted_img_np) + cost_pix = self._np_array_to_qpixmap(cost_img_np, normalize=True) + + self._last_cb_pix = cb_pix + self._last_cost_pix = cost_pix + self._update_labels() + + def _np_array_to_qpixmap(self, arr: np.ndarray, normalize: bool = False) -> QPixmap: + """ + Convert a numpy array to a QPixmap. + """ + if arr is None: + return None + arr_ = arr.copy() + if normalize: + mn, mx = arr_.min(), arr_.max() + if abs(mx - mn) < 1e-12: + arr_[:] = 0 + else: + arr_ = (arr_ - mn) / (mx - mn) + arr_ = np.clip(arr_, 0, 1) + arr_255 = (arr_ * 255).astype(np.uint8) + + h, w = arr_255.shape + qimage = QImage(arr_255.data, w, h, w, QImage.Format_Grayscale8) + return QPixmap.fromImage(qimage) diff --git a/ast2d/circleEditorGraphicsView.py b/ast2d/circleEditorGraphicsView.py new file mode 100644 index 0000000000000000000000000000000000000000..ffebeb787a22f09262dd68a5f39122a5bbd94e6a --- /dev/null +++ b/ast2d/circleEditorGraphicsView.py @@ -0,0 +1,56 @@ +from PyQt5.QtWidgets import QGraphicsView, QWidget +from .panZoomGraphicsView import PanZoomGraphicsView +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QMouseEvent, QWheelEvent +from .draggableCircleItem import DraggableCircleItem +from typing import Optional + +# A specialized PanZoomGraphicsView for the circle editor (disk size calibration) +class CircleEditorGraphicsView(PanZoomGraphicsView): + def __init__(self, circle_editor_widget, parent: Optional[QWidget] = None): + """ + Constructor. + """ + super().__init__(parent) + self._circle_editor_widget = circle_editor_widget + + def mousePressEvent(self, event: QMouseEvent): + """ + If the user clicks on the circle, we let the circle item handle the event. + """ + if event.button() == Qt.LeftButton: + # Check if user clicked on the circle item + clicked_item = self.itemAt(event.pos()) + if clicked_item is not None: + # climb up parent chain + it = clicked_item + while it is not None and not hasattr(it, "boundingRect"): + it = it.parentItem() + + if isinstance(it, DraggableCircleItem): + # Let normal item-dragging occur, no pan + return QGraphicsView.mousePressEvent(self, event) + super().mousePressEvent(event) + + def wheelEvent(self, event: QWheelEvent): + """ + If the user scrolls the mouse wheel over the circle, we change the circle + """ + pos_in_widget = event.pos() + item_under = self.itemAt(pos_in_widget) + if item_under is not None: + it = item_under + while it is not None and not hasattr(it, "boundingRect"): + it = it.parentItem() + + if isinstance(it, DraggableCircleItem): + delta = event.angleDelta().y() + step = 1 if delta > 0 else -1 + old_r = it.radius() + new_r = max(1, old_r + step) + it.set_radius(new_r) + self._circle_editor_widget.update_slider_value(new_r) + event.accept() + return + + super().wheelEvent(event) \ No newline at end of file diff --git a/ast2d/circleEditorWidget.py b/ast2d/circleEditorWidget.py new file mode 100644 index 0000000000000000000000000000000000000000..280787e5f462f7573d1fbd9d863dc42bcda72b9b --- /dev/null +++ b/ast2d/circleEditorWidget.py @@ -0,0 +1,100 @@ +from PyQt5.QtWidgets import ( + QGraphicsScene, QGraphicsPixmapItem, QPushButton, + QHBoxLayout, QVBoxLayout, QWidget, QSlider, QLabel +) +from PyQt5.QtGui import QFont, QPixmap +from PyQt5.QtCore import Qt, QRectF, QSize +from .circleEditorGraphicsView import CircleEditorGraphicsView +from .draggableCircleItem import DraggableCircleItem +from typing import Optional, Callable + +class CircleEditorWidget(QWidget): + """ + A widget for the user to calibrate the disk size (kernel size) for the ridge detection. + """ + def __init__(self, pixmap: QPixmap, init_radius: int = 20, done_callback: Optional[Callable[[], None]] = None, parent: Optional[QWidget] = None): + """ + Constructor. + """ + super().__init__(parent) + self._pixmap = pixmap + self._done_callback = done_callback + self._init_radius = init_radius + + layout = QVBoxLayout(self) + self.setLayout(layout) + + # Add centered label above image + label_instructions = QLabel("Scale the dot to be of the size of your ridge") + label_instructions.setAlignment(Qt.AlignCenter) + big_font = QFont("Arial", 20) + big_font.setBold(True) + label_instructions.setFont(big_font) + layout.addWidget(label_instructions) + + # Show the image + self._graphics_view = CircleEditorGraphicsView(circle_editor_widget=self) + self._scene = QGraphicsScene(self) + self._graphics_view.setScene(self._scene) + layout.addWidget(self._graphics_view) + + self._image_item = QGraphicsPixmapItem(self._pixmap) + self._scene.addItem(self._image_item) + + # Put circle in center + cx = self._pixmap.width() / 2 + cy = self._pixmap.height() / 2 + self._circle_item = DraggableCircleItem(cx, cy, radius=self._init_radius, color=Qt.red) + self._scene.addItem(self._circle_item) + + # Fit in view + self._graphics_view.setSceneRect(QRectF(self._pixmap.rect())) + self._graphics_view.fitInView(self._image_item, Qt.KeepAspectRatio) + + ### Controls below + bottom_layout = QHBoxLayout() + layout.addLayout(bottom_layout) + + # label + slider + self._lbl_size = QLabel(f"size ({self._init_radius})") + bottom_layout.addWidget(self._lbl_size) + + self._slider = QSlider(Qt.Horizontal) + self._slider.setRange(1, 200) + self._slider.setValue(self._init_radius) + bottom_layout.addWidget(self._slider) + + # Done button + self._btn_done = QPushButton("Done") + bottom_layout.addWidget(self._btn_done) + + # Connect signals + self._slider.valueChanged.connect(self._on_slider_changed) + self._btn_done.clicked.connect(self._on_done_clicked) + + def _on_slider_changed(self, value: int): + """ + Handle slider value changes. + """ + self._circle_item.set_radius(value) + self._lbl_size.setText(f"size ({value})") + + def _on_done_clicked(self): + """ + Handle the user clicking the "Done" button. + """ + final_radius = self._circle_item.radius() + if self._done_callback is not None: + self._done_callback(final_radius) + + def update_slider_value(self, new_radius: int): + """ + Update the slider value. + """ + self._slider.blockSignals(True) + self._slider.setValue(new_radius) + self._slider.blockSignals(False) + self._lbl_size.setText(f"size ({new_radius})") + + def sizeHint(self): + return QSize(800, 600) diff --git a/ast2d/circle_edge_kernel.py b/ast2d/circle_edge_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..f8587459ca36433198857e31488018368068747f --- /dev/null +++ b/ast2d/circle_edge_kernel.py @@ -0,0 +1,34 @@ +import numpy as np +from typing import Optional + +def circle_edge_kernel(k_size: int = 5, radius: Optional[int] = None) -> np.ndarray: + """ + Create a k_size x k_size array whose values increase linearly + from 0 at the center to 1 at the circle boundary (radius). + + Args: + k_size: The size (width and height) of the kernel array. + radius: The circle's radius. By default, set to (k_size-1)/2. + + Returns: + kernel: The circle-edge-weighted kernel. + """ + if radius is None: + # By default, let the radius be half the kernel size + radius = (k_size - 1) / 2 + + # Create an empty kernel + kernel = np.zeros((k_size, k_size), dtype=float) + + # Coordinates of the center + center = radius # same as (k_size-1)/2 if radius is default + + # Fill the kernel + for y in range(k_size): + for x in range(k_size): + dist = np.sqrt((x - center)**2 + (y - center)**2) + if dist <= radius: + # Weight = distance / radius => 0 at center, 1 at boundary + kernel[y, x] = dist / radius + + return kernel \ No newline at end of file diff --git a/ast2d/compute_cost_image.py b/ast2d/compute_cost_image.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8b61b3b11ad9ffbde82d07ce6efa6238b2e6c8 --- /dev/null +++ b/ast2d/compute_cost_image.py @@ -0,0 +1,41 @@ +from skimage.feature import canny +from scipy.signal import convolve2d +from .compute_disk_size import compute_disk_size +from .load_image import load_image +from .preprocess_image import preprocess_image +from .circle_edge_kernel import circle_edge_kernel +import numpy as np + +def compute_cost_image(path: str, user_radius: int, sigma: int = 3, clip_limit: float = 0.01) -> np.ndarray: + """ + Compute the cost image for a given image path, user radius, and optional parameters. + + Args: + path: The path to the image file. + user_radius: The radius of the disk. + sigma: The standard deviation for Gaussian smoothing. + clip_limit: The limit for contrasting the image. + + Returns: + The cost image as a NumPy array. + """ + disk_size = compute_disk_size(user_radius) + + # Load image + image = load_image(path) + + # Apply smoothing + smoothed_img = preprocess_image(image, sigma=sigma, clip_limit=clip_limit) + + # Apply Canny edge detection + canny_img = canny(smoothed_img) + + # Perform disk convolution + binary_img = canny_img + kernel = circle_edge_kernel(k_size=disk_size) + convolved = convolve2d(binary_img, kernel, mode='same', boundary='fill') + + # Create cost image + cost_img = (convolved.max() - convolved)**4 # Invert edges: higher cost where edges are stronger + + return cost_img \ No newline at end of file diff --git a/ast2d/compute_disk_size.py b/ast2d/compute_disk_size.py new file mode 100644 index 0000000000000000000000000000000000000000..84f9b6d6d2938c13ddd3d728fc3cf53af48bc99c --- /dev/null +++ b/ast2d/compute_disk_size.py @@ -0,0 +1,14 @@ +import numpy as np + +def compute_disk_size(user_radius: int, upscale_factor: float = 1.2) -> int: + """ + Compute the size of the disk to be used in the cost image computation. + + Args: + user_radius: The radius in pixels. + upscale_factor: The factor by which the disk size will be upscaled. + + Returns: + The size of the disk. + """ + return int(np.ceil(upscale_factor * 2 * user_radius + 1) // 2 * 2 + 1) \ No newline at end of file diff --git a/ast2d/downscale.py b/ast2d/downscale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca95b390b444b9ffe271b3422f93045ec3ea7fd7 --- /dev/null +++ b/ast2d/downscale.py @@ -0,0 +1,39 @@ +import cv2 +import numpy as np +from typing import Tuple + +# Currently not implemented +def downscale(img: np.ndarray, points: Tuple[Tuple[int, int], Tuple[int, int]], scale_percent: int) -> Tuple[np.ndarray, Tuple[Tuple[int, int], Tuple[int, int]]]: + """ + Downscale an image and its corresponding points. + + Args: + img: The image. + points: The points to downscale. + scale_percent: The percentage to downscale to. E.g. scale_percent = 60 results in a new image 60% of the original image's size. + + Returns: + The downsampled image and the downsampled points. + """ + if scale_percent == 100: + return img, (tuple(points[0]), tuple(points[1])) + else: + # Compute new dimensions + width = int(img.shape[1] * scale_percent / 100) + height = int(img.shape[0] * scale_percent / 100) + new_dimensions = (width, height) + + # Downsample + downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA) + + # Scaling factors + scale_x = width / img.shape[1] + scale_y = height / img.shape[0] + + # Scale the points (x, y) + seed_xy = tuple(points[0]) + target_xy = tuple(points[1]) + scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y)) + scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y)) + + return downsampled_img, (scaled_seed_xy, scaled_target_xy) \ No newline at end of file diff --git a/ast2d/draggableCircleItem.py b/ast2d/draggableCircleItem.py new file mode 100644 index 0000000000000000000000000000000000000000..884dc41b0fa58eccff8cde148f13e42dab892b1e --- /dev/null +++ b/ast2d/draggableCircleItem.py @@ -0,0 +1,46 @@ +from PyQt5.QtWidgets import QGraphicsEllipseItem, QGraphicsItem +from PyQt5.QtGui import QPen, QBrush, QColor +from PyQt5.QtCore import Qt +from typing import Optional + +class DraggableCircleItem(QGraphicsEllipseItem): + """ + A QGraphicsEllipseItem that can be dragged around. + """ + def __init__(self, x: float, y: float, radius: float = 20, color: QColor = Qt.red, parent: Optional[QGraphicsItem] = None): + """ + Constructor. + """ + super().__init__(0, 0, 2*radius, 2*radius, parent) + self._r = radius + + pen = QPen(color) + brush = QBrush(color) + self.setPen(pen) + self.setBrush(brush) + + # Enable item-based dragging + self.setFlags(QGraphicsEllipseItem.ItemIsMovable | + QGraphicsEllipseItem.ItemIsSelectable | + QGraphicsEllipseItem.ItemSendsScenePositionChanges) + + # Position so that (x, y) is the center + self.setPos(x - radius, y - radius) + + def set_radius(self, r: float): + """ + Set the radius of the circle + """ + old_center = self.sceneBoundingRect().center() + self._r = r + self.setRect(0, 0, 2*r, 2*r) + new_center = self.sceneBoundingRect().center() + diff_x = old_center.x() - new_center.x() + diff_y = old_center.y() - new_center.y() + self.moveBy(diff_x, diff_y) + + def radius(self): + """ + Get the radius of the circle + """ + return self._r \ No newline at end of file diff --git a/ast2d/find_path.py b/ast2d/find_path.py new file mode 100644 index 0000000000000000000000000000000000000000..426563ddd4843ab49ab9c08f8424cbd21508d864 --- /dev/null +++ b/ast2d/find_path.py @@ -0,0 +1,17 @@ +from skimage.graph import route_through_array + +def find_path(cost_image, points): + + if len(points) != 2: + raise ValueError("Points should be a list of 2 points: seed and target.") + + seed_rc, target_rc = points + + path_rc, cost = route_through_array( + cost_image, + start=seed_rc, + end=target_rc, + fully_connected=True + ) + + return path_rc \ No newline at end of file diff --git a/ast2d/imageGraphicsView.py b/ast2d/imageGraphicsView.py new file mode 100644 index 0000000000000000000000000000000000000000..77bb7bb9cc93b3efab5df50de005ff38f3814405 --- /dev/null +++ b/ast2d/imageGraphicsView.py @@ -0,0 +1,492 @@ +from scipy.signal import savgol_filter +from PyQt5.QtWidgets import QGraphicsScene, QGraphicsPixmapItem +from PyQt5.QtGui import QPixmap, QColor +from PyQt5.QtCore import Qt, QRectF, QPoint +import math +import numpy as np +from .panZoomGraphicsView import PanZoomGraphicsView +from .labeledPointItem import LabeledPointItem +from .find_path import find_path + + +class ImageGraphicsView(PanZoomGraphicsView): + """ + A custom QGraphicsView for displaying and interacting with an image. + + This class extends PanZoomGraphicsView to provide additional functionality + for loading images, adding labeled anchor points, and computing paths + between points based on a cost image. + """ + + def __init__(self, parent=None): + super().__init__(parent) + self.scene = QGraphicsScene(self) + self.setScene(self.scene) + + # Image display + self.image_item = QGraphicsPixmapItem() + self.scene.addItem(self.image_item) + + self.anchor_points = [] + self.point_items = [] + self.full_path_points = [] + self._full_path_xy = [] + + self.dot_radius = 4 + self.path_radius = 1 + self.radius_cost_image = 2 + self._img_w = 0 + self._img_h = 0 + + self._mouse_pressed = False + self._press_view_pos = None + self._drag_threshold = 5 + self._was_dragging = False + self._dragging_idx = None + self._drag_offset = (0, 0) + self._drag_counter = 0 + + # Cost images + self.cost_image_original = None + self.cost_image = None + + # Rainbow toggle => start with OFF + self._rainbow_enabled = False + + # Smoothing parameters + self._savgol_window_length = 7 + + def set_rainbow_enabled(self, enabled: bool): + """Enable rainbow coloring of the path.""" + self._rainbow_enabled = enabled + self._rebuild_full_path() + + def toggle_rainbow(self): + """Toggle rainbow coloring of the path.""" + self._rainbow_enabled = not self._rainbow_enabled + self._rebuild_full_path() + + def set_savgol_window_length(self, wlen: int): + """Set the window length for Savitzky-Golay smoothing.""" + wlen = max(3, wlen) + if wlen % 2 == 0: + wlen += 1 + self._savgol_window_length = wlen + + self._rebuild_full_path() + + # -------------------------------------------------------------------- + # LOADING + # -------------------------------------------------------------------- + def load_image(self, path: str): + """Load an image from a file path.""" + pixmap = QPixmap(path) + if not pixmap.isNull(): + self.image_item.setPixmap(pixmap) + self.setSceneRect(QRectF(pixmap.rect())) + + self._img_w = pixmap.width() + self._img_h = pixmap.height() + + self._clear_all_points() + self.resetTransform() + self.fitInView(self.image_item, Qt.KeepAspectRatio) + + # By default, add S/E + s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h + e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h + self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6) + self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6) + + # -------------------------------------------------------------------- + # ANCHOR POINTS + # -------------------------------------------------------------------- + def _insert_anchor_point(self, idx, x: float, y: float, label="", removable=True, + z_val=0, radius=4): + """Insert an anchor point at a specific index.""" + x_clamped = self._clamp(x, radius, self._img_w - radius) + y_clamped = self._clamp(y, radius, self._img_h - radius) + + if idx < 0: + # Insert before E if there's at least 2 anchors + if len(self.anchor_points) >= 2: + idx = len(self.anchor_points) - 1 + else: + idx = len(self.anchor_points) + + self.anchor_points.insert(idx, (x_clamped, y_clamped)) + color = Qt.green if label in ("S", "E") else Qt.red + item = LabeledPointItem(x_clamped, y_clamped, + label=label, radius=radius, color=color, + removable=removable, z_value=z_val) + self.point_items.insert(idx, item) + self.scene.addItem(item) + + def _add_guide_point(self, x, y): + """Add a guide point to the path.""" + x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius) + y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius) + + self._revert_cost_to_original() + + if not self._full_path_xy: + self._insert_anchor_point(-1, x_clamped, y_clamped, + label="", removable=True, z_val=1, radius=self.dot_radius) + else: + self._insert_anchor_between_subpath(x_clamped, y_clamped) + + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() + + def _insert_anchor_between_subpath(self, x_new: float, y_new: float ): + """Insert an anchor point between existing anchor points.""" # If somehow we have no path yet + # If somehow we have no path yet + if not self._full_path_xy: + self._insert_anchor_point(-1, x_new, y_new) + return + + # Find nearest point in the current full path + best_idx = None + best_d2 = float('inf') + for i, (px, py) in enumerate(self._full_path_xy): + dx = px - x_new + dy = py - y_new + d2 = dx*dx + dy*dy + if d2 < best_d2: + best_d2 = d2 + best_idx = i + + if best_idx is None: + self._insert_anchor_point(-1, x_new, y_new) + return + + def approx_equal(xa, ya, xb, yb, tol=1e-3): + """Check if two points are approximately equal.""" + return (abs(xa - xb) < tol) and (abs(ya - yb) < tol) + + def is_anchor(coord): + """Check if a point is an anchor point.""" + cx, cy = coord + for (ax, ay) in self.anchor_points: + if approx_equal(ax, ay, cx, cy): + return True + return False + + # Walk left + left_anchor_pt = None + iL = best_idx + while iL >= 0: + px, py = self._full_path_xy[iL] + if is_anchor((px, py)): + left_anchor_pt = (px, py) + break + iL -= 1 + + # Walk right + right_anchor_pt = None + iR = best_idx + while iR < len(self._full_path_xy): + px, py = self._full_path_xy[iR] + if is_anchor((px, py)): + right_anchor_pt = (px, py) + break + iR += 1 + + # If we can't find distinct anchors on left & right, + # just insert before E. + if not left_anchor_pt or not right_anchor_pt: + self._insert_anchor_point(-1, x_new, y_new) + return + if left_anchor_pt == right_anchor_pt: + self._insert_anchor_point(-1, x_new, y_new) + return + + # Convert anchor coords -> anchor_points indices + left_idx = None + right_idx = None + for i, (ax, ay) in enumerate(self.anchor_points): + if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]): + left_idx = i + if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]): + right_idx = i + + if left_idx is None or right_idx is None: + self._insert_anchor_point(-1, x_new, y_new) + return + + # Insert between them + if left_idx < right_idx: + insert_idx = left_idx + 1 + else: + insert_idx = right_idx + 1 + + self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True, + z_val=1, radius=self.dot_radius) + + # -------------------------------------------------------------------- + # COST IMAGE + # -------------------------------------------------------------------- + def _revert_cost_to_original(self): + if self.cost_image_original is not None: + self.cost_image = self.cost_image_original.copy() + + def _apply_all_guide_points_to_cost(self): + if self.cost_image is None: + return + for i, (ax, ay) in enumerate(self.anchor_points): + if self.point_items[i].is_removable(): + self._lower_cost_in_circle(ax, ay, self.radius_cost_image) + + def _lower_cost_in_circle(self, x_f: float, y_f: float, radius: int): + """Lower the cost in a circle centered at (x_f, y_f).""" + if self.cost_image is None: + return + h, w = self.cost_image.shape + row_c = int(round(y_f)) + col_c = int(round(x_f)) + if not (0 <= row_c < h and 0 <= col_c < w): + return + global_min = self.cost_image.min() + r_s = max(0, row_c - radius) + r_e = min(h, row_c + radius + 1) + c_s = max(0, col_c - radius) + c_e = min(w, col_c + radius + 1) + for rr in range(r_s, r_e): + for cc in range(c_s, c_e): + dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2) + if dist <= radius: + self.cost_image[rr, cc] = global_min + + # -------------------------------------------------------------------- + # PATH BUILDING + # -------------------------------------------------------------------- + def _rebuild_full_path(self): + """Rebuild the full path based on the anchor points.""" + for item in self.full_path_points: + self.scene.removeItem(item) + self.full_path_points.clear() + self._full_path_xy.clear() + + if len(self.anchor_points) < 2 or self.cost_image is None: + return + + big_xy = [] + for i in range(len(self.anchor_points) - 1): + xA, yA = self.anchor_points[i] + xB, yB = self.anchor_points[i + 1] + sub_xy = self._compute_subpath_xy(xA, yA, xB, yB) + if i == 0: + big_xy.extend(sub_xy) + else: + if len(sub_xy) > 1: + big_xy.extend(sub_xy[1:]) + + if len(big_xy) >= self._savgol_window_length: + arr_xy = np.array(big_xy) + smoothed = savgol_filter( + arr_xy, + window_length=self._savgol_window_length, + polyorder=2, + axis=0 + ) + big_xy = smoothed.tolist() + + self._full_path_xy = big_xy[:] + + n_points = len(big_xy) + for i, (px, py) in enumerate(big_xy): + fraction = i / (n_points - 1) if n_points > 1 else 0 + color = Qt.red + if self._rainbow_enabled: + color = self._rainbow_color(fraction) + + path_item = LabeledPointItem(px, py, label="", + radius=self.path_radius, + color=color, + removable=False, + z_value=0) + self.full_path_points.append(path_item) + self.scene.addItem(path_item) + + # Keep anchor labels on top + for p_item in self.point_items: + if p_item._text_item: + p_item.setZValue(100) + + def _compute_subpath_xy(self, xA: float, yA: float, xB: float, yB: float): + """Compute a subpath between two points.""" + if self.cost_image is None: + return [] + h, w = self.cost_image.shape + rA, cA = int(round(yA)), int(round(xA)) + rB, cB = int(round(yB)), int(round(xB)) + rA = max(0, min(rA, h - 1)) + cA = max(0, min(cA, w - 1)) + rB = max(0, min(rB, h - 1)) + cB = max(0, min(cB, w - 1)) + try: + path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)]) + except ValueError as e: + print("Error in find_path:", e) + return [] + # Convert from (row, col) to (x, y) + return [(c, r) for (r, c) in path_rc] + + def _rainbow_color(self, fraction: float): + """Get a rainbow color.""" + hue = int(300 * fraction) + saturation = 255 + value = 255 + return QColor.fromHsv(hue, saturation, value) + + # -------------------------------------------------------------------- + # MOUSE EVENTS + # -------------------------------------------------------------------- + def mousePressEvent(self, event): + """Handle mouse press events for dragging a point or adding a point.""" + if event.button() == Qt.LeftButton: + self._mouse_pressed = True + self._was_dragging = False + self._press_view_pos = event.pos() + + idx = self._find_item_near(event.pos(), threshold=10) + if idx is not None: + self._dragging_idx = idx + self._drag_counter = 0 + scene_pos = self.mapToScene(event.pos()) + px, py = self.point_items[idx].get_pos() + self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py) + self.setCursor(Qt.ClosedHandCursor) + return + + elif event.button() == Qt.RightButton: + self._remove_point_by_click(event.pos()) + + super().mousePressEvent(event) + + def mouseMoveEvent(self, event): + """Handle mouse move events for dragging a point or dragging the view""" + if self._dragging_idx is not None: + scene_pos = self.mapToScene(event.pos()) + x_new = scene_pos.x() - self._drag_offset[0] + y_new = scene_pos.y() - self._drag_offset[1] + + r = self.point_items[self._dragging_idx]._r + x_clamped = self._clamp(x_new, r, self._img_w - r) + y_clamped = self._clamp(y_new, r, self._img_h - r) + self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped) + + self._drag_counter += 1 + # Update path every 4 moves + if self._drag_counter >= 4: + self._drag_counter = 0 + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped) + self._rebuild_full_path() + else: + if self._mouse_pressed and (event.buttons() & Qt.LeftButton): + dist = (event.pos() - self._press_view_pos).manhattanLength() + if dist > self._drag_threshold: + self._was_dragging = True + + super().mouseMoveEvent(event) + + def mouseReleaseEvent(self, event): + """Handle mouse release events for dragging a point or adding a point.""" + super().mouseReleaseEvent(event) + if event.button() == Qt.LeftButton and self._mouse_pressed: + self._mouse_pressed = False + self.setCursor(Qt.ArrowCursor) + + if self._dragging_idx is not None: + idx = self._dragging_idx + self._dragging_idx = None + self._drag_offset = (0, 0) + newX, newY = self.point_items[idx].get_pos() + self.anchor_points[idx] = (newX, newY) + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() + else: + # No drag => add point + if not self._was_dragging: + scene_pos = self.mapToScene(event.pos()) + x, y = scene_pos.x(), scene_pos.y() + self._add_guide_point(x, y) + + self._was_dragging = False + + def _remove_point_by_click(self, view_pos: QPoint): + """Remove a point by clicking on it.""" + idx = self._find_item_near(view_pos, threshold=10) + if idx is None: + return + if not self.point_items[idx].is_removable(): + return + + self.scene.removeItem(self.point_items[idx]) + self.point_items.pop(idx) + self.anchor_points.pop(idx) + + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() + + def _find_item_near(self, view_pos: QPoint, threshold=10): + """Find the index of an item near a given position.""" + scene_pos = self.mapToScene(view_pos) + x_click, y_click = scene_pos.x(), scene_pos.y() + + closest_idx = None + min_dist = float('inf') + for i, itm in enumerate(self.point_items): + d = itm.distance_to(x_click, y_click) + if d < min_dist: + min_dist = d + closest_idx = i + if closest_idx is not None and min_dist <= threshold: + return closest_idx + return None + + # -------------------------------------------------------------------- + # UTILS + # -------------------------------------------------------------------- + def _clamp(self, val, mn, mx): + return max(mn, min(val, mx)) + + def _clear_all_points(self): + """Clear all anchor points and guide points.""" + for it in self.point_items: + self.scene.removeItem(it) + self.point_items.clear() + self.anchor_points.clear() + + for p in self.full_path_points: + self.scene.removeItem(p) + self.full_path_points.clear() + self._full_path_xy.clear() + + def clear_guide_points(self): + """Clear all guide points.""" + i = 0 + while i < len(self.anchor_points): + if self.point_items[i].is_removable(): + self.scene.removeItem(self.point_items[i]) + del self.point_items[i] + del self.anchor_points[i] + else: + i += 1 + + for it in self.full_path_points: + self.scene.removeItem(it) + self.full_path_points.clear() + self._full_path_xy.clear() + + self._revert_cost_to_original() + self._apply_all_guide_points_to_cost() + self._rebuild_full_path() + + def get_full_path_xy(self): + """Returns the entire path as a list of (x, y) coordinates.""" + return self._full_path_xy \ No newline at end of file diff --git a/ast2d/labeledPointItem.py b/ast2d/labeledPointItem.py new file mode 100644 index 0000000000000000000000000000000000000000..f96bb8ff7875daead14c7ada43dc899e6ebad4b3 --- /dev/null +++ b/ast2d/labeledPointItem.py @@ -0,0 +1,87 @@ +import math +from PyQt5.QtWidgets import QGraphicsEllipseItem, QGraphicsTextItem +from PyQt5.QtGui import QPen, QBrush, QColor, QFont +from PyQt5.QtCore import Qt + + +class LabeledPointItem(QGraphicsEllipseItem): + """ + A QGraphicsEllipseItem subclass that represents a labeled point in a 2D space. + + This class creates a circular point. + The point can be customized with different colors, sizes, and labels, and can + be marked as removable. + """ + + def __init__(self, x: float, y: float, label: str ="", radius:int =4, + color=Qt.red, removable=True, z_value=0, parent=None): + super().__init__(0, 0, 2*radius, 2*radius, parent) + self._x = x + self._y = y + self._r = radius + self._removable = removable + + pen = QPen(color) + brush = QBrush(color) + self.setPen(pen) + self.setBrush(brush) + self.setZValue(z_value) + + self._text_item = None + if label: + self._text_item = QGraphicsTextItem(self) + self._text_item.setPlainText(label) + self._text_item.setDefaultTextColor(QColor("black")) + font = QFont("Arial", 14) + font.setBold(True) + self._text_item.setFont(font) + self._scale_text_to_fit() + + self.set_pos(x, y) + + def _scale_text_to_fit(self): + """Scales the text to fit inside the circle.""" + if not self._text_item: + return + self._text_item.setScale(1.0) + circle_diam = 2 * self._r + raw_rect = self._text_item.boundingRect() + text_w = raw_rect.width() + text_h = raw_rect.height() + if text_w > circle_diam or text_h > circle_diam: + scale_factor = min(circle_diam / text_w, circle_diam / text_h) + self._text_item.setScale(scale_factor) + self._center_label() + + def _center_label(self): + """Centers the text inside the circle.""" + if not self._text_item: + return + ellipse_w = 2 * self._r + ellipse_h = 2 * self._r + raw_rect = self._text_item.boundingRect() + scale_factor = self._text_item.scale() + scaled_w = raw_rect.width() * scale_factor + scaled_h = raw_rect.height() * scale_factor + tx = (ellipse_w - scaled_w) * 0.5 + ty = (ellipse_h - scaled_h) * 0.5 + self._text_item.setPos(tx, ty) + + def set_pos(self, x, y): + """Positions the circle so its center is at (x, y).""" + self._x = x + self._y = y + self.setPos(x - self._r, y - self._r) + + def get_pos(self): + """Returns the (x, y) coordinates of the center of the circle.""" + return (self._x, self._y) + + def distance_to(self, x_other, y_other): + """Returns the Euclidean distance from the center + of the circle to another circle.""" + return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2) + + def is_removable(self): + """Returns True if the point is removable, False otherwise.""" + return self._removable diff --git a/ast2d/load_image.py b/ast2d/load_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fa7eadf42d6098f0cc14bee50a4f015d9d0240 --- /dev/null +++ b/ast2d/load_image.py @@ -0,0 +1,13 @@ +import cv2 + +def load_image(path: str) -> "numpy.ndarray": + """ + Loads an image from the specified file path in grayscale mode. + + Args: + path (str): The file path to the image. + + Returns: + numpy.ndarray: The loaded grayscale image. + """ + return cv2.imread(path, cv2.IMREAD_GRAYSCALE) \ No newline at end of file diff --git a/ast2d/mainWindow.py b/ast2d/mainWindow.py new file mode 100644 index 0000000000000000000000000000000000000000..94dd3239d8060017840dca4cca7748a04f2d79b1 --- /dev/null +++ b/ast2d/mainWindow.py @@ -0,0 +1,297 @@ +import math +import numpy as np +from PyQt5.QtWidgets import ( + QMainWindow, QPushButton, QHBoxLayout, + QVBoxLayout, QWidget, QFileDialog +) +from PyQt5.QtGui import QPixmap, QImage, QCloseEvent +from .compute_cost_image import compute_cost_image +from .preprocess_image import preprocess_image +from .advancedSettingsWidget import AdvancedSettingsWidget +from .imageGraphicsView import ImageGraphicsView +from .circleEditorWidget import CircleEditorWidget + +class MainWindow(QMainWindow): + def __init__(self): + """ + Initialize the main window for the application. + + This method sets up the main window, including the layout, widgets, and initial state. + It initializes various attributes related to the image processing and user interface. + """ + super().__init__() + self.setWindowTitle("Test GUI") + + self._last_loaded_pixmap = None + self._circle_calibrated_radius = 6 + self._last_loaded_file_path = None + + # Value for the contrast slider + self._current_clip_limit = 0.01 + + # Outer widget and layout + self._main_widget = QWidget() + self._main_layout = QHBoxLayout(self._main_widget) + + # Container for the image area and its controls + self._left_panel = QVBoxLayout() + + # Container widget for stretching the panel + self._left_container = QWidget() + self._left_container.setLayout(self._left_panel) + + self._main_layout.addWidget(self._left_container, 7) # 70% ratio of the full window + + # Advanced widget window + self._advanced_widget = AdvancedSettingsWidget(self) + self._advanced_widget.hide() + self._main_layout.addWidget(self._advanced_widget, 3) # 30% ratio of the full window + + self.setCentralWidget(self._main_widget) + + # The image view + self.image_view = ImageGraphicsView() + self._left_panel.addWidget(self.image_view) + + # Button row + btn_layout = QHBoxLayout() + self.btn_load_image = QPushButton("Load Image") + self.btn_load_image.clicked.connect(self.load_image) + btn_layout.addWidget(self.btn_load_image) + + self.btn_export_path = QPushButton("Export Path") + self.btn_export_path.clicked.connect(self.export_path) + btn_layout.addWidget(self.btn_export_path) + + self.btn_clear_points = QPushButton("Clear Points") + self.btn_clear_points.clicked.connect(self.clear_points) + btn_layout.addWidget(self.btn_clear_points) + + self.btn_advanced = QPushButton("Advanced Settings") + self.btn_advanced.setCheckable(True) + self.btn_advanced.clicked.connect(self._toggle_advanced_settings) + btn_layout.addWidget(self.btn_advanced) + + self._left_panel.addLayout(btn_layout) + + self.resize(1000, 600) + self._old_central_widget = None + self._editor = None + + def _toggle_advanced_settings(self, checked: bool): + """ + Toggles the visibility of the advanced settings widget. + """ + if checked: + self._advanced_widget.show() + else: + self._advanced_widget.hide() + # Force re-layout + self.adjustSize() + + + def open_circle_editor(self): + """ + Replace central widget with circle editor. + """ + if not self._last_loaded_pixmap: + print("No image loaded yet! Cannot open circle editor.") + return + + old_widget = self.takeCentralWidget() + self._old_central_widget = old_widget + + init_radius = self._circle_calibrated_radius + editor = CircleEditorWidget( + pixmap=self._last_loaded_pixmap, + init_radius=init_radius, + done_callback=self._on_circle_editor_done + ) + self._editor = editor + self.setCentralWidget(editor) + + + def _on_circle_editor_done(self, final_radius: int): + """ + Updates the calibrated radius, computes the cost image based on the new radius, + and updates the image view with the new cost image. + It also restores the previous central widget and cleans up the editor widget. + """ + self._circle_calibrated_radius = final_radius + print(f"Circle Editor done. Radius = {final_radius}") + + # Update cost image and path using new radius + if self._last_loaded_file_path: + cost_img = compute_cost_image( + self._last_loaded_file_path, + self._circle_calibrated_radius, + clip_limit=self._current_clip_limit + ) + self.image_view.cost_image_original = cost_img + self.image_view.cost_image = cost_img.copy() + self.image_view._apply_all_guide_points_to_cost() + self.image_view._rebuild_full_path() + self._update_advanced_images() + + # Swap back to central widget + editor_widget = self.takeCentralWidget() + if editor_widget is not None: + editor_widget.setParent(None) + + if self._old_central_widget is not None: + self.setCentralWidget(self._old_central_widget) + self._old_central_widget = None + + if self._editor is not None: + self._editor.deleteLater() + self._editor = None + + def toggle_rainbow(self): + """ + Toggle rainbow coloring of the path. + """ + self.image_view.toggle_rainbow() + + def load_image(self): + """ + Load an image and update the image view and cost image. + The supported image formats are: PNG, JPG, JPEG, BMP, and TIF. + """ + options = QFileDialog.Options() + file_path, _ = QFileDialog.getOpenFileName( + self, "Open Image", "", + "Images (*.png *.jpg *.jpeg *.bmp *.tif)", + options=options + ) + + if file_path: + self.image_view.load_image(file_path) + + cost_img = compute_cost_image( + file_path, + self._circle_calibrated_radius, + clip_limit=self._current_clip_limit + ) + self.image_view.cost_image_original = cost_img + self.image_view.cost_image = cost_img.copy() + + pm = QPixmap(file_path) + if not pm.isNull(): + self._last_loaded_pixmap = pm + + self._last_loaded_file_path = file_path + self._update_advanced_images() + + def update_contrast(self, clip_limit: float): + """ + Updates and applies the contrast value of the image. + """ + self._current_clip_limit = clip_limit + if self._last_loaded_file_path: + cost_img = compute_cost_image( + self._last_loaded_file_path, + self._circle_calibrated_radius, + clip_limit=clip_limit + ) + self.image_view.cost_image_original = cost_img + self.image_view.cost_image = cost_img.copy() + self.image_view._apply_all_guide_points_to_cost() + self.image_view._rebuild_full_path() + + self._update_advanced_images() + + def _update_advanced_images(self): + """ + Updates the advanced images display with the latest image. + If no image has been loaded, the method returns without making any updates. + """ + if not self._last_loaded_pixmap: + return + pm_np = self._qpixmap_to_gray_float(self._last_loaded_pixmap) + contrasted_blurred = preprocess_image( + pm_np, + sigma=3, + clip_limit=self._current_clip_limit + ) + cost_img_np = self.image_view.cost_image + self._advanced_widget.update_displays(contrasted_blurred, cost_img_np) + + def _qpixmap_to_gray_float(self, qpix: QPixmap) -> np.ndarray: + """ + Convert a QPixmap to a grayscale float array. + + Args: + qpix: The QPixmap to be converted. + + Returns: + A 2D numpy array representing the grayscale image. + """ + img = qpix.toImage() + img = img.convertToFormat(QImage.Format_ARGB32) + ptr = img.bits() + ptr.setsize(img.byteCount()) + arr = np.frombuffer(ptr, np.uint8).reshape((img.height(), img.width(), 4)) + rgb = arr[..., :3].astype(np.float32) + gray = rgb.mean(axis=2) / 255.0 + return gray + + def export_path(self): + """ + Exports the path as a CSV in the format: x, y, TYPE, + ensuring that each anchor influences exactly one path point. + """ + full_xy = self.image_view.get_full_path_xy() + if not full_xy: + print("No path to export.") + return + + anchor_points = self.image_view.anchor_points + + # Finds the index of the closest path point for each anchor point + user_placed_indices = set() + for ax, ay in anchor_points: + min_dist = float('inf') + closest_idx = None + for i, (px, py) in enumerate(full_xy): + dist = math.hypot(px - ax, py - ay) + if dist < min_dist: + min_dist = dist + closest_idx = i + if closest_idx is not None: + user_placed_indices.add(closest_idx) + + # Ask user for the CSV filename + options = QFileDialog.Options() + file_path, _ = QFileDialog.getSaveFileName( + self, "Export Path", "", + "CSV Files (*.csv);;All Files (*)", + options=options + ) + if not file_path: + return + + import csv + with open(file_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["x", "y", "TYPE"]) + + for i, (x, y) in enumerate(full_xy): + ptype = "USER-PLACED" if i in user_placed_indices else "PATH" + writer.writerow([x, y, ptype]) + + print(f"Exported path with {len(full_xy)} points to {file_path}") + + def clear_points(self): + """ + Clears points from the image. + """ + self.image_view.clear_guide_points() + + def closeEvent(self, event: QCloseEvent): + """ + Handle the window close event. + + Args: + event: The close event. + """ + super().closeEvent(event) \ No newline at end of file diff --git a/ast2d/panZoomGraphicsView.py b/ast2d/panZoomGraphicsView.py new file mode 100644 index 0000000000000000000000000000000000000000..5adb46f298fde323ae5c42a1cdec32d24a2694b8 --- /dev/null +++ b/ast2d/panZoomGraphicsView.py @@ -0,0 +1,49 @@ +from PyQt5.QtWidgets import QGraphicsView, QSizePolicy +from PyQt5.QtCore import Qt + +class PanZoomGraphicsView(QGraphicsView): + """ + A QGraphicsView subclass that supports panning and zooming with the mouse. + """ + def __init__(self, parent=None): + super().__init__(parent) + self.setDragMode(QGraphicsView.NoDrag) # We'll handle panning manually + self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) + self._panning = False + self._pan_start = None + + # Expands layout + self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + def wheelEvent(self, event): + """ Zoom in/out with mouse wheel. """ + zoom_in_factor = 1.25 + zoom_out_factor = 1 / zoom_in_factor + if event.angleDelta().y() > 0: + self.scale(zoom_in_factor, zoom_in_factor) + else: + self.scale(zoom_out_factor, zoom_out_factor) + event.accept() + + def mousePressEvent(self, event): + """ If left button: Start panning (unless overridden). """ + if event.button() == Qt.LeftButton: + self._panning = True + self._pan_start = event.pos() + self.setCursor(Qt.ClosedHandCursor) + super().mousePressEvent(event) + + def mouseMoveEvent(self, event): + """ If panning, translate the scene. """ + if self._panning and self._pan_start is not None: + delta = event.pos() - self._pan_start + self._pan_start = event.pos() + self.translate(delta.x(), delta.y()) + super().mouseMoveEvent(event) + + def mouseReleaseEvent(self, event): + """ End panning. """ + if event.button() == Qt.LeftButton: + self._panning = False + self.setCursor(Qt.ArrowCursor) + super().mouseReleaseEvent(event) diff --git a/ast2d/preprocess_image.py b/ast2d/preprocess_image.py new file mode 100644 index 0000000000000000000000000000000000000000..351988f1b00389c592d35e1913a0ca1221cb7224 --- /dev/null +++ b/ast2d/preprocess_image.py @@ -0,0 +1,23 @@ +from skimage.filters import gaussian +from skimage import exposure + + +def preprocess_image(image: "np.ndarray", sigma: int = 3, clip_limit: float = 0.01) -> "np.ndarray": + """ + Preprocess the input image by applying histogram equalization and Gaussian smoothing. + + Args: + image: (ndarray): Input image to be processed. + sigma: (float, optional): Standard deviation for Gaussian kernel. Default is 3. + clip_limit: (float, optional): Clipping limit for contrast enhancement. Default is 0.01. + Returns: + ndarray: The preprocessed image. + """ + # Applies histogram equalization to enhance contrast + image_contrasted = exposure.equalize_adapthist( + image, clip_limit=clip_limit) + + # Applies smoothing + smoothed_img = gaussian(image_contrasted, sigma=sigma) + + return smoothed_img diff --git a/data/AgamodonSlice.png b/data/AgamodonSlice.png new file mode 100644 index 0000000000000000000000000000000000000000..b8d0db7cc7dc30c9c26d15ad9e2f43466e273097 Binary files /dev/null and b/data/AgamodonSlice.png differ diff --git a/data/AngustifronsSlice35.png b/data/AngustifronsSlice35.png new file mode 100644 index 0000000000000000000000000000000000000000..17cadf804cd3af6d626674f06419108cf792f7f1 Binary files /dev/null and b/data/AngustifronsSlice35.png differ diff --git a/data/BipesSlice4.png b/data/BipesSlice4.png new file mode 100644 index 0000000000000000000000000000000000000000..6d17dc3048990d94b3f195b693a1b112e7eb3a5e Binary files /dev/null and b/data/BipesSlice4.png differ diff --git a/data/BipesSlice4NoCropping.png b/data/BipesSlice4NoCropping.png new file mode 100644 index 0000000000000000000000000000000000000000..42c1bf536eecb3458f1221869415059a4a6e9bc3 Binary files /dev/null and b/data/BipesSlice4NoCropping.png differ diff --git a/data/agamodonPath.npy b/data/agamodonPath.npy new file mode 100644 index 0000000000000000000000000000000000000000..5081e86649a488510afbf0606cf8d9465c813a70 Binary files /dev/null and b/data/agamodonPath.npy differ diff --git a/data/agamodonPoints.npy b/data/agamodonPoints.npy new file mode 100644 index 0000000000000000000000000000000000000000..75ec4b2945d7e7a49b6b72de2fddeca8bcee6615 Binary files /dev/null and b/data/agamodonPoints.npy differ diff --git a/data/agamodon_slice.png b/data/agamodon_slice.png new file mode 100644 index 0000000000000000000000000000000000000000..aa58983b107d16300d3425d4150d64bb568de087 Binary files /dev/null and b/data/agamodon_slice.png differ diff --git a/data/angustifronsPoints.npy b/data/angustifronsPoints.npy new file mode 100644 index 0000000000000000000000000000000000000000..188f483b5c7a7433d144acd9e2787678e9fceccd Binary files /dev/null and b/data/angustifronsPoints.npy differ diff --git a/data/angustifrons_slice.png b/data/angustifrons_slice.png new file mode 100644 index 0000000000000000000000000000000000000000..21323872af3c72094cf891a9ecec4d74bf82fdff Binary files /dev/null and b/data/angustifrons_slice.png differ diff --git a/data/baikaPoints.npy b/data/baikaPoints.npy new file mode 100644 index 0000000000000000000000000000000000000000..e5c79954c80b6cd87d793e8c03830da3619c268c Binary files /dev/null and b/data/baikaPoints.npy differ diff --git a/data/baika_slice.png b/data/baika_slice.png new file mode 100644 index 0000000000000000000000000000000000000000..97e142dc529b74aa19cd2d89c1bd4a283fd96990 Binary files /dev/null and b/data/baika_slice.png differ diff --git a/data/bipesPoints.npy b/data/bipesPoints.npy new file mode 100644 index 0000000000000000000000000000000000000000..166d2be24b69db24a4e5488ddea7261c5165b2b5 Binary files /dev/null and b/data/bipesPoints.npy differ diff --git a/data/bipes_slice.png b/data/bipes_slice.png new file mode 100644 index 0000000000000000000000000000000000000000..c99746284766bd5f1694e6ac12c0adef08b8f87c Binary files /dev/null and b/data/bipes_slice.png differ diff --git a/data/exportedPath.npy b/data/exportedPath.npy new file mode 100644 index 0000000000000000000000000000000000000000..a2d546e367aa3e18efefe88ecbdca9a09c7205cd Binary files /dev/null and b/data/exportedPath.npy differ diff --git a/data/test_image.jpg b/data/test_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..632e369fa39c002c0eb0f96dac18b22be08cf418 Binary files /dev/null and b/data/test_image.jpg differ diff --git a/modules/main.py b/modules/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f76dbb1deeac59c4ef8488c7ab23d0788dd750 --- /dev/null +++ b/modules/main.py @@ -0,0 +1,13 @@ +import sys +from PyQt5.QtWidgets import QApplication +from mainWindow import MainWindow + +def main(): + app = QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..409d38a452073e31931809154549a11e5bb644cb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy>=1.23.3 +opencv_python>=4.9.0.80 +PyQt5>=5.15.11 +PyQt5_sip>=12.11.1 +scipy>=1.15.1 +scikit-image>=0.23.2 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..e34828040a293414d65dc1a20093de4b265b414a --- /dev/null +++ b/setup.py @@ -0,0 +1,44 @@ +from setuptools import find_packages, setup + +# Read the contents of your README file +with open("README.md", "r", encoding="utf-8") as f: + long_description = f.read() + +setup( + name="ast2d", + version="0.1.0", + author="Aske T. Rove, Christian L. Bjerregaard, Mikkel W. Breinstrup", + author_email="s224362@dtu.dk, s224389@dtu.dk, s224361@dtu.dk", + description="Interactive path tracing in 2D medical images", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://lab.compute.dtu.dk/QIM/tools/tracing_sutures", + packages=find_packages(), + include_package_data=True, + entry_points = { + 'console_scripts': [ + 'ast2d=ast2d:main' + ] + }, + classifiers=[ + "License :: MIT License", + "Programming Language :: Python :: 3", + "Development Status :: 3 - Alpha", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Image Processing", + "Topic :: Scientific/Engineering :: Visualization", + "Topic :: Software Development :: User Interfaces", + ], + python_requires=">=3.10", + install_requires=[ + "numpy<=1.26.4", + "opencv_python>=4.9.0.80", + "PyQt5>=5.15.11", + "PyQt5_sip>=12.11.1", + "scipy>=1.15.1", + "scikit-image>=0.23.2" + ] +) \ No newline at end of file