from sklearn import decomposition
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np


def PCA(X, y, n_components):
    fig = plt.figure(1, figsize=(4, 3))
    plt.clf()

    ax = fig.add_subplot(111, projection="3d", elev=48, azim=134)
    ax.set_position([0, 0, 0.95, 1])

    # Standardize the data
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    pca = decomposition.PCA(n_components=n_components)
    pca.fit(X)
    X = pca.transform(X)

    for name, label in [("CTRL", 0), ("PTSD", 1)]:
        ax.text3D(
            X[y == label, 0].mean(),
            X[y == label, 1].mean(),
            X[y == label, 2].mean(),
            name,
            horizontalalignment="center",
            bbox=dict(alpha=0.5, edgecolor="w", facecolor="w"),
        )

    ax.scatter(
        X[:, 0],
        X[:, 1],
        X[:, 2],
        c=y,
        cmap=plt.cm.nipy_spectral,
        edgecolor="k",
    )

    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.zaxis.set_ticklabels([])

    plt.show()

    # Save plot 
    Feature_savepath = "./Figures/"
    plt.savefig(Feature_savepath + "PCA_3D.png", dpi=300)

    return "PCA was complete without errors - check the plot in your chosen path"