diff --git a/live_wire.py b/live_wire.py
index 49d5cb20ffe56fa9d5a8429218ee30b927242f61..4f8c86f9830f9cd1a17197e968ce166d72d63a8e 100644
--- a/live_wire.py
+++ b/live_wire.py
@@ -1,97 +1,18 @@
 import time
 import cv2
 import numpy as np
-import heapq
 import matplotlib.pyplot as plt
-from scipy.ndimage import convolve
+
 from skimage.filters import gaussian
 from skimage.feature import canny
+from skimage.graph import route_through_array
 
 #### Helper functions ####
-def neighbors_8(x, y, width, height):
-    """Return the 8-connected neighbors of (x, y)."""
-    for nx in (x-1, x, x+1):
-        for ny in (y-1, y, y+1):
-            if 0 <= nx < width and 0 <= ny < height:
-                if not (nx == x and ny == y):
-                    yield nx, ny
-
-def dijkstra(cost_img, seed):
-    """
-    Dijkstra's algorithm on a 2D grid, using cost_img as the per-pixel cost.
-    
-    Args:
-      cost_img (np.array): 2D array of costs (float).
-      seed (tuple): (x, y) starting coordinate.
-    
-    Returns:
-      dist (np.float32): array of minimal cumulative cost from seed to each pixel.
-      parent (np.int32): array storing predecessor of each pixel for path reconstruction.
-    """
-    height, width = cost_img.shape
-
-    # Initialize dist and parent
-    dist = np.full((height, width), np.inf, dtype=np.float32)
-    dist[seed[1], seed[0]] = 0.0
-
-    parent = -1 * np.ones((height, width, 2), dtype=np.int32)
 
-    visited = np.zeros((height, width), dtype=bool)
-    pq = [(0.0, seed[0], seed[1])]  # (distance, x, y)
-
-    while pq:
-        curr_dist, cx, cy = heapq.heappop(pq)
-        if visited[cy, cx]:
-            continue
-        visited[cy, cx] = True
-
-        for nx, ny in neighbors_8(cx, cy, width, height):
-            if visited[ny, nx]:
-                continue
-            # We can take an average or sum—here, let's just sum the cost
-            move_cost = 0.5 * (cost_img[cy, cx] + cost_img[ny, nx])
-            ndist = curr_dist + move_cost
-            if ndist < dist[ny, nx]:
-                dist[ny, nx] = ndist
-                parent[ny, nx] = (cx, cy)
-                heapq.heappush(pq, (ndist, nx, ny))
-
-    return dist, parent
-
-def backtrack_path(parent, start, end):
+def load_image(path, type):
     """
-    Reconstruct path from 'end' back to 'start' using 'parent' array.
-    
-    Args:
-      parent (np.array): shape (H, W, 2), storing (px, py) for each pixel.
-      start (tuple): (x, y) start coordinate.
-      end (tuple): (x, y) end coordinate.
-    
-    Returns:
-      path (list of (x, y)): from start to end inclusive.
+    Load an image in either gray or color mode (then convert color to gray).
     """
-    path = []
-    current = end
-    while True:
-        path.append(current)
-        if current == start:
-            break
-        px, py = parent[current[1], current[0]]
-        current = (px, py)
-
-    path.reverse()
-    return path
-
-def compute_cost(image, sigma=3.0, epsilon=1e-5):
-
-    smoothed_img = gaussian(image, sigma=sigma)
-    canny_img = canny(smoothed_img)
-    cost_img = 1 / (canny_img + epsilon)
-
-    return cost_img, canny_img  
-
-def load_image(path, type):
-    # Load image
     if type == 'gray':
         img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
         if img is None:
@@ -104,82 +25,111 @@ def load_image(path, type):
             img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
     else:
         raise ValueError("type must be 'gray' or 'color'")
-    
     return img
 
 def downscale(img, points, scale_percent):
+    """
+    Downsample `img` to `scale_percent` size and scale the given points accordingly.
+    Returns (downsampled_img, (scaled_seed, scaled_target)).
+    """
     if scale_percent == 100:
         return img, (tuple(points[0]), tuple(points[1]))
     else:
+        # Compute new dimensions
         width = int(img.shape[1] * scale_percent / 100)
         height = int(img.shape[0] * scale_percent / 100)
         new_dimensions = (width, height)
 
-        # Downsample the image
+        # Downsample
         downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
 
-        ### SCALE POINTS
-        # Original image dimensions
-        original_width = img.shape[1]
-        original_height = img.shape[0]
-
-        # Downsampled image dimensions
-        downsampled_width = width
-        downsampled_height = height
-
         # Scaling factors
