Skip to content
Snippets Groups Projects
DriftEstimate2.py 8 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    from EllipsoidFit import ellipsoidFit
    import pickle
    from scipy.interpolate import interp1d
    from scipy.optimize import minimize
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    from time import sleep
    
    def cumulative_interp(k):
        c = np.cumsum(k, axis=0)
        c_interp_0 = interp1d(range(0, len(k)), c[:,0], kind='linear', fill_value='extrapolate')
        c_interp_1 = interp1d(range(0, len(k)), c[:,1], kind='linear', fill_value='extrapolate')
        return (c_interp_0, c_interp_1)
    
    def calculate_drift(E):
        v, radii, center = E
    
        A, B, C, D, E, F = v
    
        sx = (D*F - B*E) / (A*B - D**2)
        sy = (D*E - A*F) / (A*B - D**2)
    
        return sx, sy
    
    def driftEstimate(ellipsoids, z_bounds, drift_spread=7, minimum_N=2):
    
        drift_parameters = []
    
        for E in ellipsoids:
    
            v, radii, center = E
    
            A, B, C, D, E, F = v
    
            sx = (D*F - B*E) / (A*B - D**2)
            sy = (D*E - A*F) / (A*B - D**2)
    
            drift_parameters.append([sx, sy, center[0], center[1], center[2]])
    
        z_lower, z_upper = z_bounds
    
        drift_bins = [[] for i in range(z_lower, z_upper)]
    
        drift_components = []
    
        for param in drift_parameters:
    
            drift_component = param[:2]
            center_z = param[4]
            center_k = int(round(center_z))
            drift_components.append(drift_component)
    
            for k in range(center_k - drift_spread, center_k + drift_spread + 1):
                if k >= z_lower and k < z_upper:
                    k_array = k - z_lower
                    drift_bins[k_array].append(drift_component)
    
        drift_bins_means = [np.mean(np.array(bin_group), axis=0) if len(bin_group) >= minimum_N else np.array([0, 0]) for bin_group in drift_bins]
        drift_bins_std = [np.std(np.array(bin_group), axis=0) if len(bin_group) >= minimum_N else np.array([-1, -1]) for bin_group in drift_bins]
        drift_bins_N = [len(bin_group) for bin_group in drift_bins]
    
        # Interpolation, filling missing values etc., here
    
        drift_estimate = np.array(drift_bins_means)
        drift_std = np.array(drift_bins_std)
        drift_bins_N = np.array(drift_bins_N)
    
        return drift_estimate, drift_std, drift_bins_N, drift_components
    
    def drift_points(points, c_interp):
    
        new_points = []
    
        c_interp_0, c_interp_1 = c_interp
    
        for X in points:
            new_X = np.copy(X)
            new_X[:,0] += c_interp_0(X[:,2])
            new_X[:,1] += c_interp_1(X[:,2])
            new_points.append(new_X)
    
        return new_points
    
    def driftEstimate2(points, z_bounds):
    
        drift_parameters = []
    
        n = z_bounds[1] - z_bounds[0]
    
        k = np.zeros((2*(n-1), 1), dtype=np.float32)
    
        def callback(xk):
            print(xk[:10:2])
    
        def energy(x):
            k = np.empty((n, 2), dtype=np.float32)
            k[0,0] = 0
            k[0,1] = 0
            k[1:,:] = x.reshape((n-1, 2))
            c_interp = cumulative_interp(k)
            new_points = drift_points(points, c_interp)
    
            energy = 0
    
            for X in new_points:
                if X.shape[0] >= 10:
                    E = ellipsoidFit(X)
                    sx, sy = calculate_drift(E)
                    print(sx, sy)
                    energy += abs(sx)**2 + abs(sy)**2
    
            return energy
    
        res = minimize(energy, k, callback=callback)
    
        x = np.zeros(2*n, dtype=np.float32)
        x[2:] = res.x
    
        return x.reshape((n, 2))
    
    def driftEstimate3(points, z_bounds):
    
        drift_parameters = []
    
        n = z_bounds[1] - z_bounds[0]
    
        k = np.zeros((n, 2), dtype=np.float32)
    
        n_passes = 1
    
        for i in tqdm(range(n_passes), desc='n passes'):
    
            print('k before:', k)
    
            for j in tqdm(range(1, n), desc='slices'):
    
                def energy(x):
    
                    k_in = np.copy(k)
                    k_in[j,:] = x
                    c_interp = cumulative_interp(k_in)
                    new_points = drift_points(points, c_interp)
    
                    energy = 0
    
                    for X in new_points:
                        if X.shape[0] >= 10:
                            E = ellipsoidFit(X)
                            sx, sy = calculate_drift(E)
                            energy += sx**2 + sy**2
    
                    return energy
    
                #plt.plot(np.linspace(-0.4, 0.4, 15), [energy([x,0]) for x in np.linspace(-0.4, 0.4, 15)])
                #plt.show()
    
                res = minimize(energy, k[j,:])
                k[j,:] = res.x
                print('k after',i, ':', k)
    
        return k
    
    if __name__ == '__main__':
    
        if False:
            import matplotlib.pyplot as plt
    
            data_path = '/home/dith/arngorf/code/fibsem_tools/data/image_transforms/synthetic_linear/'
            transforms_path = data_path + 'transforms.txt'
            points_path = data_path + 'vesicle_ellipsoid_manual_points_fixed.txt'
    
            true_drift = []
    
            with open(transforms_path, 'r') as f:
                for line in f.readlines():
                    formated_line = list(map(float, line.rstrip().split(' ')))
                    true_drift.append([formated_line[0], formated_line[1]])
    
            true_drift = np.array(true_drift)
    
            vesicle_points = []
    
            cur_vidx = -1
            with open(points_path, 'r') as f:
                for line in f.readlines():
                    fline = line.rstrip().split(' ')
                    vidx, x, y, z = int(fline[0]), float(fline[1]), float(fline[2]), float(fline[3])
                    if vidx != cur_vidx:
                        cur_vidx = vidx
                        if len(vesicle_points) > 0:
                            vesicle_points[-1] = np.array(vesicle_points[-1])
                        vesicle_points.append([])
    
                    vesicle_points[-1].append([x,y,z])
    
            vesicle_points[-1] = np.array(vesicle_points[-1])
    
            ellipsoids = []
    
            for X in vesicle_points:
                if X.shape[0] >= 10:
                    E = ellipsoidFit(X)
                    ellipsoids.append(E)
    
            results = driftEstimate(ellipsoids, (0, 200))
            results2 = driftEstimate3(vesicle_points, (0, 200))
    
            plt.subplot(2,1,1)
            plt.plot(true_drift[:,0], label='true', color='black', ls='--')
            plt.plot(results[0][:,1], label='original')
            plt.plot(results2[:,1], label='new')
            plt.xlim(0,199)
    
            plt.subplot(2,1,2)
            plt.plot(true_drift[:,1], label='true', color='black', ls='--')
            plt.plot(results[0][:,0], label='original')
            plt.plot(results2[:,0], label='new')
            plt.xlim(0,199)
    
            plt.show()
    
        sx = 0.25
        sy = 0.0
    
        a = 5
        b = 5
        c = 5
    
        vesicle_test_points = []
    
        for vidx in range(10):
            vesicle_test_points.append([])
    
            for u in np.linspace(0.1234, 2*np.pi, 9)[:-1]:
                for v in np.linspace(0.1234, np.pi, 9)[:-1]:
                    x = a * np.cos(u) * np.sin(v) + np.random.uniform(-0.2, 0.2)
                    y = b * np.sin(u) * np.sin(v) + np.random.uniform(-0.2, 0.2)
                    z = c * np.cos(v) + np.random.uniform(-0.2, 0.2)
    
                    z += 3 + vidx
    
                    x += sx * z
                    y += sy * z
    
                    vesicle_test_points[-1].append([x,y,z])
    
            vesicle_test_points[-1] = np.array(vesicle_test_points[-1])
    
        X = np.round(vesicle_test_points[0],4)
        print(X)
        E = ellipsoidFit(X)
        E_X_sx, E_X_sy = calculate_drift(E)
    
        for i in range(16):
    
            plt.subplot(4,4,i+1)
    
            k = np.ones((10,2))
            k[:,0] = -2 + i*0.25
            k[:,1] = 0
            #k[4,0] = 2
            c_interp = cumulative_interp(k)
            new_points = drift_points([X], c_interp)
    
            plt.scatter(X[:,0], X[:,2], label='before, sx = ' + str(E_X_sx), marker='o')
    
            Y = new_points[-1]
            print(Y)
            E = ellipsoidFit(Y)
            E_Y_sx, E_Y_sy = calculate_drift(E)
            plt.scatter(Y[:,0], Y[:,2], label='after, sx = ' + str(E_Y_sx), marker='x')
    
        plt.legend()
        plt.show()
    
        sxs = np.linspace(-3, 3, 1000)
        sxs_fit_sq = []
    
        for sx in sxs:
    
            k = np.ones((10,2))
            k[:,0] = sx
            k[:,1] = 0
            #k[4,0] = 2
            c_interp = cumulative_interp(k)
            new_points = drift_points([X], c_interp)
    
            Y = new_points[-1]
            E = ellipsoidFit(Y)
            E_Y_sx, E_Y_sy = calculate_drift(E)
    
            sxs_fit_sq.append(E_Y_sx**2)
    
        plt.plot(sxs, sxs_fit_sq)
        plt.show()
    
        r = driftEstimate3(vesicle_test_points, (0, 6))
    
        print(np.round(r,2))