diff --git a/live_wire.py b/live_wire.py new file mode 100644 index 0000000000000000000000000000000000000000..49d5cb20ffe56fa9d5a8429218ee30b927242f61 --- /dev/null +++ b/live_wire.py @@ -0,0 +1,185 @@ +import time +import cv2 +import numpy as np +import heapq +import matplotlib.pyplot as plt +from scipy.ndimage import convolve +from skimage.filters import gaussian +from skimage.feature import canny + +#### Helper functions #### +def neighbors_8(x, y, width, height): + """Return the 8-connected neighbors of (x, y).""" + for nx in (x-1, x, x+1): + for ny in (y-1, y, y+1): + if 0 <= nx < width and 0 <= ny < height: + if not (nx == x and ny == y): + yield nx, ny + +def dijkstra(cost_img, seed): + """ + Dijkstra's algorithm on a 2D grid, using cost_img as the per-pixel cost. + + Args: + cost_img (np.array): 2D array of costs (float). + seed (tuple): (x, y) starting coordinate. + + Returns: + dist (np.float32): array of minimal cumulative cost from seed to each pixel. + parent (np.int32): array storing predecessor of each pixel for path reconstruction. + """ + height, width = cost_img.shape + + # Initialize dist and parent + dist = np.full((height, width), np.inf, dtype=np.float32) + dist[seed[1], seed[0]] = 0.0 + + parent = -1 * np.ones((height, width, 2), dtype=np.int32) + + visited = np.zeros((height, width), dtype=bool) + pq = [(0.0, seed[0], seed[1])] # (distance, x, y) + + while pq: + curr_dist, cx, cy = heapq.heappop(pq) + if visited[cy, cx]: + continue + visited[cy, cx] = True + + for nx, ny in neighbors_8(cx, cy, width, height): + if visited[ny, nx]: + continue + # We can take an average or sum—here, let's just sum the cost + move_cost = 0.5 * (cost_img[cy, cx] + cost_img[ny, nx]) + ndist = curr_dist + move_cost + if ndist < dist[ny, nx]: + dist[ny, nx] = ndist + parent[ny, nx] = (cx, cy) + heapq.heappush(pq, (ndist, nx, ny)) + + return dist, parent + +def backtrack_path(parent, start, end): + """ + Reconstruct path from 'end' back to 'start' using 'parent' array. + + Args: + parent (np.array): shape (H, W, 2), storing (px, py) for each pixel. + start (tuple): (x, y) start coordinate. + end (tuple): (x, y) end coordinate. + + Returns: + path (list of (x, y)): from start to end inclusive. + """ + path = [] + current = end + while True: + path.append(current) + if current == start: + break + px, py = parent[current[1], current[0]] + current = (px, py) + + path.reverse() + return path + +def compute_cost(image, sigma=3.0, epsilon=1e-5): + + smoothed_img = gaussian(image, sigma=sigma) + canny_img = canny(smoothed_img) + cost_img = 1 / (canny_img + epsilon) + + return cost_img, canny_img + +def load_image(path, type): + # Load image + if type == 'gray': + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + if img is None: + raise FileNotFoundError(f"Could not read {path}") + elif type == 'color': + img = cv2.imread(path, cv2.IMREAD_COLOR) + if img is None: + raise FileNotFoundError(f"Could not read {path}") + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + else: + raise ValueError("type must be 'gray' or 'color'") + + return img + +def downscale(img, points, scale_percent): + if scale_percent == 100: + return img, (tuple(points[0]), tuple(points[1])) + else: + width = int(img.shape[1] * scale_percent / 100) + height = int(img.shape[0] * scale_percent / 100) + new_dimensions = (width, height) + + # Downsample the image + downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA) + + ### SCALE POINTS + # Original image dimensions + original_width = img.shape[1] + original_height = img.shape[0] + + # Downsampled image dimensions + downsampled_width = width + downsampled_height = height + + # Scaling factors + scale_x = downsampled_width / original_width + scale_y = downsampled_height / original_height + + # Original points + seed_xy = tuple(points[0]) + target_xy = tuple(points[1]) + + # Scale the points + scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y)) + scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y)) + + return downsampled_img, (scaled_seed_xy, scaled_target_xy) + + +# Define the following +image_path = './tests/slice_60_volQ.png' +image_type = 'gray' # 'gray' or 'color' +downscale_factor = 100 # % of original size wanted +points_path = './tests/LiveWireEndPoints.npy' + + + +# Load image +image = load_image(image_path, image_type) +# Load points +points = np.int0(np.round(np.load(points_path))) + +# Downscale image and points +scaled_image, scaled_points = downscale(image, points, downscale_factor) +seed, target = scaled_points + +# Compute cost image +cost_image, canny_img = compute_cost(scaled_image) + +# Find path and time it +start_time = time.time() +dist, parent = dijkstra(cost_image, seed) +path = backtrack_path(parent, seed, target) +end_time = time.time() + +print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds") + +color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR) +for (x, y) in path: + color_img[y, x] = (0, 0, 255) # red (color of path) + +plt.figure(figsize=(20,8)) +plt.subplot(1,2,1) +plt.title("Cost Image") +plt.imshow(canny_img, cmap='gray') + +plt.subplot(1,2,2) +plt.title("Path from Seed to Target") +plt.imshow(color_img[..., ::-1]) # BGR->RGB for plotting +plt.show() diff --git a/tests/LiveWireEndPoints.npy b/tests/LiveWireEndPoints.npy new file mode 100644 index 0000000000000000000000000000000000000000..75ec4b2945d7e7a49b6b72de2fddeca8bcee6615 Binary files /dev/null and b/tests/LiveWireEndPoints.npy differ