Skip to content
Snippets Groups Projects
sato_test.py 4.59 KiB
Newer Older
  • Learn to ignore specific revisions
  • import time
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    from skimage.morphology import skeletonize
    from skimage.filters import gaussian, sato
    from skimage.feature import canny
    from skimage.graph import route_through_array
    
    #### Helper functions ####
    
    def load_image(path, type):
        """
        Load an image in either gray or color mode (then convert color to gray).
        """
        if type == 'gray':
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
        elif type == 'color':
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        else:
            raise ValueError("type must be 'gray' or 'color'")
        return img
    
    def downscale(img, points, scale_percent):
        """
        Downsample `img` to `scale_percent` size and scale the given points accordingly.
        Returns (downsampled_img, (scaled_seed, scaled_target)).
        """
        if scale_percent == 100:
            return img, (tuple(points[0]), tuple(points[1]))
        else:
            # Compute new dimensions
            width = int(img.shape[1] * scale_percent / 100)
            height = int(img.shape[0] * scale_percent / 100)
            new_dimensions = (width, height)
    
            # Downsample
            downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
    
            # Scaling factors
            scale_x = width / img.shape[1]
            scale_y = height / img.shape[0]
    
            # Scale the points (x, y)
            seed_xy = tuple(points[0])
            target_xy = tuple(points[1])
            scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
            scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
    
            return downsampled_img, (scaled_seed_xy, scaled_target_xy)
    
    ## NO LONGER INVERSE (NOT 1/...)
    def compute_cost(image, sigma=1.0, epsilon=1e-1):
        """
        Smooth the image, run Canny edge detection, then invert the edge map into a cost image.
        """
    
        smoothed_img = gaussian(image, sigma=sigma)
        canny_img = sato(smoothed_img)
    
        canny_thresh = canny_img > 0.08
    
        skeleton = skeletonize(canny_thresh)
    
        cost_img = 1 /(skeleton + epsilon)  # Invert edges: higher cost where edges are stronger
        return cost_img, skeleton
    
    def backtrack_pixels_on_image(img_color, path_coords, bgr_color=(0, 0, 255)):
        """
        Color the path on the (already converted BGR) image in the specified color.
        `path_coords` should be a list of (row, col) or (y, x).
        """
        for (row, col) in path_coords:
            img_color[row, col] = bgr_color
        return img_color
    
    #### Main Script ####
    
    def main():
        # Define input parameters
        image_path = './tests/slice_60_volQ.png'
        image_type = 'gray'        # 'gray' or 'color'
        downscale_factor = 100     # % of original size
        points_path = './tests/LiveWireEndPoints.npy'
    
        # Load image
        image = load_image(image_path, image_type)
    
        # Load seed and target points
        points = np.int0(np.round(np.load(points_path)))  # shape: (2, 2), i.e. [[x_seed, y_seed], [x_target, y_target]]
    
        # Downscale image and points
        scaled_image, scaled_points = downscale(image, points, downscale_factor)
        seed, target = scaled_points  # Each is (x, y)
    
        # Convert to row,col for scikit-image (which uses (row, col) = (y, x))
        seed_rc = (seed[1], seed[0])
        target_rc = (target[1], target[0])
    
        # Compute cost image
        cost_image, canny_img = compute_cost(scaled_image)
    
        # Find path using route_through_array
        # route_through_array expects: route_through_array(image, start, end, fully_connected=True/False)
        start_time = time.time()
        path_rc, cost = route_through_array(
            cost_image, 
            start=seed_rc, 
            end=target_rc, 
            fully_connected=True
        )
        end_time = time.time()
    
        print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
    
        # Convert single-channel image to BGR for coloring
        color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
    
        # Draw path. `path_rc` is a list of (row, col).
        # If you want to mark it in red, do (0,0,255) because OpenCV uses BGR format.
        color_img = backtrack_pixels_on_image(color_img, path_rc, bgr_color=(0, 0, 255))
    
        # Display results
        plt.figure(figsize=(20, 8))
        plt.subplot(1, 2, 1)
        plt.title("Canny Edges")
        plt.imshow(canny_img, cmap='gray')
    
        plt.subplot(1, 2, 2)
        plt.title("Path from Seed to Target")
        # Convert BGR->RGB for pyplot
        plt.imshow(color_img[..., ::-1])
        plt.show()
    
    if __name__ == "__main__":
        main()