import numpy as np
from EllipsoidFit import ellipsoidFit
import pickle
from scipy.interpolate import interp1d
from scipy.optimize import minimize
from tqdm import tqdm
import matplotlib.pyplot as plt
from time import sleep

def cumulative_interp(k):
    c = np.cumsum(k, axis=0)
    c_interp_0 = interp1d(range(0, len(k)), c[:,0], kind='linear', fill_value='extrapolate')
    c_interp_1 = interp1d(range(0, len(k)), c[:,1], kind='linear', fill_value='extrapolate')
    return (c_interp_0, c_interp_1)

def calculate_drift(E):
    v, radii, center = E

    A, B, C, D, E, F = v

    sx = (D*F - B*E) / (A*B - D**2)
    sy = (D*E - A*F) / (A*B - D**2)

    return sx, sy

def driftEstimate(ellipsoids, z_bounds, drift_spread=7, minimum_N=2):

    drift_parameters = []

    for E in ellipsoids:

        v, radii, center = E

        A, B, C, D, E, F = v

        sx = (D*F - B*E) / (A*B - D**2)
        sy = (D*E - A*F) / (A*B - D**2)

        drift_parameters.append([sx, sy, center[0], center[1], center[2]])

    z_lower, z_upper = z_bounds

    drift_bins = [[] for i in range(z_lower, z_upper)]

    drift_components = []

    for param in drift_parameters:

        drift_component = param[:2]
        center_z = param[4]
        center_k = int(round(center_z))
        drift_components.append(drift_component)

        for k in range(center_k - drift_spread, center_k + drift_spread + 1):
            if k >= z_lower and k < z_upper:
                k_array = k - z_lower
                drift_bins[k_array].append(drift_component)

    drift_bins_means = [np.mean(np.array(bin_group), axis=0) if len(bin_group) >= minimum_N else np.array([0, 0]) for bin_group in drift_bins]
    drift_bins_std = [np.std(np.array(bin_group), axis=0) if len(bin_group) >= minimum_N else np.array([-1, -1]) for bin_group in drift_bins]
    drift_bins_N = [len(bin_group) for bin_group in drift_bins]

    # Interpolation, filling missing values etc., here

    drift_estimate = np.array(drift_bins_means)
    drift_std = np.array(drift_bins_std)
    drift_bins_N = np.array(drift_bins_N)

    return drift_estimate, drift_std, drift_bins_N, drift_components

def drift_points(points, c_interp):

    new_points = []

    c_interp_0, c_interp_1 = c_interp

    for X in points:
        new_X = np.copy(X)
        new_X[:,0] += c_interp_0(X[:,2])
        new_X[:,1] += c_interp_1(X[:,2])
        new_points.append(new_X)

    return new_points

def driftEstimate2(points, z_bounds):

    drift_parameters = []

    n = z_bounds[1] - z_bounds[0]

    k = np.zeros((2*(n-1), 1), dtype=np.float32)

    def callback(xk):
        print(xk[:10:2])

    def energy(x):
        k = np.empty((n, 2), dtype=np.float32)
        k[0,0] = 0
        k[0,1] = 0
        k[1:,:] = x.reshape((n-1, 2))
        c_interp = cumulative_interp(k)
        new_points = drift_points(points, c_interp)

        energy = 0

        for X in new_points:
            if X.shape[0] >= 10:
                E = ellipsoidFit(X)
                sx, sy = calculate_drift(E)
                print(sx, sy)
                energy += abs(sx)**2 + abs(sy)**2

        return energy

    res = minimize(energy, k, callback=callback)

    x = np.zeros(2*n, dtype=np.float32)
    x[2:] = res.x

    return x.reshape((n, 2))

def driftEstimate3(points, z_bounds):

    drift_parameters = []

    n = z_bounds[1] - z_bounds[0]

    k = np.zeros((n, 2), dtype=np.float32)

    n_passes = 1

    for i in tqdm(range(n_passes), desc='n passes'):

        print('k before:', k)

        for j in tqdm(range(1, n), desc='slices'):

            def energy(x):

                k_in = np.copy(k)
                k_in[j,:] = x
                c_interp = cumulative_interp(k_in)
                new_points = drift_points(points, c_interp)

                energy = 0

                for X in new_points:
                    if X.shape[0] >= 10:
                        E = ellipsoidFit(X)
                        sx, sy = calculate_drift(E)
                        energy += sx**2 + sy**2

                return energy

            #plt.plot(np.linspace(-0.4, 0.4, 15), [energy([x,0]) for x in np.linspace(-0.4, 0.4, 15)])
            #plt.show()

            res = minimize(energy, k[j,:])
            k[j,:] = res.x
            print('k after',i, ':', k)

    return k

