diff --git a/Preprocessing.py b/Preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..7e15f86980195f3688a24e693b878ac3adbcc523 --- /dev/null +++ b/Preprocessing.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +""" +Updated Oct 18 2022 + +@author: Qianliang Li (glia@dtu.dk) + +The following preprocessing steps were employed: + 1. Load the data + 2. 1-100Hz Bandpass filtering + 3. 50Hz Notch filtering + 4. Retrieving event information about eyes open/closed states + 5. Epoching to 4s non-overlapping segments + 6. Visual inspection of the data + Bad channels and non-ocular artifacts are removed + 7. Robust re-reference to common average (without bad channels) + 8. Interpolation of bad channels + 9. ICA artifact removal of ocular and ECG artifacts + 10. AutoReject guided final visual inspection + +Due to privacy issues of clinical data, our data is not publically available. +For demonstration purposes, I am using a publically available EEG dataset +and treating it as resting-state eyes open/closed +Link to the demonstration data: www.bci2000.org +""" + +# Set working directory +import os +wkdir = "/home/glia/EEG" +os.chdir(wkdir) + +# Load all libraries from the Preamble +from Preamble import * + +# To demonstrate the script, a publically available EEG dataset are used +# EEG recordings from 2 subjects are used as an example +# The EEGBCI200 is task based, but we will treat it as "resting-state" +# And the 2 runs as Eyes Closed and Eyes Open +n_subjects = 2 +Subject_id = [1,2] +# Download and get filenames +from mne.datasets import eegbci +files = [] +for i in range(n_subjects): + raw_fnames = eegbci.load_data(Subject_id[i], [1,2]) # The first 2 runs + files.append(raw_fnames) + +# # Original code to get filenames in a folder +# data_path = "EEG_folder" +# # Get filenames +# files = [] +# for r, d, f in os.walk(data_path): +# for file in f: +# if ".bdf" in file: +# files.append(os.path.join(r, file)) + +# Eye status +anno_to_event = {'Eyes Closed': 1, 'Eyes Open': 2} # manually defined event id +eye_status = list(anno_to_event.keys()) +n_eye_status = len(eye_status) + +# Epoch settings +epoch_len = 4 # length of epochs in seconds +n_epochs_trial = int(60/epoch_len) # number of epochs in a trial +n_trials = 2 # 1 eyes closed followed by 1 eyes open + +# Montage settings +montage = mne.channels.make_standard_montage('standard_1005') +#montage = mne.channels.read_custom_montage(filename) # custom montage file + +# %% Load, filter and epoch (Steps 1 to 5) +# Pre-allocate memory +epochs = [0]*n_subjects +for i in range(n_subjects): + # MNE python supports many EEG formats. Make sure to use the proper one + raw = mne.io.concatenate_raws([mne.io.read_raw_edf(f, preload=True) for f in files[i]]) + # Fix EEGBCI channel names + eegbci.standardize(raw) + # Set montage + raw.set_montage(montage) + # Only use the EEG channels + raw.pick_types(meg=False, eeg=True, stim=False) + # Bandpass filter (will also detrend, median = 0) + raw.filter(1, 70, fir_design="firwin", verbose=0) # Due to sfreq = 160 I cannot lowpass at 100Hz + # raw.filter(1, 100, fir_design="firwin", verbose=0) # original line + # Notch filter to remove power-line noise (50Hz is common in Europe) + raw.notch_filter(50, fir_design="firwin", verbose=0) + # Epoch to 4 sec + # The first 60s are treated as Eyes Closed and following 60s as Eyes Open + event = np.zeros((int(2*n_epochs_trial),3), dtype=int) # manually make event array (n_events,3) + event[:,0] = np.array(np.linspace(0,2*60-epoch_len,int(2*n_epochs_trial))*int(raw.info["sfreq"]), dtype=int) # first column, the time for events + # Hardcoded based on data format + event[:n_epochs_trial,2] = 1 # Eyes closed + event[n_epochs_trial:,2] = 2 # Eyes open + # Make the epochs. (The -1/int(raw[n].info["sfreq"]) is needed because python start with 0, so 30000 points is 0 to 29999 + epochs[i] = mne.Epochs(raw, event, event_id=anno_to_event, + tmin=0, tmax=epoch_len-1/int(raw.info["sfreq"]),baseline=None, verbose=0).load_data() + + print("Subject:{} finished epoching ({}/{})".format(Subject_id[i],i+1,n_subjects)) + +# Compute the number of epochs in each trial +n_epochs = pd.DataFrame(np.zeros((n_subjects,n_trials))) +for i in range(n_subjects): + for t in range(n_trials): + try: + n_epochs_ec = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[0]]) + except: + n_epochs_ec = 0 + try: + n_epochs_eo = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[1]]) + except: + n_epochs_eo = 0 + n_epochs.iloc[i,t] = np.max([n_epochs_ec,n_epochs_eo]) +# Calculate cumulative sum which is used for later indexing +cumsum_n_epochs = np.cumsum(n_epochs, axis=1) +# Insert 0 as first column +cumsum_n_epochs.insert(0,"Start",0) + +# %% First visual inspection, re-referencing and ch interpolation (Steps 6 to 8) +# When in doubt, I did not exclude the data +# Ocular artifacts are not removed, as we will use ICA to correct for those +# Bad epochs are dropped +# Bad channels are removed and interpolated on a trial basis +# i.e. for each 1min eyes open or eyes closed trial + +# Visualize each +epochs[0].plot(scalings=200e-6, n_epochs=15, n_channels=24) +epochs[1].plot(scalings=200e-6, n_epochs=15, n_channels=24) + +# Pre-allocate memory +bad_channels = [0]*len(epochs) +bad_epochs = [0]*len(epochs) + +# Manually found bad data segments +#bad_channels[0] = [0] # If all trials are fine I do not make a line with bad channels +bad_epochs[0] = [6] # Epoch 6 +bad_channels[1] = ["T7",0] # T7 is bad in first eyes closed trial +bad_epochs[1] = [12,13] # 2 bad epochs + +# Cautionary note: bad_epochs are using index. Not the epoch number in MNE-viewer +# They are often the same, but not always, like in this demonstration +# where there are no Epoch 15 because it was not full length data due to +# the way I defined the event times, so when you drop Epoch idx 15, it is the +# epoch that MNE labeled as 16 in this example! + +# Pre-allocate memory +cleaned_epochs = epochs.copy() + +# Interpolate bad channels and re-reference to robust common average +bad_ch_counter = 0 # count how many bad channels for summary +bad_ch_idx = np.array(bad_channels,dtype=object).nonzero()[0] +for i in range(len(bad_ch_idx)): + n = bad_ch_idx[i] + n_epochs_trial = int(60/epoch_len) + # Get epoch number for each subject + epoch_number = np.arange(0,len(epochs[n])) + # Pre-allocate memory + temp_epochs = [0]*n_trials + for trial in range(n_trials): + # Retrieve trial + trial_idx = (epoch_number >= cumsum_n_epochs.iloc[n,trial]) & (epoch_number < cumsum_n_epochs.iloc[n,trial+1]) + temp_epochs[trial] = epochs[n].copy().drop(np.invert(trial_idx)) + # Set bad channel + trial_bad_ch = bad_channels[bad_ch_idx[i]][trial] + if trial_bad_ch == 0: + # Do not perform interpolation, only re-referencing to average + temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0) + else: + if type(trial_bad_ch) == str: # fix if only one ch is provided + trial_bad_ch = [trial_bad_ch] # make it to list + temp_epochs[trial].info["bads"] = trial_bad_ch + # Re-reference to Cz + temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0) + # Interpolate the bad channels + temp_epochs[trial].interpolate_bads(reset_bads=True) + # Increase counter + bad_ch_counter += len(trial_bad_ch) + # Concatenate temporary epoch and save + cleaned_epochs[n] = mne.concatenate_epochs(temp_epochs, add_offset=False) + # Notice that an offset is still added when using to_data_frame! + +# Re-reference all other data (that did not have bad channels) to common avg +good_ch_idx = np.where(np.array(bad_channels,dtype=object) == 0)[0] +for i in range(len(good_ch_idx)): + n = good_ch_idx[i] + cleaned_epochs[n].set_eeg_reference(ref_channels="average", verbose=0) + +# Drop bad epochs +bad_epoch_counter = 0 +bad_epochs_idx = np.array(bad_epochs,dtype=object).nonzero()[0] +for i in range(len(bad_epochs_idx)): + n = bad_epochs_idx[i] + subject_bad_epochs = bad_epochs[n] + cleaned_epochs[n].drop(subject_bad_epochs) + bad_epoch_counter += len(subject_bad_epochs) + +# Summarize how many bad channels and epochs there were manually defined +(bad_ch_counter/(n_trials*epochs[0].info["nchan"]*n_subjects))*100 # 0.4% bad channels +(bad_epoch_counter/(141+150*(n_subjects-1)))*100 # 1% bad epochs rejected + + +# %% ICA is performed to remove eye blinks, ECG and EOG artifacts +# Make list to contain all ICAs +ica = [0]*len(cleaned_epochs) +# Make ICA objects +for n in range(len(cleaned_epochs)): + # Matrix rank is -1 because I used common average reference + # If any channels were interpolated the rank is further reduced + if any([i in [n] for i in bad_ch_idx]): + inter_pol_ch = [i for i in bad_channels[n] if i != 0] # remove all 0 + # Make a flat list to use np.unique + flat_list = [] + for sublist in inter_pol_ch: + if type(sublist) == str: # fix if only one ch is provided + sublist = [sublist] + for item in sublist: + flat_list.append(item) + n_inter_pol_ch = len(np.unique(flat_list)) + matrix_rank = cleaned_epochs[n].info["nchan"]-1-n_inter_pol_ch + else: + matrix_rank = cleaned_epochs[n].info["nchan"]-1 + + ica[n] = mne.preprocessing.ICA(method="fastica", random_state=42, verbose=0, + max_iter=500, n_components=matrix_rank) + ica[n].fit(cleaned_epochs[n]) + print("{} out of {} ICAs processed".format(n+1,len(cleaned_epochs))) + +# Plot the components for visual inspection +def ica_analysis(n): + plt.close("all") + # Plot original + cleaned_epochs[n].plot(scalings=200e-6, n_epochs = 10) + # Plot ICA - compontents + ica[n].plot_sources(cleaned_epochs[n], start = 0, stop = 10) + ica[n].plot_sources(cleaned_epochs[n], start = 5, stop = 8) # zoomed in helps for ECG artifact recognition + ica[n].plot_components(picks=np.arange(0,20)) + +# Manual ICA decomposition visualization +n = 1; ica_analysis(n) # manually change n + +# # Plot specific component for further inspection of the marked artifacts +# ica[n].plot_properties(cleaned_epochs[n], picks = artifacts[n]) + +# Pre-allocate memeory +artifacts = [0]*len(cleaned_epochs) +# Manually determine artifacts +artifacts[0] = [0, 5, 12] # eye blinks, eye blinks, eye movement +artifacts[1] = [0, 2] # eye blinks, eye movement + +# Remove the artifact components from the signal +corrected_epochs = cleaned_epochs.copy() + +for n in range(len(cleaned_epochs)): + # Define the components with artifacts + ica[n].exclude = artifacts[n] + # Remove on corrected data + ica[n].apply(corrected_epochs[n].load_data()) + +# Inspect how the ICA worked +n=0; corrected_epochs[n].plot(scalings=200e-6, n_epochs = 10) + +# %% Detect bad epochs automatically using AutoReject +""" +AutoReject will use cross-validation to determine optimal Peak-to-Peak threshold +This threshold will be determined for each channel for each subject +The threshold is used to mark whether each ch and epoch are bad +Interpolation is performed on bad epoch/ch segments +If many neighboring ch are bad, then the algorithm will score the worst ch +(based on peak-to-peak amplitude) and only interpolate the worst +If too many ch in one epoch are bad, the epoch will be rejected + +I am not using Autoreject directly to interpolate, but as a guide. +The marked bad ch and epochs are then manually inspected and determined +whether they should be dropped or not. +The algorithm is run for each eye status separately +""" + +# Suppress plots +# import matplotlib +# matplotlib.use("Agg") + +save_path = "./Autoreject_overview" # for autoreject overview +# Pre-allocate memory +reject_log = [0]*len(corrected_epochs) +dropped_epochs = [0]*len(corrected_epochs) +ar = [0]*len(corrected_epochs) +mean_threshold = [0]*len(corrected_epochs) +for i in range(len(corrected_epochs)): + reject_log0 = [0]*n_eye_status + ar0 = [0]*n_eye_status + mean_threshold0 = [0]*n_eye_status + drop_epochs0 = [0]*n_eye_status + for e in range(n_eye_status): + ee = eye_status[e] + # Initialize class + ar0[e] = AutoReject(consensus=np.linspace(0,1,11), cv=10, n_jobs=8, + verbose=False, random_state=42) + # Fit to data - but do not transform + ar0[e].fit(corrected_epochs[i][ee]) + # Get rejection log + reject_log0[e] = ar0[e].get_reject_log(corrected_epochs[i][ee]) + # Plot and save Autorejected epochs + fig = reject_log0[e].plot(orientation="horizontal", show=False) + fig.savefig(os.path.join(save_path,"AR_" + str(Subject_id_concat[i]) + "_" + str(ee) + ".png")) + # Close figure window + plt.close(fig) + # Save mean peak-to-peak voltage threshold used + mean_threshold0[e] = np.mean(list(ar0[e].threshes_.values())) + # Save suggested dropped epochs + drop_epochs0[e] = reject_log0[e].bad_epochs + # Concatenate dropped epochs + drop_epochs1 = np.concatenate(drop_epochs0) + # Save + dropped_epochs[i] = drop_epochs1.nonzero()[0] + reject_log[i] = reject_log0 + ar[i] = ar0 + mean_threshold[i] = mean_threshold0 + print("{} out of {} subjects finished autorejecting".format(i+1,len(corrected_epochs))) + +# Overview of dropped epochs +Drop_epochs_df = pd.DataFrame.from_records(dropped_epochs) # convert to dataframe +Drop_epochs_df.insert(0, "Subject_id", Subject_id) + +# Re-enable plots +# import matplotlib.pyplot as plt +# %matplotlib qt + +print(Drop_epochs_df) # no further bad epochs after ICA +# But for demonstration purposes we will add one more bad epochs in subject 2 + +### Dropped epochs are used as guide to manually inspect potential thresholded bad epochs +# Visualize +n=1 ; corrected_epochs[n].plot(scalings=100e-6, n_channels=31, n_epochs=15) + +bad_epochs_ar = dropped_epochs # make copy before modifying +# Manually modify when appropriate +# n = 0 - agreed to all +bad_epochs_ar[1] = np.array([8], dtype="int64") # added 1 more than suggested +# n = 1: T7 is an outlier and bad + +# Combine dropped epochs from first visual inspection and Autoreject +bad_epochs_comb = [0]*len(corrected_epochs) +for i in range(len(corrected_epochs)): + # Retrieve bad epochs from first manual inspection + bad1 = np.array(bad_epochs[i]) + # If there are no dropped epochs from first inspection, just use second + if type(bad_epochs[i]) == int: + bad_epochs_comb[i] = bad_epochs_ar[i] + continue + else: + # Retrieve bad epochs from manual of AR + bad2 = bad_epochs_ar[i].copy() + # Fix index due to iterative dropped epoch + # E.g. if 45 is dropped, then all epochs above 45 will have 1 index lower and should be fixed + for drops in range(len(bad1)): + bad2[bad2 >= bad1[drops]] += 1 + # Concatenate from both + bad_epochs_comb[i] = np.sort(np.concatenate([bad1,bad2])) + +# Convert to dataframe +bad_epochs_comb_df = pd.DataFrame.from_records(bad_epochs_comb) +bad_epochs_comb_df.insert(0, "Subject_ID", Subject_id) +# Save +bad_epochs_comb_df.to_pickle("./Preprocessing/dropped_epochs.pkl") + +# Drop bad epochs +final_epochs = corrected_epochs.copy() +bad_epochs_df = pd.DataFrame.from_records(bad_epochs_ar) +bad_epoch_counter2 = 0 +bad_epochs_idx2 = np.where(bad_epochs_df.iloc[:,0].notnull())[0] +for i in range(len(bad_epochs_idx2)): + n = bad_epochs_idx2[i] + subject_bad_epochs = bad_epochs_ar[n] + final_epochs[n].drop(subject_bad_epochs) + bad_epoch_counter2 += len(subject_bad_epochs) + +# Summarize how many bad channels and epochs there were manually defined +(bad_epoch_counter2+bad_epoch_counter)/(141+150*(n_subjects-1))*100 # in total 1.37% epochs rejected + +# Re-reference and interpolate +def re_reference_interpolate(input_epochs, n, bads): + # Re-reference to Cz from average + input_epochs[n].set_eeg_reference(ref_channels=["Cz"], verbose=0) + # Set bads + input_epochs[n].info["bads"] = bads + # Re-reference to average without bads + input_epochs[n].set_eeg_reference(ref_channels="average", verbose=0) + # Interpolate + input_epochs[n].interpolate_bads(reset_bads=True) + +re_reference_interpolate(final_epochs, 1, ["T7"]) # n = 1 + +# Save all bad channels from first visual inspection and autorejection guided +bad_channels[1] = ["T7"]*n_trials +# To keep track of which channels were interpolated +with open("./Preprocessing/bad_ch.pkl", "wb") as filehandle: + pickle.dump(bad_channels, filehandle) + + +# %% Save the preprocessed epochs +save_path = "./PreprocessedData" +for n in range(len(corrected_epochs)): + final_epochs[n].save(fname = os.path.join(save_path,str("{}_preprocessed".format(Subject_id[n]) + "-epo.fif")), + overwrite=True, verbose=0) + +# %% Evaluate dropped epochs to determine threshold for exclusion +dropped_epochs_df = np.load("./Preprocessing/dropped_epochs.pkl", allow_pickle=True) + +# Convert the df to number of dropped epochs +# In the original data I had multiple datasets and combined the dropped df here +All_drop_epoch_df = dropped_epochs_df # originally using pd.concat + +def drop_epoch_counter(row,df_row): + res = np.sum(df_row.iloc[row,1:].notnull()) + return res + +Number_dropped_epochs = [0]*All_drop_epoch_df.shape[0] +for i in range(All_drop_epoch_df.shape[0]): + Number_dropped_epochs[i] = drop_epoch_counter(i,All_drop_epoch_df) + +All_drop_epoch_df2 = pd.DataFrame({"Subject_ID":All_drop_epoch_df["Subject_ID"], + "Number_drop_epoch":Number_dropped_epochs}) + +# Plot histogram +bin_seq = range(0,np.max(All_drop_epoch_df2["Number_drop_epoch"])+1) +All_drop_epoch_df2.hist(column="Number_drop_epoch",figsize=(12,8),bins=bin_seq) + +# View all subjects with more than 15 dropped epochs +All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>=15,:] +# Number of dropped subjects using above 30 as cutoff +len(list(All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>30,"Subject_ID"])) +# Using cutoff for exclusion: 20% or more dropped epochs