import numpy as np

def fix_triangle_normals(V, F):

    for i in range(len(F)):

        face = F[i]
        vertices = V[face,:]

        V1 = np.pad(V[1] - V[0], (0, 1))
        V2 = np.pad(V[2] - V[1], (0, 1))
        if np.cross(V1, V2)[2] < 0:
            F[i] = F[i][::-1]

    return F

def check_vertex_ordering(V):
    return np.dot(np.cross(V[1] - V[0], V[2] - V[0]), V[3] - V[0]) > 0

def tetrahedralize(V_tri, F, z0, z1):

    if z0 > z1:
        z0, z1 = z1, z0

    F = fix_triangle_normals(V_tri, F)

    N = len(V_tri)
    V_tet = np.concatenate([
        np.pad(V_tri, ((0,0), (0,1)), constant_values=z0), 
        np.pad(V_tri, ((0,0), (0,1)), constant_values=z1)
        ])

    T = []

    for face in F:
        new_tet = [face[0], face[1],     face[2],     face[0] + N]
        while not check_vertex_ordering(V_tet[new_tet]):
            np.random.shuffle(new_tet)
        T.append(new_tet)
        new_tet = [face[1], face[2],     face[0] + N, face[1] + N]
        while not check_vertex_ordering(V_tet[new_tet]):
            np.random.shuffle(new_tet)
        T.append(new_tet)
        new_tet = [face[2], face[0] + N, face[1] + N, face[2] + N]
        while not check_vertex_ordering(V_tet[new_tet]):
            np.random.shuffle(new_tet)
        T.append(new_tet)

    return V_tet, T