if __name__ == '__main__':

    if False:
        import matplotlib.pyplot as plt

        data_path = '/home/dith/arngorf/code/fibsem_tools/data/image_transforms/synthetic_linear/'
        transforms_path = data_path + 'transforms.txt'
        points_path = data_path + 'vesicle_ellipsoid_manual_points_fixed.txt'

        true_drift = []

        with open(transforms_path, 'r') as f:
            for line in f.readlines():
                formated_line = list(map(float, line.rstrip().split(' ')))
                true_drift.append([formated_line[0], formated_line[1]])

        true_drift = np.array(true_drift)

        vesicle_points = []

        cur_vidx = -1
        with open(points_path, 'r') as f:
            for line in f.readlines():
                fline = line.rstrip().split(' ')
                vidx, x, y, z = int(fline[0]), float(fline[1]), float(fline[2]), float(fline[3])
                if vidx != cur_vidx:
                    cur_vidx = vidx
                    if len(vesicle_points) > 0:
                        vesicle_points[-1] = np.array(vesicle_points[-1])
                    vesicle_points.append([])

                vesicle_points[-1].append([x,y,z])

        vesicle_points[-1] = np.array(vesicle_points[-1])

        ellipsoids = []

        for X in vesicle_points:
            if X.shape[0] >= 10:
                E = ellipsoidFit(X)
                ellipsoids.append(E)

        results = driftEstimate(ellipsoids, (0, 200))
        results2 = driftEstimate3(vesicle_points, (0, 200))

        plt.subplot(2,1,1)
        plt.plot(true_drift[:,0], label='true', color='black', ls='--')
        plt.plot(results[0][:,1], label='original')
        plt.plot(results2[:,1], label='new')
        plt.xlim(0,199)

        plt.subplot(2,1,2)
        plt.plot(true_drift[:,1], label='true', color='black', ls='--')
        plt.plot(results[0][:,0], label='original')
        plt.plot(results2[:,0], label='new')
        plt.xlim(0,199)

        plt.show()

    sx = 0.25
    sy = 0.0

    a = 5
    b = 5
    c = 5

    vesicle_test_points = []

    for vidx in range(10):
        vesicle_test_points.append([])

        for u in np.linspace(0.1234, 2*np.pi, 9)[:-1]:
            for v in np.linspace(0.1234, np.pi, 9)[:-1]:
                x = a * np.cos(u) * np.sin(v) + np.random.uniform(-0.2, 0.2)
                y = b * np.sin(u) * np.sin(v) + np.random.uniform(-0.2, 0.2)
                z = c * np.cos(v) + np.random.uniform(-0.2, 0.2)

                z += 3 + vidx

                x += sx * z
                y += sy * z

                vesicle_test_points[-1].append([x,y,z])

        vesicle_test_points[-1] = np.array(vesicle_test_points[-1])

    X = np.round(vesicle_test_points[0],4)
    print(X)
    E = ellipsoidFit(X)
    E_X_sx, E_X_sy = calculate_drift(E)

    for i in range(16):

        plt.subplot(4,4,i+1)

        k = np.ones((10,2))
        k[:,0] = -2 + i*0.25
        k[:,1] = 0
        #k[4,0] = 2
        c_interp = cumulative_interp(k)
        new_points = drift_points([X], c_interp)

        plt.scatter(X[:,0], X[:,2], label='before, sx = ' + str(E_X_sx), marker='o')

        Y = new_points[-1]
        print(Y)
        E = ellipsoidFit(Y)
        E_Y_sx, E_Y_sy = calculate_drift(E)
        plt.scatter(Y[:,0], Y[:,2], label='after, sx = ' + str(E_Y_sx), marker='x')

    plt.legend()
    plt.show()

    sxs = np.linspace(-3, 3, 1000)
    sxs_fit_sq = []

    for sx in sxs:

        k = np.ones((10,2))
        k[:,0] = sx
        k[:,1] = 0
        #k[4,0] = 2
        c_interp = cumulative_interp(k)
        new_points = drift_points([X], c_interp)

        Y = new_points[-1]
        E = ellipsoidFit(Y)
        E_Y_sx, E_Y_sy = calculate_drift(E)

        sxs_fit_sq.append(E_Y_sx**2)

    plt.plot(sxs, sxs_fit_sq)
    plt.show()

    r = driftEstimate3(vesicle_test_points, (0, 6))

    print(np.round(r,2))