import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage import exposure
from skimage.filters import gaussian
from skimage.feature import canny
from skimage.graph import route_through_array

#### Helper functions ####

def compute_cost_image(path, sigma=3):

    ### Load image
    image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    
    # Apply histogram equalization
    image_contrasted = exposure.equalize_adapthist(image, clip_limit=0.01)

    # Apply smoothing
    smoothed_img = gaussian(image_contrasted, sigma=sigma)

    # Apply Canny edge detection
    canny_img = canny(smoothed_img)

    # Create cost image
    cost_img = 1.0 / (canny_img + 1e-5)  # Invert edges: higher cost where edges are stronger

    return cost_img


def find_path(cost_image, points):

    if len(points) != 2:
        raise ValueError("Points should be a list of 2 points: seed and target.")
    
    seed_rc, target_rc = points

    path_rc, cost = route_through_array(
        cost_image, 
        start=seed_rc, 
        end=target_rc, 
        fully_connected=True
    )

    return path_rc




def downscale(img, points, scale_percent):
    """
    Downsample `img` to `scale_percent` size and scale the given points accordingly.
    Returns (downsampled_img, (scaled_seed, scaled_target)).
    """
    if scale_percent == 100:
        return img, (tuple(points[0]), tuple(points[1]))
    else:
        # Compute new dimensions
        width = int(img.shape[1] * scale_percent / 100)
        height = int(img.shape[0] * scale_percent / 100)
        new_dimensions = (width, height)

        # Downsample
        downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)

        # Scaling factors
        scale_x = width / img.shape[1]
        scale_y = height / img.shape[0]

        # Scale the points (x, y)
        seed_xy = tuple(points[0])
        target_xy = tuple(points[1])
        scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
        scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))

        return downsampled_img, (scaled_seed_xy, scaled_target_xy)

def compute_cost(image, sigma=3.0, epsilon=1e-5):
    """
    Smooth the image, run Canny edge detection, then invert the edge map into a cost image.
    """

    # Apply histogram equalization
    image_contrasted = exposure.equalize_adapthist(image, clip_limit=0.01)

    # Apply smoothing
    smoothed_img = gaussian(image_contrasted, sigma=sigma)

    # Apply Canny edge detection
    canny_img = canny(smoothed_img)

    # Create cost image
    cost_img = 1.0 / (canny_img + epsilon)  # Invert edges: higher cost where edges are stronger

    return cost_img, canny_img

def backtrack_pixels_on_image(img_color, path_coords, bgr_color=(0, 0, 255)):
    """
    Color the path on the (already converted BGR) image in the specified color.
    `path_coords` should be a list of (row, col) or (y, x).
    """
    for (row, col) in path_coords:
        img_color[row, col] = bgr_color
    return img_color

def export_path(path_coords, path_name):
    """
    Export the path to a np array.
    """
    np.save(path_name, path_coords)
    return None