import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage import exposure
from skimage.filters import gaussian
from skimage.feature import canny
from skimage.graph import route_through_array
from scipy.signal import convolve2d

### Disk live wire cost image

def compute_disk_size(user_radius, upscale_factor=1.2):
    return int(np.ceil(upscale_factor * 2 * user_radius + 1) // 2 * 2 + 1)


def load_image(path):
    return cv2.imread(path, cv2.IMREAD_GRAYSCALE)

def preprocess_image(image, sigma=3, clip_limit=0.01):
    # Apply histogram equalization
    image_contrasted = exposure.equalize_adapthist(image, clip_limit=clip_limit)

    # Apply smoothing
    smoothed_img = gaussian(image_contrasted, sigma=sigma)

    return smoothed_img


def compute_cost_image(path, user_radius, sigma=3, clip_limit=0.01):

    disk_size = compute_disk_size(user_radius)

    ### Load image
    image = load_image(path)

    # Apply smoothing
    smoothed_img = preprocess_image(image, sigma=sigma, clip_limit=clip_limit)

    # Apply Canny edge detection
    canny_img = canny(smoothed_img)

    # Do disk thing
    binary_img = canny_img
    kernel = circle_edge_kernel(k_size=disk_size)
    convolved = convolve2d(binary_img, kernel, mode='same', boundary='fill')

    # Create cost image
    cost_img = (convolved.max() - convolved)**4  # Invert edges: higher cost where edges are stronger

    return cost_img


def find_path(cost_image, points):

    if len(points) != 2:
        raise ValueError("Points should be a list of 2 points: seed and target.")
    
    seed_rc, target_rc = points

    path_rc, cost = route_through_array(
        cost_image, 
        start=seed_rc, 
        end=target_rc, 
        fully_connected=True
    )

    return path_rc


def circle_edge_kernel(k_size=5, radius=None):
    """
    Create a k_size x k_size array whose values increase linearly
    from 0 at the center to 1 at the circle boundary (radius).

    Parameters
    ----------
    k_size : int
        The size (width and height) of the kernel array.
    radius : float, optional
        The circle's radius. By default, set to (k_size-1)/2.

    Returns
    -------
    kernel : 2D numpy array of shape (k_size, k_size)
        The circle-edge-weighted kernel.
    """
    if radius is None:
        # By default, let the radius be half the kernel size
        radius = (k_size - 1) / 2

    # Create an empty kernel
    kernel = np.zeros((k_size, k_size), dtype=float)

    # Coordinates of the center
    center = radius  # same as (k_size-1)/2 if radius is default

    # Fill the kernel
    for y in range(k_size):
        for x in range(k_size):
            dist = np.sqrt((x - center)**2 + (y - center)**2)
            if dist <= radius:
                # Weight = distance / radius => 0 at center, 1 at boundary
                kernel[y, x] = dist / radius

    return kernel


# Other functions (to be implemented?)
def downscale(img, points, scale_percent):
    """
    Downsample `img` to `scale_percent` size and scale the given points accordingly.
    Returns (downsampled_img, (scaled_seed, scaled_target)).
    """
    if scale_percent == 100:
        return img, (tuple(points[0]), tuple(points[1]))
    else:
        # Compute new dimensions
        width = int(img.shape[1] * scale_percent / 100)
        height = int(img.shape[0] * scale_percent / 100)
        new_dimensions = (width, height)

        # Downsample
        downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)

        # Scaling factors
        scale_x = width / img.shape[1]
        scale_y = height / img.shape[0]

        # Scale the points (x, y)
        seed_xy = tuple(points[0])
        target_xy = tuple(points[1])
        scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
        scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))

        return downsampled_img, (scaled_seed_xy, scaled_target_xy)