Skip to content
Snippets Groups Projects
disk_live_wire_test.py 5.99 KiB
Newer Older
  • Learn to ignore specific revisions
  • import time
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    from skimage import exposure
    from skimage.filters import gaussian
    from skimage.feature import canny
    from skimage.graph import route_through_array
    from scipy.signal import convolve2d
    
    #### Helper functions ####
    
    def load_image(path, type):
        """
        Load an image in either gray or color mode (then convert color to gray).
        """
        if type == 'gray':
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
        elif type == 'color':
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Could not read {path}")
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        else:
            raise ValueError("type must be 'gray' or 'color'")
        return img
    
    def downscale(img, points, scale_percent):
        """
        Downsample `img` to `scale_percent` size and scale the given points accordingly.
        Returns (downsampled_img, (scaled_seed, scaled_target)).
        """
        if scale_percent == 100:
            return img, (tuple(points[0]), tuple(points[1]))
        else:
            # Compute new dimensions
            width = int(img.shape[1] * scale_percent / 100)
            height = int(img.shape[0] * scale_percent / 100)
            new_dimensions = (width, height)
    
            # Downsample
            downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
    
            # Scaling factors
            scale_x = width / img.shape[1]
            scale_y = height / img.shape[0]
    
            # Scale the points (x, y)
            seed_xy = tuple(points[0])
            target_xy = tuple(points[1])
            scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
            scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
    
            return downsampled_img, (scaled_seed_xy, scaled_target_xy)
    
    def compute_cost(image, sigma=3.0, disk_size=15):
        """
        Smooth the image, run Canny edge detection, then invert the edge map into a cost image.
        """
    
        # Apply histogram equalization
        image_contrasted = exposure.equalize_adapthist(image, clip_limit=0.01)
    
        # Apply smoothing
        smoothed_img = gaussian(image_contrasted, sigma=sigma)
    
        # Apply Canny edge detection
        canny_img = canny(smoothed_img)
    
        # Do disk thing
        binary_img = canny_img
        k_size = 17
        kernel = circle_edge_kernel(k_size=disk_size)
        convolved = convolve2d(binary_img, kernel, mode='same', boundary='fill')
    
    
        # Create cost image
        cost_img = (convolved.max() - convolved)**4  # Invert edges: higher cost where edges are stronger
    
        return cost_img, canny_img
    
    def backtrack_pixels_on_image(img_color, path_coords, bgr_color=(0, 0, 255)):
        """
        Color the path on the (already converted BGR) image in the specified color.
        `path_coords` should be a list of (row, col) or (y, x).
        """
        for (row, col) in path_coords:
            img_color[row, col] = bgr_color
        return img_color
    
    
    
    def circle_edge_kernel(k_size=5, radius=None):
        """
        Create a k_size x k_size array whose values increase linearly
        from 0 at the center to 1 at the circle boundary (radius).
    
        Parameters
        ----------
        k_size : int
            The size (width and height) of the kernel array.
        radius : float, optional
            The circle's radius. By default, set to (k_size-1)/2.
    
        Returns
        -------
        kernel : 2D numpy array of shape (k_size, k_size)
            The circle-edge-weighted kernel.
        """
        if radius is None:
            # By default, let the radius be half the kernel size
            radius = (k_size - 1) / 2
    
        # Create an empty kernel
        kernel = np.zeros((k_size, k_size), dtype=float)
    
        # Coordinates of the center
        center = radius  # same as (k_size-1)/2 if radius is default
    
        # Fill the kernel
        for y in range(k_size):
            for x in range(k_size):
                dist = np.sqrt((x - center)**2 + (y - center)**2)
                if dist <= radius:
                    # Weight = distance / radius => 0 at center, 1 at boundary
                    kernel[y, x] = dist / radius
    
        return kernel
    
    
    
    
    
    
    #### Main Script ####
    def main():
        # Define input parameters
        image_path = 'agamodon_slice.png'
        image_type = 'gray'        # 'gray' or 'color'
        downscale_factor = 100     # % of original size
        points_path = 'agamodonPoints.npy'
    
        # Load image
        image = load_image(image_path, image_type)
    
        # Load seed and target points
        points = np.int0(np.round(np.load(points_path)))  # shape: (2, 2), i.e. [[x_seed, y_seed], [x_target, y_target]]
    
        # Downscale image and points
        scaled_image, scaled_points = downscale(image, points, downscale_factor)
        seed, target = scaled_points  # Each is (x, y)
    
        # Convert to row,col for scikit-image (which uses (row, col) = (y, x))
        seed_rc = (seed[1], seed[0])
        target_rc = (target[1], target[0])
    
        # Compute cost image
        cost_image, canny_img = compute_cost(scaled_image, disk_size=17)
    
    
        # Find path using route_through_array
        # route_through_array expects: route_through_array(image, start, end, fully_connected=True/False)
        start_time = time.time()
        path_rc, cost = route_through_array(
            cost_image, 
            start=seed_rc, 
            end=target_rc, 
            fully_connected=True
        )
        end_time = time.time()
    
        print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
    
        # Convert single-channel image to BGR for coloring
        color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
    
        # Draw path. `path_rc` is a list of (row, col).
        # If you want to mark it in red, do (0,0,255) because OpenCV uses BGR format.
        color_img = backtrack_pixels_on_image(color_img, path_rc, bgr_color=(0, 0, 255))
    
        # Display results
        plt.figure(figsize=(20, 8))
        plt.subplot(1, 2, 1)
        plt.title("Cost image")
        plt.imshow(cost_image, cmap='gray')
    
        plt.subplot(1, 2, 2)
        plt.title("Path from Seed to Target")
        # Convert BGR->RGB for pyplot
        plt.imshow(color_img[..., ::-1])
        plt.show()
    
    if __name__ == "__main__":
        main()