import numpy as np

def psign(p1, p2, p3):
    return (p1[0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p3[1]);

def PointInTriangle(pt, v1, v2, v3):
    d1 = psign(pt, v1, v2)
    d2 = psign(pt, v2, v3)
    d3 = psign(pt, v3, v1)

    has_neg = (d1 < 0) | (d2 < 0) | (d3 < 0)
    has_pos = (d1 > 0) | (d2 > 0) | (d3 > 0)

    return not (has_neg & has_pos)

def get_edges(tri):
    return [(tri[0], tri[1]), (tri[0], tri[2]), (tri[1], tri[2])]

def default_add(d, e, V, ti):
    if e in d:
        d[e][1].append(ti)
    else:
        a, b = e
        length = np.linalg.norm(V[a] - V[b])
        if a == b:
            raise ValueError('Adding edge with same index:', e, ti)
        d[e] = [length, [ti]]
    return d

def get_all_edges(V, T):
    E_dict = {}
    for n, tri in enumerate(T):
        E = get_edges(tri) # outputs list of edges with sorted vertex indices. eg [(4,7),(4,8),(7,8)]
        for e in E:
            a, b = e
            length = np.linalg.norm(V[a] - V[b])
            E_dict = default_add(E_dict, e, V, n)

    return E_dict

def split_long_edges(V, T):
    E_dict = get_all_edges(V, T)
    target_length = np.median([length for length, _ in E_dict.values()])

    run = True

    n_init_V = len(V)
    while run:
        run = False
        max_length = 0

        for cur_e in E_dict.keys():
            cur_length, cur_Ti = E_dict[cur_e]
            if cur_length > max_length:
                max_length = cur_length
                e = cur_e
                Ti = cur_Ti
                a, b = cur_e

        if max_length > target_length:
            run = True
            # make a new vertex
            Vnew = V[a] + V[b]
            N = 2.0

            Vnew = Vnew / N
            new_length = np.linalg.norm(Vnew - V[a])
            n = len(V)
            V = np.concatenate([V, [Vnew]], axis=0)

            for ti in sorted(Ti)[::-1]:
                # find the vertex index which a,b is not co-linear with
                m = [i for i in T[ti] if i != a and i != b][0]

                old = np.pad(np.array([V[ii] for ii in T[ti]]), ((1,0), (0,0)), mode='wrap')+1
                new1 = np.pad(np.array([V[ii] for ii in sorted([m, n, a])]), ((1,0), (0,0)), mode='wrap')+2
                new2 = np.pad(np.array([V[ii] for ii in sorted([m, n, b])]), ((1,0), (0,0)), mode='wrap')

                T[ti] = sorted([m, n, a]) # ADD m,n,a REM a,b,m

                tj = len(T)

                T.append(sorted([m, n, b])) # ADD m,n,b

                # add new edges
                an = tuple(sorted([n, a]))
                am = tuple(sorted([m, a]))
                bn = tuple(sorted([n, b]))
                bm = tuple(sorted([m, b]))
                mn = tuple(sorted([m, n]))

                E_dict[tuple(sorted([m, b]))][1] = [tk for tk in E_dict[tuple(sorted([m, b]))][1] if tk != ti]

                cross_length = np.linalg.norm(V[m] - V[n])

                E_dict = default_add(E_dict, an, V, ti) # ADD ti to an
                E_dict = default_add(E_dict, bn, V, tj) # ADD tj to bn
                E_dict = default_add(E_dict, bm, V, tj) # ADD tj to bm
                E_dict = default_add(E_dict, mn, V, ti) # ADD ti to mn
                E_dict = default_add(E_dict, mn, V, tj) # ADD tj to mn
            
            del E_dict[e]

    interior_vertex_indices = set(range(len(V)))

    for e in E_dict.keys():
        if len(E_dict[e][1]) == 1:
            interior_vertex_indices = interior_vertex_indices - set(e)

    # Smooth interior triangulated points
    for _ in range(100):
        for i in interior_vertex_indices:
            Vs = []
            for e in E_dict.keys():
                if i in e:
                    vn = [j for j in e if j != i][0]
                    Vs.append(V[vn])

            Vnew = np.mean(Vs, axis=0)
            V[i] = 0.7 * V[i] + 0.3 * Vnew

    return V, T

def polygon_triangulate(P):
    N = len(P)
    P3 = np.pad(P, ((0,0), (0,1)))

    angle_sum = 0
    angles = []
    exterior_angles = []
    T = []
    for p1, p2, p3 in zip(P3, np.roll(P3, -1, axis=0), np.roll(P3, -2, axis=0)):
        v1 = p2 - p1
        v2 = p3 - p2
        cross = np.cross(v1, v2)

        dot = np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1, 1)

        angle = np.sign(cross[2])*np.arccos(dot)

        exterior_angles.append(angle/(2*np.pi)*360)

    indices = np.arange(N)

    if np.sum(exterior_angles) < 0:
        indices = indices[::-1]

    # iteratively add a trinagle and remove the index to the point which can no longer be part of another triangle
    while len(indices) > 2:
        best_i = 0
        best_j = 0
        best_k = 0
        best_n = 0
        best_length = np.inf

        # iterate over the remaining points to find the best candidate
        for n, (i,j,k), in enumerate(zip(np.roll(indices, 1), indices, np.roll(indices, -1))):
            v1 = P3[j] - P3[i]
            v2 = P3[k] - P3[j]
            cross = np.cross(v1, v2)

            angle = np.arcsin(cross[2] / (np.linalg.norm(v1) * np.linalg.norm(v2)))
            length = np.linalg.norm(v1) + np.linalg.norm(v2)

            # pick the shortest triangle which is interior
            if angle > 0 and length < best_length:

                # check all point to see if one lies inside the triangle, skip in that case
                contains_point = False
                for nn in range(N):
                    if nn == i or nn == j or nn == k:
                        continue
                    if PointInTriangle(P[nn], P[i], P[j], P[k]):
                        contains_point = True
                        break
                if contains_point:
                    continue
                best_length = length
                best_i = i
                best_j = j
                best_k = k
                best_n = n

        # add triangle and remove index to point which can no longer be part of another triangle
        T.append(sorted([best_i, best_j, best_k]))
        indices = np.delete(indices, best_n)

    P, T = split_long_edges(P, T)

    return exterior_angles, P, T

class PlotTriangulation:
    colors = np.random.uniform(0.2,1,(2048, 3))
    def __init__(self, V, T, add_labels=True):
        import matplotlib.pyplot as plt
        V0 = np.copy(V)

        if np.argmax(np.max(V0, axis=0)) == 1:
            V0 = V0[:,::-1]
        
        margins = (np.max(V0, axis=0) - np.min(V0, axis=0))*0.05

        V0 = V0 - np.min(V0, axis=0) + margins

        I = np.zeros((np.max(V0[:,1] + 1 + margins[1]).astype(np.int32), np.max(V0[:,0] + 1 + margins[0]).astype(np.int32)))

        plt.imshow(I, cmap='gray')
        plt.axis('off')
        plt.subplots_adjust(bottom=0, top=1, left=0,right=1)
        
        for j in range(len(T)):

            tri = T[j]
            t1 = plt.Polygon(V0[tri,:], alpha=0.7, color=PlotTriangulation.colors[j%2048])
            plt.gca().add_patch(t1)
            if add_labels:
                plt.text(np.mean(V0[tri,0]), np.mean(V0[tri,1]), str(j))

        if add_labels:
            for i, v in enumerate(V0):
                plt.text(v[0], v[1], str(i), color='orange')