-        scale_x = downsampled_width / original_width
-        scale_y = downsampled_height / original_height
+        scale_x = width / img.shape[1]
+        scale_y = height / img.shape[0]
 
-        # Original points
+        # Scale the points (x, y)
         seed_xy = tuple(points[0])
         target_xy = tuple(points[1])
-
-        # Scale the points
         scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
         scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
 
         return downsampled_img, (scaled_seed_xy, scaled_target_xy)
 
+def compute_cost(image, sigma=3.0, epsilon=1e-5):
+    """
+    Smooth the image, run Canny edge detection, then invert the edge map into a cost image.
+    """
+    smoothed_img = gaussian(image, sigma=sigma)
+    canny_img = canny(smoothed_img)
+    cost_img = 1.0 / (canny_img + epsilon)  # Invert edges: higher cost where edges are stronger
+    return cost_img, canny_img
 
-# Define the following
-image_path = './tests/slice_60_volQ.png'
-image_type = 'gray'  # 'gray' or 'color'
-downscale_factor = 100  # % of original size wanted
-points_path = './tests/LiveWireEndPoints.npy'
-
-
-
-# Load image
-image = load_image(image_path, image_type)
-# Load points
-points = np.int0(np.round(np.load(points_path)))
-
-# Downscale image and points
-scaled_image, scaled_points = downscale(image, points, downscale_factor)
-seed, target = scaled_points
-
-# Compute cost image
-cost_image, canny_img = compute_cost(scaled_image)
-
-# Find path and time it
-start_time = time.time()
-dist, parent = dijkstra(cost_image, seed)
-path = backtrack_path(parent, seed, target)
-end_time = time.time()
-
-print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
+def backtrack_pixels_on_image(img_color, path_coords, bgr_color=(0, 0, 255)):
+    """
+    Color the path on the (already converted BGR) image in the specified color.
+    `path_coords` should be a list of (row, col) or (y, x).
+    """
+    for (row, col) in path_coords:
+        img_color[row, col] = bgr_color
+    return img_color
 
-color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
-for (x, y) in path:
-    color_img[y, x] = (0, 0, 255)  # red (color of path)
+#### Main Script ####
 
-plt.figure(figsize=(20,8))
-plt.subplot(1,2,1)
-plt.title("Cost Image")
-plt.imshow(canny_img, cmap='gray')
+def main():
+    # Define input parameters
+    image_path = './tests/slice_60_volQ.png'
+    image_type = 'gray'        # 'gray' or 'color'
+    downscale_factor = 100     # % of original size
+    points_path = './tests/LiveWireEndPoints.npy'
 
-plt.subplot(1,2,2)
-plt.title("Path from Seed to Target")
-plt.imshow(color_img[..., ::-1])  # BGR->RGB for plotting
-plt.show()
+    # Load image
+    image = load_image(image_path, image_type)
+
+    # Load seed and target points
+    points = np.int0(np.round(np.load(points_path)))  # shape: (2, 2), i.e. [[x_seed, y_seed], [x_target, y_target]]
+
+    # Downscale image and points
+    scaled_image, scaled_points = downscale(image, points, downscale_factor)
+    seed, target = scaled_points  # Each is (x, y)
+
+    # Convert to row,col for scikit-image (which uses (row, col) = (y, x))
+    seed_rc = (seed[1], seed[0])
+    target_rc = (target[1], target[0])
+
+    # Compute cost image
+    cost_image, canny_img = compute_cost(scaled_image)
+
+    # Find path using route_through_array
+    # route_through_array expects: route_through_array(image, start, end, fully_connected=True/False)
+    start_time = time.time()
+    path_rc, cost = route_through_array(
+        cost_image, 
+        start=seed_rc, 
+        end=target_rc, 
+        fully_connected=True
+    )
+    end_time = time.time()
+
+    print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
+
+    # Convert single-channel image to BGR for coloring
+    color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
+
+    # Draw path. `path_rc` is a list of (row, col).
+    # If you want to mark it in red, do (0,0,255) because OpenCV uses BGR format.
+    color_img = backtrack_pixels_on_image(color_img, path_rc, bgr_color=(0, 0, 255))
+
+    # Display results
+    plt.figure(figsize=(20, 8))
+    plt.subplot(1, 2, 1)
+    plt.title("Canny Edges")
+    plt.imshow(canny_img, cmap='gray')
+
+    plt.subplot(1, 2, 2)
+    plt.title("Path from Seed to Target")
+    # Convert BGR->RGB for pyplot
+    plt.imshow(color_img[..., ::-1])
+    plt.show()
+
+if __name__ == "__main__":
+    main()