# -*- coding: utf-8 -*-
"""
Updated Oct 18 2022

@author: Qianliang Li (glia@dtu.dk)

This preamble contains the code to load all the relevant libraries
Refer to the requirements.txt for the specific versions used
"""

# Libraries
import os
import sys
import re
import warnings
import time
import pickle
import concurrent.futures # for multiprocessing

import numpy as np # Arrays and mathematical computations
import matplotlib.pyplot as plt # Plotting
import mne # EEG library
import scipy # Signal processing
import sklearn # Machine learning
import nitime # Time series analysis
import nolds # DFA exponent
import statsmodels # multipletest
import pysparcl # Sparse Kmeans
import fooof # Peak Alpha Freq and 1/f exponents
import pandas as pd # Dataframes
import seaborn as sns # Plotting library
import autoreject # Automatic EEG artifact detection
import mlxtend # Sequential Forward Selection

from mne.time_frequency import psd_multitaper
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs, corrmap)
from mne.stats import spatio_temporal_cluster_test, permutation_cluster_test
from mne.channels import find_ch_adjacency
from mne.connectivity import spectral_connectivity

import nitime.analysis as nta
import nitime.timeseries as nts
import nitime.utils as ntsu
from nitime.viz import drawmatrix_channels, drawmatrix_channels_modified

from sklearn import preprocessing
from sklearn import manifold
from sklearn.svm import LinearSVC, SVC
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression, Ridge, LassoCV, RidgeCV, LogisticRegressionCV
from sklearn.model_selection import StratifiedKFold, GridSearchCV, StratifiedGroupKFold
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import RFECV
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, make_scorer

import matplotlib.gridspec as gridspec
from matplotlib import cm

from statsmodels.tsa.stattools import adfuller
from statsmodels.formula.api import mixedlm

from autoreject import AutoReject

from mlxtend.evaluate import permutation_test
from mlxtend.feature_selection import SequentialFeatureSelector as SFS
from mlxtend.plotting import plot_sequential_feature_selection as plot_sfs

from tqdm import tqdm # progress bars
from mayavi import mlab # Plotting with MNE
from mpl_toolkits.mplot3d import Axes3D # registers 3D projections

# Non-library scripts
# EEG microstate package by von Wegner & Lauf, 2018
from eeg_microstates import * # downloaded from https://github.com/Frederic-vW/eeg_microstates
# minimum Redundancy Maximum Relevance script by Kiran Karra
from feature_select import * # downloaded from https://github.com/stochasticresearch/featureselect/blob/master/python/feature_select.py

plt.style.use('ggplot') # plotting style

### Some of the functions in the libraries were modified by defining
# modified functions in the respective .py files in the different libraries

# Modified Kmeans in eeg_microstates
# Modified T_empirical in eeg_microstates
# Modified sparcl cluster_permute

# # For eeg_microstates.py
# def kmeans_return_all(data, n_maps, n_runs=10, maxerr=1e-6, maxiter=500):
#     """Modified K-means clustering as detailed in:
#     [1] Pascual-Marqui et al., IEEE TBME (1995) 42(7):658--665
#     [2] Murray et al., Brain Topography(2008) 20:249--264.
#     Variables named as in [1], step numbering as in Table I.

#     Args:
#         data: numpy.array, size = number of EEG channels
#         n_maps: number of microstate maps
#         n_runs: number of K-means runs (optional)
#         maxerr: maximum error for convergence (optional)
#         maxiter: maximum number of iterations (optional)
#         doplot: plot the results, default=False (optional)
#     Returns:
#         maps: microstate maps (number of maps x number of channels)
#         L: sequence of microstate labels
#         gfp_peaks: indices of local GFP maxima
#         gev: global explained variance (0..1)
#         cv: value of the cross-validation criterion
#     """
#     n_t = data.shape[0]
#     n_ch = data.shape[1]
#     data = data - data.mean(axis=1, keepdims=True)

#     # GFP peaks
#     gfp = np.std(data, axis=1)
#     gfp_peaks = locmax(gfp)
#     gfp_values = gfp[gfp_peaks]
#     gfp2 = np.sum(gfp_values**2) # normalizing constant in GEV
#     n_gfp = gfp_peaks.shape[0]

#     # clustering of GFP peak maps only
#     V = data[gfp_peaks, :]
#     sumV2 = np.sum(V**2)

#     # store results for each k-means run
#     cv_list =   []  # cross-validation criterion for each k-means run
#     gev_list =  []  # GEV of each map for each k-means run
#     gevT_list = []  # total GEV values for each k-means run
#     maps_list = []  # microstate maps for each k-means run
#     L_list =    []  # microstate label sequence for each k-means run
#     for run in range(n_runs):
#         # initialize random cluster centroids (indices w.r.t. n_gfp)
#         rndi = np.random.permutation(n_gfp)[:n_maps]
#         maps = V[rndi, :]
#         # normalize row-wise (across EEG channels)
#         maps /= np.sqrt(np.sum(maps**2, axis=1, keepdims=True))
#         # initialize
#         n_iter = 0
#         var0 = 1.0
#         var1 = 0.0
#         # convergence criterion: variance estimate (step 6)
#         while ( (np.abs((var0-var1)/var0) > maxerr) & (n_iter < maxiter) ):
#             # (step 3) microstate sequence (= current cluster assignment)
#             C = np.dot(V, maps.T)
#             C /= (n_ch*np.outer(gfp[gfp_peaks], np.std(maps, axis=1)))
#             L = np.argmax(C**2, axis=1)
#             # (step 4)
#             for k in range(n_maps):
#                 Vt = V[L==k, :]
#                 # (step 4a)
#                 Sk = np.dot(Vt.T, Vt)
#                 # (step 4b)
#                 evals, evecs = np.linalg.eig(Sk)
#                 v = evecs[:, np.argmax(np.abs(evals))]
#                 maps[k, :] = v/np.sqrt(np.sum(v**2))
#             # (step 5)
#             var1 = var0
#             var0 = sumV2 - np.sum(np.sum(maps[L, :]*V, axis=1)**2)
#             var0 /= (n_gfp*(n_ch-1))
#             n_iter += 1
#         if (n_iter < maxiter):
#             print("\t\tK-means run {:d}/{:d} converged after {:d} iterations.".format(run+1, n_runs, n_iter))
#         else:
#             print("\t\tK-means run {:d}/{:d} did NOT converge after {:d} iterations.".format(run+1, n_runs, maxiter))

