Skip to content
Snippets Groups Projects
live_wire.py 5.43 KiB
Newer Older
  • Learn to ignore specific revisions
  • Christian's avatar
    Christian committed
    import time
    import cv2
    import numpy as np
    import heapq
    import matplotlib.pyplot as plt
    from scipy.ndimage import convolve
    from skimage.filters import gaussian
    from skimage.feature import canny
    
    #### Helper functions ####
    def neighbors_8(x, y, width, height):
        """Return the 8-connected neighbors of (x, y)."""
        for nx in (x-1, x, x+1):
            for ny in (y-1, y, y+1):
                if 0 <= nx < width and 0 <= ny < height:
                    if not (nx == x and ny == y):
                        yield nx, ny
    
    def dijkstra(cost_img, seed):
        """
        Dijkstra's algorithm on a 2D grid, using cost_img as the per-pixel cost.
        
        Args:
          cost_img (np.array): 2D array of costs (float).
          seed (tuple): (x, y) starting coordinate.
        
        Returns:
          dist (np.float32): array of minimal cumulative cost from seed to each pixel.
          parent (np.int32): array storing predecessor of each pixel for path reconstruction.
        """
        height, width = cost_img.shape
    
        # Initialize dist and parent
        dist = np.full((height, width), np.inf, dtype=np.float32)
        dist[seed[1], seed[0]] = 0.0
    
        parent = -1 * np.ones((height, width, 2), dtype=np.int32)
    
        visited = np.zeros((height, width), dtype=bool)
        pq = [(0.0, seed[0], seed[1])]  # (distance, x, y)
    
        while pq:
            curr_dist, cx, cy = heapq.heappop(pq)
            if visited[cy, cx]:
                continue
            visited[cy, cx] = True
    
            for nx, ny in neighbors_8(cx, cy, width, height):
                if visited[ny, nx]:
                    continue
                # We can take an average or sum—here, let's just sum the cost
                move_cost = 0.5 * (cost_img[cy, cx] + cost_img[ny, nx])
                ndist = curr_dist + move_cost
                if ndist < dist[ny, nx]:
                    dist[ny, nx] = ndist
                    parent[ny, nx] = (cx, cy)
                    heapq.heappush(pq, (ndist, nx, ny))
    
        return dist, parent
    
    def backtrack_path(parent, start, end):
        """
        Reconstruct path from 'end' back to 'start' using 'parent' array.
        
        Args:
          parent (np.array): shape (H, W, 2), storing (px, py) for each pixel.
          start (tuple): (x, y) start coordinate.
          end (tuple): (x, y) end coordinate.
        
        Returns:
          path (list of (x, y)): from start to end inclusive.
        """
        path = []
        current = end
        while True:
            path.append(current)
            if current == start:
                break
            px, py = parent[current[1], current[0]]
            current = (px, py)
    
        path.reverse()
        return path
    
    def compute_cost(image, sigma=3.0, epsilon=1e-5):
    
        smoothed_img = gaussian(image, sigma=sigma)
        canny_img = canny(smoothed_img)
        cost_img = 1 / (canny_img + epsilon)
    
        return cost_img, canny_img  
    
    def load_image(path, type):
        # Load image
        if type == 'gray':
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
        elif type == 'color':
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        else:
            raise ValueError("type must be 'gray' or 'color'")
        
        return img
    
    def downscale(img, points, scale_percent):
        if scale_percent == 100:
            return img, (tuple(points[0]), tuple(points[1]))
        else:
            width = int(img.shape[1] * scale_percent / 100)
            height = int(img.shape[0] * scale_percent / 100)
            new_dimensions = (width, height)
    
            # Downsample the image
            downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
    
            ### SCALE POINTS
            # Original image dimensions
            original_width = img.shape[1]
            original_height = img.shape[0]
    
            # Downsampled image dimensions
            downsampled_width = width
            downsampled_height = height
    
            # Scaling factors
            scale_x = downsampled_width / original_width
            scale_y = downsampled_height / original_height
    
            # Original points
            seed_xy = tuple(points[0])
            target_xy = tuple(points[1])
    
            # Scale the points
            scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
            scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
    
            return downsampled_img, (scaled_seed_xy, scaled_target_xy)
    
    
    # Define the following
    image_path = './tests/slice_60_volQ.png'
    image_type = 'gray'  # 'gray' or 'color'
    downscale_factor = 100  # % of original size wanted
    points_path = './tests/LiveWireEndPoints.npy'
    
    
    
    # Load image
    image = load_image(image_path, image_type)
    # Load points
    points = np.int0(np.round(np.load(points_path)))
    
    # Downscale image and points
    scaled_image, scaled_points = downscale(image, points, downscale_factor)
    seed, target = scaled_points
    
    # Compute cost image
    cost_image, canny_img = compute_cost(scaled_image)
    
    # Find path and time it
    start_time = time.time()
    dist, parent = dijkstra(cost_image, seed)
    path = backtrack_path(parent, seed, target)
    end_time = time.time()
    
    print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
    
    color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
    for (x, y) in path:
        color_img[y, x] = (0, 0, 255)  # red (color of path)
    
    plt.figure(figsize=(20,8))
    plt.subplot(1,2,1)
    plt.title("Cost Image")
    plt.imshow(canny_img, cmap='gray')
    
    plt.subplot(1,2,2)
    plt.title("Path from Seed to Target")
    plt.imshow(color_img[..., ::-1])  # BGR->RGB for plotting
    plt.show()