"""
Helping functions for snakes. 
Note that snake here is 2-by-N array!
"""

import numpy as np
import scipy.interpolate

def distribute_points(snake):
    """ Distributes snake points equidistantly. Expects snake to be 2-by-N array."""
    N = snake.shape[1]
    
    # Compute length of line segments.
    d = np.sqrt(np.sum((np.roll(snake, -1, axis=1) - snake)**2, axis=0)) 
    f = scipy.interpolate.interp1d(np.hstack((0, np.cumsum(d))), 
                                   np.hstack((snake, snake[:,0:1])))
    return(f(sum(d) * np.arange(N) / N))

def is_crossing(p1, p2, p3, p4):
    """ Check if the line segments (p1, p2) and (p3, p4) cross."""
    crossing = False
    d21 = p2 - p1
    d43 = p4 - p3
    d31 = p3 - p1
    det = d21[0]*d43[1] - d21[1]*d43[0] # Determinant
    if det != 0.0 and d21[0] != 0.0 and d21[1] != 0.0:
        a = d43[0]/d21[0] - d43[1]/d21[1]
        b = d31[1]/d21[1] - d31[0]/d21[0]
        if a != 0.0:
            u = b/a
            if d21[0] > 0:
                t = (d43[0]*u + d31[0])/d21[0]
            else:
                t = (d43[1]*u + d31[1])/d21[1]
            crossing = 0 < u < 1 and 0 < t < 1         
    return crossing

def is_counterclockwise(snake):
    """ Check if points are ordered counterclockwise."""
    return np.dot(snake[0, 1:] - snake[0, :-1],
                  snake[1, 1:] + snake[1, :-1]) < 0

def remove_intersections(snake):
    """ Reorder snake points to remove self-intersections.
        Arguments: snake represented by a 2-by-N array.
        Returns: snake.
    """
    pad_snake = np.append(snake, snake[:,0].reshape(2,1), axis=1)
    pad_n = pad_snake.shape[1]
    n = pad_n - 1 
    
    for i in range(pad_n - 3):
        for j in range(i + 2, pad_n - 1):
            pts = pad_snake[:, [i, i + 1, j, j + 1]]
            if is_crossing(pts[:, 0], pts[:, 1], pts[:, 2], pts[:, 3]):
                # Reverse vertices of smallest loop
                rb = i + 1 # Reverse begin
                re = j     # Reverse end
                if j - i > n // 2:
                    # Other loop is smallest
                    rb = j + 1
                    re = i + n                    
                while rb < re:
                    ia = rb % n
                    rb = rb + 1                    
                    ib = re % n
                    re = re - 1                    
                    pad_snake[:, [ia, ib]] = pad_snake[:, [ib, ia]]                    
                pad_snake[:,-1] = pad_snake[:,0]                
    snake = pad_snake[:, :-1]
    if is_counterclockwise(snake):
        return snake
    else:
        return np.flip(snake, axis=1)