#         # CROSS-VALIDATION criterion for this run (step 8)
#         C_ = np.dot(data, maps.T)
#         C_ /= (n_ch*np.outer(gfp, np.std(maps, axis=1)))
#         L_ = np.argmax(C_**2, axis=1)
#         var = np.sum(data**2) - np.sum(np.sum(maps[L_, :]*data, axis=1)**2)
#         var /= (n_t*(n_ch-1))
#         cv = var * (n_ch-1)**2/(n_ch-n_maps-1.)**2

#         # GEV (global explained variance) of cluster k
#         gev = np.zeros(n_maps)
#         for k in range(n_maps):
#             r = L==k
#             gev[k] = np.sum(gfp_values[r]**2 * C[r,k]**2)/gfp2
#         gev_total = np.sum(gev)

#         # store
#         cv_list.append(cv)
#         gev_list.append(gev)
#         gevT_list.append(gev_total)
#         maps_list.append(maps)
#         L_list.append(L_)

#     # select best run
#     k_opt = np.argmin(cv_list)
#     #k_opt = np.argmax(gevT_list)
#     maps = maps_list[k_opt]
#     # ms_gfp = ms_list[k_opt] # microstate sequence at GFP peaks
#     gev = gev_list[k_opt]
#     L_ = L_list[k_opt]
#     # lowest cv criterion
#     cv_min = np.min(cv_list)

#     return maps, L_, gfp_peaks, gev, cv_min

# # For eeg_microstates.py
# def T_empirical(data, n_clusters, gap_idx = []):
#     """Modified empirical transition matrix to take gap_idx argument

#     Args:
#         data: numpy.array, size = length of microstate sequence
#         n_clusters: number of microstate clusters
#         gap_idx: index for gaps in data which should be excluded in T
#     Returns:
#         T: empirical transition matrix
#     """
#     T = np.zeros((n_clusters, n_clusters))
#     n = len(data)
#     for i in range(n-1):
#         # Do not count transitions between gaps in data
#         if i in gap_idx:
#             continue
#         else:
#             T[data[i], data[i+1]] += 1.0
#     p_row = np.sum(T, axis=1)
#     for i in range(n_clusters):
#         if ( p_row[i] != 0.0 ):
#             for j in range(n_clusters):
#                 T[i,j] /= p_row[i]  # normalize row sums to 1.0
#     return T    

# # From sparcl.cluster.permute
# def permute_modified(x, k=None, nperms=25, wbounds=None, nvals=10, centers=None,
#             verbose=False):
#     # I added sdgaps output
#     n, p = x.shape
#     if k is None and centers is None:
#         raise ValueError('k and centers are None.')
#     if k is not None and centers is not None:
#         if centers.shape[0] != k or centers.shape[1] != p:
#             raise ValueError('Invalid shape of centers.')
#     if wbounds is None:
#         wbounds = np.exp(
#             np.linspace(np.log(1.2), np.log(np.sqrt(p) * 0.9), nvals))
#     if wbounds.min() <= 1 or len(wbounds) < 2:
#         raise ValueError('len(wbounds) and each wbound must be > 1.')

#     permx = np.zeros((nperms, n, p))
#     nnonzerows = None
#     for i in range(nperms):
#         for j in range(p):
#             permx[i, :, j] = np.random.permutation(x[:, j])
#     tots = None
#     out = kmeans(x, k, wbounds, centers=centers, verbose=verbose)

#     for i in range(len(out)):
#         nnonzerows = utils._cbind(nnonzerows, np.sum(out[i]['ws'] != 0))
#         bcss = subfunc._get_wcss(x, out[i]['cs'])[1]
#         tots = utils._cbind(tots, np.sum(out[i]['ws'] * bcss))
#     permtots = np.zeros((len(wbounds), nperms))
#     for i in range(nperms):
#         perm_out = kmeans(
#             permx[i], k, wbounds, centers=centers, verbose=verbose)
#         for j in range(len(perm_out)):
#             perm_bcss = subfunc._get_wcss(permx[i], perm_out[j]['cs'])[1]
#             permtots[j, i] = np.sum(perm_out[j]['ws'] * perm_bcss)
    
#     sdgaps = np.std(np.log(permtots),axis=1)
#     gaps = np.log(tots) - np.log(permtots).mean(axis=1)
#     bestw = wbounds[gaps.argmax()]
#     out = {'bestw': bestw, 'gaps': gaps, 'sdgaps': sdgaps, 'wbounds': wbounds,
#            'nnonzerows': nnonzerows}
#     return out