from collections import defaultdict
import numpy as np

to_key = lambda val: int(round(val))

def p_dist(p1, p2):
    return np.linalg.norm(np.array(p1) -  np.array(p2))

class VesiclePointManager:

    def __init__(self):
        self._points = {0: defaultdict(list),
                        1: defaultdict(list),
                        2: defaultdict(list)}

        self._points_vidx = defaultdict(list)

        self._vidx = 0

    @property
    def vidx(self):
        return self._vidx

    def new_vesicle(self):
        if len(self._points_vidx) > 0:
            self._vidx = max(self._points_vidx.keys()) + 1
        else:
            self._vidx = 0

    def add_point(self, x, y, z):

        #for p in self._points_vidx[self._vidx]:
        #    if p_dist(p, [x,y,z]) > 40:
        #        print('Point too far from existing points', p, [x,y,z])
        #        return

        self._points[0][to_key(z)].append((x,y,self._vidx,z))
        self._points[1][to_key(y)].append((x,z,self._vidx,y))
        self._points[2][to_key(x)].append((y,z,self._vidx,x))
        self._points_vidx[self._vidx].append((x,y,z))

    def delete_nearest_point(self, x, y, z):

        min_dist = 3
        min_p = None
        min_vidx = None

        for key in self._points_vidx.keys():
            for p in self._points_vidx[key]:
                dist = p_dist(p, [x,y,z])
                if dist < min_dist:
                    min_dist = dist
                    min_p = p
                    min_vidx = key

        if min_p == None:
            print('no point was close enough')
        else:
            print('found:', min_p, min_vidx, min_dist)
            x, y, z = min_p

            try:
                idx = self._points_vidx[min_vidx].index((x,y,z))
                idx0 = self._points[0][to_key(z)].index((x,y,min_vidx,z))
                idx1 = self._points[1][to_key(y)].index((x,z,min_vidx,y))
                idx2 = self._points[2][to_key(x)].index((y,z,min_vidx,x))
                del self._points_vidx[min_vidx][idx]
                del self._points[0][to_key(z)][idx0]
                del self._points[1][to_key(y)][idx1]
                del self._points[2][to_key(x)][idx2]
            except ValueError:
                pass

    def set_to_nearest_vesicle(self, x, y, z):
        min_dist = 3
        min_p = None
        min_vidx = None

        for key in self._points_vidx.keys():
            for p in self._points_vidx[key]:
                dist = p_dist(p, [x,y,z])
                if dist < min_dist:
                    min_dist = dist
                    min_p = p
                    min_vidx = key

        if min_vidx != None:
            self._vidx = min_vidx

    def get_vesicle_points(self):
        return [np.array(self._points_vidx[key]) for key in self._points_vidx.keys()]

    def get_vesicle_points_with_vidx(self):
        points = []

        for key in self._points_vidx.keys():
            for x, y, z in self._points_vidx[key]:
                points.append([key, x, y, z])

        return points

    def get_points_in_section(self, idx, axis):
        return self._points[axis][idx]