Skip to content
Snippets Groups Projects
Commit db70eff6 authored by s224362's avatar s224362
Browse files

merging into main

parents 67169bfd c88c0c77
Branches
No related tags found
No related merge requests found
Showing
with 1118 additions and 0 deletions
from skimage.feature import canny
from scipy.signal import convolve2d
from .compute_disk_size import compute_disk_size
from .load_image import load_image
from .preprocess_image import preprocess_image
from .circle_edge_kernel import circle_edge_kernel
import numpy as np
def compute_cost_image(path: str, user_radius: int, sigma: int = 3, clip_limit: float = 0.01) -> np.ndarray:
"""
Compute the cost image for a given image path, user radius, and optional parameters.
Args:
path: The path to the image file.
user_radius: The radius of the disk.
sigma: The standard deviation for Gaussian smoothing.
clip_limit: The limit for contrasting the image.
Returns:
The cost image as a NumPy array.
"""
disk_size = compute_disk_size(user_radius)
# Load image
image = load_image(path)
# Apply smoothing
smoothed_img = preprocess_image(image, sigma=sigma, clip_limit=clip_limit)
# Apply Canny edge detection
canny_img = canny(smoothed_img)
# Perform disk convolution
binary_img = canny_img
kernel = circle_edge_kernel(k_size=disk_size)
convolved = convolve2d(binary_img, kernel, mode='same', boundary='fill')
# Create cost image
cost_img = (convolved.max() - convolved)**4 # Invert edges: higher cost where edges are stronger
return cost_img
\ No newline at end of file
import numpy as np
def compute_disk_size(user_radius: int, upscale_factor: float = 1.2) -> int:
"""
Compute the size of the disk to be used in the cost image computation.
Args:
user_radius: The radius in pixels.
upscale_factor: The factor by which the disk size will be upscaled.
Returns:
The size of the disk.
"""
return int(np.ceil(upscale_factor * 2 * user_radius + 1) // 2 * 2 + 1)
\ No newline at end of file
import cv2
import numpy as np
from typing import Tuple
# Currently not implemented
def downscale(img: np.ndarray, points: Tuple[Tuple[int, int], Tuple[int, int]], scale_percent: int) -> Tuple[np.ndarray, Tuple[Tuple[int, int], Tuple[int, int]]]:
"""
Downscale an image and its corresponding points.
Args:
img: The image.
points: The points to downscale.
scale_percent: The percentage to downscale to. E.g. scale_percent = 60 results in a new image 60% of the original image's size.
Returns:
The downsampled image and the downsampled points.
"""
if scale_percent == 100:
return img, (tuple(points[0]), tuple(points[1]))
else:
# Compute new dimensions
width = int(img.shape[1] * scale_percent / 100)
height = int(img.shape[0] * scale_percent / 100)
new_dimensions = (width, height)
# Downsample
downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
# Scaling factors
scale_x = width / img.shape[1]
scale_y = height / img.shape[0]
# Scale the points (x, y)
seed_xy = tuple(points[0])
target_xy = tuple(points[1])
scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
return downsampled_img, (scaled_seed_xy, scaled_target_xy)
\ No newline at end of file
from PyQt5.QtWidgets import QGraphicsEllipseItem, QGraphicsItem
from PyQt5.QtGui import QPen, QBrush, QColor
from PyQt5.QtCore import Qt
from typing import Optional
class DraggableCircleItem(QGraphicsEllipseItem):
"""
A QGraphicsEllipseItem that can be dragged around.
"""
def __init__(self, x: float, y: float, radius: float = 20, color: QColor = Qt.red, parent: Optional[QGraphicsItem] = None):
"""
Constructor.
"""
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._r = radius
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
# Enable item-based dragging
self.setFlags(QGraphicsEllipseItem.ItemIsMovable |
QGraphicsEllipseItem.ItemIsSelectable |
QGraphicsEllipseItem.ItemSendsScenePositionChanges)
# Position so that (x, y) is the center
self.setPos(x - radius, y - radius)
def set_radius(self, r: float):
"""
Set the radius of the circle
"""
old_center = self.sceneBoundingRect().center()
self._r = r
self.setRect(0, 0, 2*r, 2*r)
new_center = self.sceneBoundingRect().center()
diff_x = old_center.x() - new_center.x()
diff_y = old_center.y() - new_center.y()
self.moveBy(diff_x, diff_y)
def radius(self):
"""
Get the radius of the circle
"""
return self._r
\ No newline at end of file
from skimage.graph import route_through_array
def find_path(cost_image, points):
if len(points) != 2:
raise ValueError("Points should be a list of 2 points: seed and target.")
seed_rc, target_rc = points
path_rc, cost = route_through_array(
cost_image,
start=seed_rc,
end=target_rc,
fully_connected=True
)
return path_rc
\ No newline at end of file
from scipy.signal import savgol_filter
from PyQt5.QtWidgets import QGraphicsScene, QGraphicsPixmapItem
from PyQt5.QtGui import QPixmap, QColor
from PyQt5.QtCore import Qt, QRectF, QPoint
import math
import numpy as np
from .panZoomGraphicsView import PanZoomGraphicsView
from .labeledPointItem import LabeledPointItem
from .find_path import find_path
class ImageGraphicsView(PanZoomGraphicsView):
"""
A custom QGraphicsView for displaying and interacting with an image.
This class extends PanZoomGraphicsView to provide additional functionality
for loading images, adding labeled anchor points, and computing paths
between points based on a cost image.
"""
def __init__(self, parent=None):
super().__init__(parent)
self.scene = QGraphicsScene(self)
self.setScene(self.scene)
# Image display
self.image_item = QGraphicsPixmapItem()
self.scene.addItem(self.image_item)
self.anchor_points = []
self.point_items = []
self.full_path_points = []
self._full_path_xy = []
self.dot_radius = 4
self.path_radius = 1
self.radius_cost_image = 2
self._img_w = 0
self._img_h = 0
self._mouse_pressed = False
self._press_view_pos = None
self._drag_threshold = 5
self._was_dragging = False
self._dragging_idx = None
self._drag_offset = (0, 0)
self._drag_counter = 0
# Cost images
self.cost_image_original = None
self.cost_image = None
# Rainbow toggle => start with OFF
self._rainbow_enabled = False
# Smoothing parameters
self._savgol_window_length = 7
def set_rainbow_enabled(self, enabled: bool):
"""Enable rainbow coloring of the path."""
self._rainbow_enabled = enabled
self._rebuild_full_path()
def toggle_rainbow(self):
"""Toggle rainbow coloring of the path."""
self._rainbow_enabled = not self._rainbow_enabled
self._rebuild_full_path()
def set_savgol_window_length(self, wlen: int):
"""Set the window length for Savitzky-Golay smoothing."""
wlen = max(3, wlen)
if wlen % 2 == 0:
wlen += 1
self._savgol_window_length = wlen
self._rebuild_full_path()
# --------------------------------------------------------------------
# LOADING
# --------------------------------------------------------------------
def load_image(self, path: str):
"""Load an image from a file path."""
pixmap = QPixmap(path)
if not pixmap.isNull():
self.image_item.setPixmap(pixmap)
self.setSceneRect(QRectF(pixmap.rect()))
self._img_w = pixmap.width()
self._img_h = pixmap.height()
self._clear_all_points()
self.resetTransform()
self.fitInView(self.image_item, Qt.KeepAspectRatio)
# By default, add S/E
s_x, s_y = 0.15 * self._img_w, 0.5 * self._img_h
e_x, e_y = 0.85 * self._img_w, 0.5 * self._img_h
self._insert_anchor_point(-1, s_x, s_y, label="S", removable=False, z_val=100, radius=6)
self._insert_anchor_point(-1, e_x, e_y, label="E", removable=False, z_val=100, radius=6)
# --------------------------------------------------------------------
# ANCHOR POINTS
# --------------------------------------------------------------------
def _insert_anchor_point(self, idx, x: float, y: float, label="", removable=True,
z_val=0, radius=4):
"""Insert an anchor point at a specific index."""
x_clamped = self._clamp(x, radius, self._img_w - radius)
y_clamped = self._clamp(y, radius, self._img_h - radius)
if idx < 0:
# Insert before E if there's at least 2 anchors
if len(self.anchor_points) >= 2:
idx = len(self.anchor_points) - 1
else:
idx = len(self.anchor_points)
self.anchor_points.insert(idx, (x_clamped, y_clamped))
color = Qt.green if label in ("S", "E") else Qt.red
item = LabeledPointItem(x_clamped, y_clamped,
label=label, radius=radius, color=color,
removable=removable, z_value=z_val)
self.point_items.insert(idx, item)
self.scene.addItem(item)
def _add_guide_point(self, x, y):
"""Add a guide point to the path."""
x_clamped = self._clamp(x, self.dot_radius, self._img_w - self.dot_radius)
y_clamped = self._clamp(y, self.dot_radius, self._img_h - self.dot_radius)
self._revert_cost_to_original()
if not self._full_path_xy:
self._insert_anchor_point(-1, x_clamped, y_clamped,
label="", removable=True, z_val=1, radius=self.dot_radius)
else:
self._insert_anchor_between_subpath(x_clamped, y_clamped)
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _insert_anchor_between_subpath(self, x_new: float, y_new: float ):
"""Insert an anchor point between existing anchor points.""" # If somehow we have no path yet
# If somehow we have no path yet
if not self._full_path_xy:
self._insert_anchor_point(-1, x_new, y_new)
return
# Find nearest point in the current full path
best_idx = None
best_d2 = float('inf')
for i, (px, py) in enumerate(self._full_path_xy):
dx = px - x_new
dy = py - y_new
d2 = dx*dx + dy*dy
if d2 < best_d2:
best_d2 = d2
best_idx = i
if best_idx is None:
self._insert_anchor_point(-1, x_new, y_new)
return
def approx_equal(xa, ya, xb, yb, tol=1e-3):
"""Check if two points are approximately equal."""
return (abs(xa - xb) < tol) and (abs(ya - yb) < tol)
def is_anchor(coord):
"""Check if a point is an anchor point."""
cx, cy = coord
for (ax, ay) in self.anchor_points:
if approx_equal(ax, ay, cx, cy):
return True
return False
# Walk left
left_anchor_pt = None
iL = best_idx
while iL >= 0:
px, py = self._full_path_xy[iL]
if is_anchor((px, py)):
left_anchor_pt = (px, py)
break
iL -= 1
# Walk right
right_anchor_pt = None
iR = best_idx
while iR < len(self._full_path_xy):
px, py = self._full_path_xy[iR]
if is_anchor((px, py)):
right_anchor_pt = (px, py)
break
iR += 1
# If we can't find distinct anchors on left & right,
# just insert before E.
if not left_anchor_pt or not right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
if left_anchor_pt == right_anchor_pt:
self._insert_anchor_point(-1, x_new, y_new)
return
# Convert anchor coords -> anchor_points indices
left_idx = None
right_idx = None
for i, (ax, ay) in enumerate(self.anchor_points):
if approx_equal(ax, ay, left_anchor_pt[0], left_anchor_pt[1]):
left_idx = i
if approx_equal(ax, ay, right_anchor_pt[0], right_anchor_pt[1]):
right_idx = i
if left_idx is None or right_idx is None:
self._insert_anchor_point(-1, x_new, y_new)
return
# Insert between them
if left_idx < right_idx:
insert_idx = left_idx + 1
else:
insert_idx = right_idx + 1
self._insert_anchor_point(insert_idx, x_new, y_new, label="", removable=True,
z_val=1, radius=self.dot_radius)
# --------------------------------------------------------------------
# COST IMAGE
# --------------------------------------------------------------------
def _revert_cost_to_original(self):
if self.cost_image_original is not None:
self.cost_image = self.cost_image_original.copy()
def _apply_all_guide_points_to_cost(self):
if self.cost_image is None:
return
for i, (ax, ay) in enumerate(self.anchor_points):
if self.point_items[i].is_removable():
self._lower_cost_in_circle(ax, ay, self.radius_cost_image)
def _lower_cost_in_circle(self, x_f: float, y_f: float, radius: int):
"""Lower the cost in a circle centered at (x_f, y_f)."""
if self.cost_image is None:
return
h, w = self.cost_image.shape
row_c = int(round(y_f))
col_c = int(round(x_f))
if not (0 <= row_c < h and 0 <= col_c < w):
return
global_min = self.cost_image.min()
r_s = max(0, row_c - radius)
r_e = min(h, row_c + radius + 1)
c_s = max(0, col_c - radius)
c_e = min(w, col_c + radius + 1)
for rr in range(r_s, r_e):
for cc in range(c_s, c_e):
dist = math.sqrt((rr - row_c)**2 + (cc - col_c)**2)
if dist <= radius:
self.cost_image[rr, cc] = global_min
# --------------------------------------------------------------------
# PATH BUILDING
# --------------------------------------------------------------------
def _rebuild_full_path(self):
"""Rebuild the full path based on the anchor points."""
for item in self.full_path_points:
self.scene.removeItem(item)
self.full_path_points.clear()
self._full_path_xy.clear()
if len(self.anchor_points) < 2 or self.cost_image is None:
return
big_xy = []
for i in range(len(self.anchor_points) - 1):
xA, yA = self.anchor_points[i]
xB, yB = self.anchor_points[i + 1]
sub_xy = self._compute_subpath_xy(xA, yA, xB, yB)
if i == 0:
big_xy.extend(sub_xy)
else:
if len(sub_xy) > 1:
big_xy.extend(sub_xy[1:])
if len(big_xy) >= self._savgol_window_length:
arr_xy = np.array(big_xy)
smoothed = savgol_filter(
arr_xy,
window_length=self._savgol_window_length,
polyorder=2,
axis=0
)
big_xy = smoothed.tolist()
self._full_path_xy = big_xy[:]
n_points = len(big_xy)
for i, (px, py) in enumerate(big_xy):
fraction = i / (n_points - 1) if n_points > 1 else 0
color = Qt.red
if self._rainbow_enabled:
color = self._rainbow_color(fraction)
path_item = LabeledPointItem(px, py, label="",
radius=self.path_radius,
color=color,
removable=False,
z_value=0)
self.full_path_points.append(path_item)
self.scene.addItem(path_item)
# Keep anchor labels on top
for p_item in self.point_items:
if p_item._text_item:
p_item.setZValue(100)
def _compute_subpath_xy(self, xA: float, yA: float, xB: float, yB: float):
"""Compute a subpath between two points."""
if self.cost_image is None:
return []
h, w = self.cost_image.shape
rA, cA = int(round(yA)), int(round(xA))
rB, cB = int(round(yB)), int(round(xB))
rA = max(0, min(rA, h - 1))
cA = max(0, min(cA, w - 1))
rB = max(0, min(rB, h - 1))
cB = max(0, min(cB, w - 1))
try:
path_rc = find_path(self.cost_image, [(rA, cA), (rB, cB)])
except ValueError as e:
print("Error in find_path:", e)
return []
# Convert from (row, col) to (x, y)
return [(c, r) for (r, c) in path_rc]
def _rainbow_color(self, fraction: float):
"""Get a rainbow color."""
hue = int(300 * fraction)
saturation = 255
value = 255
return QColor.fromHsv(hue, saturation, value)
# --------------------------------------------------------------------
# MOUSE EVENTS
# --------------------------------------------------------------------
def mousePressEvent(self, event):
"""Handle mouse press events for dragging a point or adding a point."""
if event.button() == Qt.LeftButton:
self._mouse_pressed = True
self._was_dragging = False
self._press_view_pos = event.pos()
idx = self._find_item_near(event.pos(), threshold=10)
if idx is not None:
self._dragging_idx = idx
self._drag_counter = 0
scene_pos = self.mapToScene(event.pos())
px, py = self.point_items[idx].get_pos()
self._drag_offset = (scene_pos.x() - px, scene_pos.y() - py)
self.setCursor(Qt.ClosedHandCursor)
return
elif event.button() == Qt.RightButton:
self._remove_point_by_click(event.pos())
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
"""Handle mouse move events for dragging a point or dragging the view"""
if self._dragging_idx is not None:
scene_pos = self.mapToScene(event.pos())
x_new = scene_pos.x() - self._drag_offset[0]
y_new = scene_pos.y() - self._drag_offset[1]
r = self.point_items[self._dragging_idx]._r
x_clamped = self._clamp(x_new, r, self._img_w - r)
y_clamped = self._clamp(y_new, r, self._img_h - r)
self.point_items[self._dragging_idx].set_pos(x_clamped, y_clamped)
self._drag_counter += 1
# Update path every 4 moves
if self._drag_counter >= 4:
self._drag_counter = 0
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self.anchor_points[self._dragging_idx] = (x_clamped, y_clamped)
self._rebuild_full_path()
else:
if self._mouse_pressed and (event.buttons() & Qt.LeftButton):
dist = (event.pos() - self._press_view_pos).manhattanLength()
if dist > self._drag_threshold:
self._was_dragging = True
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
"""Handle mouse release events for dragging a point or adding a point."""
super().mouseReleaseEvent(event)
if event.button() == Qt.LeftButton and self._mouse_pressed:
self._mouse_pressed = False
self.setCursor(Qt.ArrowCursor)
if self._dragging_idx is not None:
idx = self._dragging_idx
self._dragging_idx = None
self._drag_offset = (0, 0)
newX, newY = self.point_items[idx].get_pos()
self.anchor_points[idx] = (newX, newY)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
else:
# No drag => add point
if not self._was_dragging:
scene_pos = self.mapToScene(event.pos())
x, y = scene_pos.x(), scene_pos.y()
self._add_guide_point(x, y)
self._was_dragging = False
def _remove_point_by_click(self, view_pos: QPoint):
"""Remove a point by clicking on it."""
idx = self._find_item_near(view_pos, threshold=10)
if idx is None:
return
if not self.point_items[idx].is_removable():
return
self.scene.removeItem(self.point_items[idx])
self.point_items.pop(idx)
self.anchor_points.pop(idx)
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def _find_item_near(self, view_pos: QPoint, threshold=10):
"""Find the index of an item near a given position."""
scene_pos = self.mapToScene(view_pos)
x_click, y_click = scene_pos.x(), scene_pos.y()
closest_idx = None
min_dist = float('inf')
for i, itm in enumerate(self.point_items):
d = itm.distance_to(x_click, y_click)
if d < min_dist:
min_dist = d
closest_idx = i
if closest_idx is not None and min_dist <= threshold:
return closest_idx
return None
# --------------------------------------------------------------------
# UTILS
# --------------------------------------------------------------------
def _clamp(self, val, mn, mx):
return max(mn, min(val, mx))
def _clear_all_points(self):
"""Clear all anchor points and guide points."""
for it in self.point_items:
self.scene.removeItem(it)
self.point_items.clear()
self.anchor_points.clear()
for p in self.full_path_points:
self.scene.removeItem(p)
self.full_path_points.clear()
self._full_path_xy.clear()
def clear_guide_points(self):
"""Clear all guide points."""
i = 0
while i < len(self.anchor_points):
if self.point_items[i].is_removable():
self.scene.removeItem(self.point_items[i])
del self.point_items[i]
del self.anchor_points[i]
else:
i += 1
for it in self.full_path_points:
self.scene.removeItem(it)
self.full_path_points.clear()
self._full_path_xy.clear()
self._revert_cost_to_original()
self._apply_all_guide_points_to_cost()
self._rebuild_full_path()
def get_full_path_xy(self):
"""Returns the entire path as a list of (x, y) coordinates."""
return self._full_path_xy
\ No newline at end of file
import math
from PyQt5.QtWidgets import QGraphicsEllipseItem, QGraphicsTextItem
from PyQt5.QtGui import QPen, QBrush, QColor, QFont
from PyQt5.QtCore import Qt
class LabeledPointItem(QGraphicsEllipseItem):
"""
A QGraphicsEllipseItem subclass that represents a labeled point in a 2D space.
This class creates a circular point.
The point can be customized with different colors, sizes, and labels, and can
be marked as removable.
"""
def __init__(self, x: float, y: float, label: str ="", radius:int =4,
color=Qt.red, removable=True, z_value=0, parent=None):
super().__init__(0, 0, 2*radius, 2*radius, parent)
self._x = x
self._y = y
self._r = radius
self._removable = removable
pen = QPen(color)
brush = QBrush(color)
self.setPen(pen)
self.setBrush(brush)
self.setZValue(z_value)
self._text_item = None
if label:
self._text_item = QGraphicsTextItem(self)
self._text_item.setPlainText(label)
self._text_item.setDefaultTextColor(QColor("black"))
font = QFont("Arial", 14)
font.setBold(True)
self._text_item.setFont(font)
self._scale_text_to_fit()
self.set_pos(x, y)
def _scale_text_to_fit(self):
"""Scales the text to fit inside the circle."""
if not self._text_item:
return
self._text_item.setScale(1.0)
circle_diam = 2 * self._r
raw_rect = self._text_item.boundingRect()
text_w = raw_rect.width()
text_h = raw_rect.height()
if text_w > circle_diam or text_h > circle_diam:
scale_factor = min(circle_diam / text_w, circle_diam / text_h)
self._text_item.setScale(scale_factor)
self._center_label()
def _center_label(self):
"""Centers the text inside the circle."""
if not self._text_item:
return
ellipse_w = 2 * self._r
ellipse_h = 2 * self._r
raw_rect = self._text_item.boundingRect()
scale_factor = self._text_item.scale()
scaled_w = raw_rect.width() * scale_factor
scaled_h = raw_rect.height() * scale_factor
tx = (ellipse_w - scaled_w) * 0.5
ty = (ellipse_h - scaled_h) * 0.5
self._text_item.setPos(tx, ty)
def set_pos(self, x, y):
"""Positions the circle so its center is at (x, y)."""
self._x = x
self._y = y
self.setPos(x - self._r, y - self._r)
def get_pos(self):
"""Returns the (x, y) coordinates of the center of the circle."""
return (self._x, self._y)
def distance_to(self, x_other, y_other):
"""Returns the Euclidean distance from the center
of the circle to another circle."""
return math.sqrt((self._x - x_other)**2 + (self._y - y_other)**2)
def is_removable(self):
"""Returns True if the point is removable, False otherwise."""
return self._removable
import cv2
def load_image(path: str) -> "numpy.ndarray":
"""
Loads an image from the specified file path in grayscale mode.
Args:
path (str): The file path to the image.
Returns:
numpy.ndarray: The loaded grayscale image.
"""
return cv2.imread(path, cv2.IMREAD_GRAYSCALE)
\ No newline at end of file
import math
import numpy as np
from PyQt5.QtWidgets import (
QMainWindow, QPushButton, QHBoxLayout,
QVBoxLayout, QWidget, QFileDialog
)
from PyQt5.QtGui import QPixmap, QImage, QCloseEvent
from .compute_cost_image import compute_cost_image
from .preprocess_image import preprocess_image
from .advancedSettingsWidget import AdvancedSettingsWidget
from .imageGraphicsView import ImageGraphicsView
from .circleEditorWidget import CircleEditorWidget
class MainWindow(QMainWindow):
def __init__(self):
"""
Initialize the main window for the application.
This method sets up the main window, including the layout, widgets, and initial state.
It initializes various attributes related to the image processing and user interface.
"""
super().__init__()
self.setWindowTitle("Test GUI")
self._last_loaded_pixmap = None
self._circle_calibrated_radius = 6
self._last_loaded_file_path = None
# Value for the contrast slider
self._current_clip_limit = 0.01
# Outer widget and layout
self._main_widget = QWidget()
self._main_layout = QHBoxLayout(self._main_widget)
# Container for the image area and its controls
self._left_panel = QVBoxLayout()
# Container widget for stretching the panel
self._left_container = QWidget()
self._left_container.setLayout(self._left_panel)
self._main_layout.addWidget(self._left_container, 7) # 70% ratio of the full window
# Advanced widget window
self._advanced_widget = AdvancedSettingsWidget(self)
self._advanced_widget.hide()
self._main_layout.addWidget(self._advanced_widget, 3) # 30% ratio of the full window
self.setCentralWidget(self._main_widget)
# The image view
self.image_view = ImageGraphicsView()
self._left_panel.addWidget(self.image_view)
# Button row
btn_layout = QHBoxLayout()
self.btn_load_image = QPushButton("Load Image")
self.btn_load_image.clicked.connect(self.load_image)
btn_layout.addWidget(self.btn_load_image)
self.btn_export_path = QPushButton("Export Path")
self.btn_export_path.clicked.connect(self.export_path)
btn_layout.addWidget(self.btn_export_path)
self.btn_clear_points = QPushButton("Clear Points")
self.btn_clear_points.clicked.connect(self.clear_points)
btn_layout.addWidget(self.btn_clear_points)
self.btn_advanced = QPushButton("Advanced Settings")
self.btn_advanced.setCheckable(True)
self.btn_advanced.clicked.connect(self._toggle_advanced_settings)
btn_layout.addWidget(self.btn_advanced)
self._left_panel.addLayout(btn_layout)
self.resize(1000, 600)
self._old_central_widget = None
self._editor = None
def _toggle_advanced_settings(self, checked: bool):
"""
Toggles the visibility of the advanced settings widget.
"""
if checked:
self._advanced_widget.show()
else:
self._advanced_widget.hide()
# Force re-layout
self.adjustSize()
def open_circle_editor(self):
"""
Replace central widget with circle editor.
"""
if not self._last_loaded_pixmap:
print("No image loaded yet! Cannot open circle editor.")
return
old_widget = self.takeCentralWidget()
self._old_central_widget = old_widget
init_radius = self._circle_calibrated_radius
editor = CircleEditorWidget(
pixmap=self._last_loaded_pixmap,
init_radius=init_radius,
done_callback=self._on_circle_editor_done
)
self._editor = editor
self.setCentralWidget(editor)
def _on_circle_editor_done(self, final_radius: int):
"""
Updates the calibrated radius, computes the cost image based on the new radius,
and updates the image view with the new cost image.
It also restores the previous central widget and cleans up the editor widget.
"""
self._circle_calibrated_radius = final_radius
print(f"Circle Editor done. Radius = {final_radius}")
# Update cost image and path using new radius
if self._last_loaded_file_path:
cost_img = compute_cost_image(
self._last_loaded_file_path,
self._circle_calibrated_radius,
clip_limit=self._current_clip_limit
)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
self.image_view._apply_all_guide_points_to_cost()
self.image_view._rebuild_full_path()
self._update_advanced_images()
# Swap back to central widget
editor_widget = self.takeCentralWidget()
if editor_widget is not None:
editor_widget.setParent(None)
if self._old_central_widget is not None:
self.setCentralWidget(self._old_central_widget)
self._old_central_widget = None
if self._editor is not None:
self._editor.deleteLater()
self._editor = None
def toggle_rainbow(self):
"""
Toggle rainbow coloring of the path.
"""
self.image_view.toggle_rainbow()
def load_image(self):
"""
Load an image and update the image view and cost image.
The supported image formats are: PNG, JPG, JPEG, BMP, and TIF.
"""
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "",
"Images (*.png *.jpg *.jpeg *.bmp *.tif)",
options=options
)
if file_path:
self.image_view.load_image(file_path)
cost_img = compute_cost_image(
file_path,
self._circle_calibrated_radius,
clip_limit=self._current_clip_limit
)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
pm = QPixmap(file_path)
if not pm.isNull():
self._last_loaded_pixmap = pm
self._last_loaded_file_path = file_path
self._update_advanced_images()
def update_contrast(self, clip_limit: float):
"""
Updates and applies the contrast value of the image.
"""
self._current_clip_limit = clip_limit
if self._last_loaded_file_path:
cost_img = compute_cost_image(
self._last_loaded_file_path,
self._circle_calibrated_radius,
clip_limit=clip_limit
)
self.image_view.cost_image_original = cost_img
self.image_view.cost_image = cost_img.copy()
self.image_view._apply_all_guide_points_to_cost()
self.image_view._rebuild_full_path()
self._update_advanced_images()
def _update_advanced_images(self):
"""
Updates the advanced images display with the latest image.
If no image has been loaded, the method returns without making any updates.
"""
if not self._last_loaded_pixmap:
return
pm_np = self._qpixmap_to_gray_float(self._last_loaded_pixmap)
contrasted_blurred = preprocess_image(
pm_np,
sigma=3,
clip_limit=self._current_clip_limit
)
cost_img_np = self.image_view.cost_image
self._advanced_widget.update_displays(contrasted_blurred, cost_img_np)
def _qpixmap_to_gray_float(self, qpix: QPixmap) -> np.ndarray:
"""
Convert a QPixmap to a grayscale float array.
Args:
qpix: The QPixmap to be converted.
Returns:
A 2D numpy array representing the grayscale image.
"""
img = qpix.toImage()
img = img.convertToFormat(QImage.Format_ARGB32)
ptr = img.bits()
ptr.setsize(img.byteCount())
arr = np.frombuffer(ptr, np.uint8).reshape((img.height(), img.width(), 4))
rgb = arr[..., :3].astype(np.float32)
gray = rgb.mean(axis=2) / 255.0
return gray
def export_path(self):
"""
Exports the path as a CSV in the format: x, y, TYPE,
ensuring that each anchor influences exactly one path point.
"""
full_xy = self.image_view.get_full_path_xy()
if not full_xy:
print("No path to export.")
return
anchor_points = self.image_view.anchor_points
# Finds the index of the closest path point for each anchor point
user_placed_indices = set()
for ax, ay in anchor_points:
min_dist = float('inf')
closest_idx = None
for i, (px, py) in enumerate(full_xy):
dist = math.hypot(px - ax, py - ay)
if dist < min_dist:
min_dist = dist
closest_idx = i
if closest_idx is not None:
user_placed_indices.add(closest_idx)
# Ask user for the CSV filename
options = QFileDialog.Options()
file_path, _ = QFileDialog.getSaveFileName(
self, "Export Path", "",
"CSV Files (*.csv);;All Files (*)",
options=options
)
if not file_path:
return
import csv
with open(file_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["x", "y", "TYPE"])
for i, (x, y) in enumerate(full_xy):
ptype = "USER-PLACED" if i in user_placed_indices else "PATH"
writer.writerow([x, y, ptype])
print(f"Exported path with {len(full_xy)} points to {file_path}")
def clear_points(self):
"""
Clears points from the image.
"""
self.image_view.clear_guide_points()
def closeEvent(self, event: QCloseEvent):
"""
Handle the window close event.
Args:
event: The close event.
"""
super().closeEvent(event)
\ No newline at end of file
from PyQt5.QtWidgets import QGraphicsView, QSizePolicy
from PyQt5.QtCore import Qt
class PanZoomGraphicsView(QGraphicsView):
"""
A QGraphicsView subclass that supports panning and zooming with the mouse.
"""
def __init__(self, parent=None):
super().__init__(parent)
self.setDragMode(QGraphicsView.NoDrag) # We'll handle panning manually
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
self._panning = False
self._pan_start = None
# Expands layout
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
def wheelEvent(self, event):
""" Zoom in/out with mouse wheel. """
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
if event.angleDelta().y() > 0:
self.scale(zoom_in_factor, zoom_in_factor)
else:
self.scale(zoom_out_factor, zoom_out_factor)
event.accept()
def mousePressEvent(self, event):
""" If left button: Start panning (unless overridden). """
if event.button() == Qt.LeftButton:
self._panning = True
self._pan_start = event.pos()
self.setCursor(Qt.ClosedHandCursor)
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
""" If panning, translate the scene. """
if self._panning and self._pan_start is not None:
delta = event.pos() - self._pan_start
self._pan_start = event.pos()
self.translate(delta.x(), delta.y())
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
""" End panning. """
if event.button() == Qt.LeftButton:
self._panning = False
self.setCursor(Qt.ArrowCursor)
super().mouseReleaseEvent(event)
from skimage.filters import gaussian
from skimage import exposure
def preprocess_image(image: "np.ndarray", sigma: int = 3, clip_limit: float = 0.01) -> "np.ndarray":
"""
Preprocess the input image by applying histogram equalization and Gaussian smoothing.
Args:
image: (ndarray): Input image to be processed.
sigma: (float, optional): Standard deviation for Gaussian kernel. Default is 3.
clip_limit: (float, optional): Clipping limit for contrast enhancement. Default is 0.01.
Returns:
ndarray: The preprocessed image.
"""
# Applies histogram equalization to enhance contrast
image_contrasted = exposure.equalize_adapthist(
image, clip_limit=clip_limit)
# Applies smoothing
smoothed_img = gaussian(image_contrasted, sigma=sigma)
return smoothed_img
data/AgamodonSlice.png

87.1 KiB

data/AngustifronsSlice35.png

91.3 KiB

data/BipesSlice4.png

44.1 KiB

data/BipesSlice4NoCropping.png

57.3 KiB

File added
File added
data/agamodon_slice.png

87.1 KiB

File added
data/angustifrons_slice.png

91.3 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment