From e7a9ba66e7abc1036a63e558f70a8a5e799a7c03 Mon Sep 17 00:00:00 2001
From: glia <glia@dtu.dk>
Date: Mon, 24 Oct 2022 15:03:02 +0200
Subject: [PATCH] Preprocessing

---
 Preprocessing.py | 431 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 431 insertions(+)
 create mode 100644 Preprocessing.py

diff --git a/Preprocessing.py b/Preprocessing.py
new file mode 100644
index 0000000..7e15f86
--- /dev/null
+++ b/Preprocessing.py
@@ -0,0 +1,431 @@
+# -*- coding: utf-8 -*-
+"""
+Updated Oct 18 2022
+
+@author: Qianliang Li (glia@dtu.dk)
+
+The following preprocessing steps were employed:
+    1. Load the data
+    2. 1-100Hz Bandpass filtering
+    3. 50Hz Notch filtering
+    4. Retrieving event information about eyes open/closed states
+    5. Epoching to 4s non-overlapping segments
+    6. Visual inspection of the data
+        Bad channels and non-ocular artifacts are removed
+    7. Robust re-reference to common average (without bad channels)
+    8. Interpolation of bad channels
+    9. ICA artifact removal of ocular and ECG artifacts
+    10. AutoReject guided final visual inspection
+
+Due to privacy issues of clinical data, our data is not publically available.
+For demonstration purposes, I am using a publically available EEG dataset
+and treating it as resting-state eyes open/closed
+Link to the demonstration data: www.bci2000.org
+"""
+
+# Set working directory
+import os
+wkdir = "/home/glia/EEG"
+os.chdir(wkdir)
+
+# Load all libraries from the Preamble
+from Preamble import *
+
+# To demonstrate the script, a publically available EEG dataset are used
+# EEG recordings from 2 subjects are used as an example
+# The EEGBCI200 is task based, but we will treat it as "resting-state"
+# And the 2 runs as Eyes Closed and Eyes Open
+n_subjects = 2
+Subject_id = [1,2]
+# Download and get filenames
+from mne.datasets import eegbci
+files = []
+for i in range(n_subjects):
+    raw_fnames = eegbci.load_data(Subject_id[i], [1,2]) # The first 2 runs
+    files.append(raw_fnames)
+
+# # Original code to get filenames in a folder
+# data_path = "EEG_folder"
+# # Get filenames
+# files = []
+# for r, d, f in os.walk(data_path):
+#     for file in f:
+#         if ".bdf" in file:
+#             files.append(os.path.join(r, file))
+
+# Eye status
+anno_to_event = {'Eyes Closed': 1, 'Eyes Open': 2} # manually defined event id
+eye_status = list(anno_to_event.keys())
+n_eye_status = len(eye_status)
+
+# Epoch settings
+epoch_len = 4 # length of epochs in seconds
+n_epochs_trial = int(60/epoch_len) # number of epochs in a trial
+n_trials = 2 # 1 eyes closed followed by 1 eyes open
+
+# Montage settings
+montage = mne.channels.make_standard_montage('standard_1005')
+#montage = mne.channels.read_custom_montage(filename) # custom montage file
+
+# %% Load, filter and epoch (Steps 1 to 5)
+# Pre-allocate memory
+epochs = [0]*n_subjects
+for i in range(n_subjects):
+    # MNE python supports many EEG formats. Make sure to use the proper one
+    raw = mne.io.concatenate_raws([mne.io.read_raw_edf(f, preload=True) for f in files[i]])
+    # Fix EEGBCI channel names
+    eegbci.standardize(raw)
+    # Set montage
+    raw.set_montage(montage)
+    # Only use the EEG channels
+    raw.pick_types(meg=False, eeg=True, stim=False)
+    # Bandpass filter (will also detrend, median = 0)
+    raw.filter(1, 70, fir_design="firwin", verbose=0) # Due to sfreq = 160 I cannot lowpass at 100Hz
+    # raw.filter(1, 100, fir_design="firwin", verbose=0) # original line
+    # Notch filter to remove power-line noise (50Hz is common in Europe)
+    raw.notch_filter(50, fir_design="firwin", verbose=0)
+    # Epoch to 4 sec
+    # The first 60s are treated as Eyes Closed and following 60s as Eyes Open
+    event = np.zeros((int(2*n_epochs_trial),3), dtype=int) # manually make event array (n_events,3)
+    event[:,0] = np.array(np.linspace(0,2*60-epoch_len,int(2*n_epochs_trial))*int(raw.info["sfreq"]), dtype=int) # first column, the time for events
+    # Hardcoded based on data format
+    event[:n_epochs_trial,2] = 1 # Eyes closed
+    event[n_epochs_trial:,2] = 2 # Eyes open
+    # Make the epochs. (The -1/int(raw[n].info["sfreq"]) is needed because python start with 0, so 30000 points is 0 to 29999
+    epochs[i] = mne.Epochs(raw, event, event_id=anno_to_event,
+                 tmin=0, tmax=epoch_len-1/int(raw.info["sfreq"]),baseline=None, verbose=0).load_data()
+    
+    print("Subject:{} finished epoching ({}/{})".format(Subject_id[i],i+1,n_subjects))
+
+# Compute the number of epochs in each trial
+n_epochs = pd.DataFrame(np.zeros((n_subjects,n_trials)))
+for i in range(n_subjects):
+    for t in range(n_trials):
+        try:
+            n_epochs_ec = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[0]])
+        except:
+            n_epochs_ec = 0
+        try: 
+            n_epochs_eo = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[1]])
+        except:
+            n_epochs_eo = 0
+        n_epochs.iloc[i,t] = np.max([n_epochs_ec,n_epochs_eo])
+# Calculate cumulative sum which is used for later indexing
+cumsum_n_epochs = np.cumsum(n_epochs, axis=1)
+# Insert 0 as first column
+cumsum_n_epochs.insert(0,"Start",0)
+
+# %% First visual inspection, re-referencing and ch interpolation (Steps 6 to 8)
+# When in doubt, I did not exclude the data
+# Ocular artifacts are not removed, as we will use ICA to correct for those
+# Bad epochs are dropped
+# Bad channels are removed and interpolated on a trial basis
+# i.e. for each 1min eyes open or eyes closed trial
+
+# Visualize each
+epochs[0].plot(scalings=200e-6, n_epochs=15, n_channels=24)
+epochs[1].plot(scalings=200e-6, n_epochs=15, n_channels=24)
+
+# Pre-allocate memory
+bad_channels = [0]*len(epochs)
+bad_epochs = [0]*len(epochs)
+
+# Manually found bad data segments
+#bad_channels[0] = [0] # If all trials are fine I do not make a line with bad channels
+bad_epochs[0] = [6] # Epoch 6
+bad_channels[1] = ["T7",0] # T7 is bad in first eyes closed trial
+bad_epochs[1] = [12,13] # 2 bad epochs
+
+# Cautionary note: bad_epochs are using index. Not the epoch number in MNE-viewer
+# They are often the same, but not always, like in this demonstration
+# where there are no Epoch 15 because it was not full length data due to
+# the way I defined the event times, so when you drop Epoch idx 15, it is the
+# epoch that MNE labeled as 16 in this example!
+
+# Pre-allocate memory
+cleaned_epochs = epochs.copy()
+
+# Interpolate bad channels and re-reference to robust common average
+bad_ch_counter = 0 # count how many bad channels for summary
+bad_ch_idx = np.array(bad_channels,dtype=object).nonzero()[0]
+for i in range(len(bad_ch_idx)):
+    n = bad_ch_idx[i]
+    n_epochs_trial = int(60/epoch_len)
+    # Get epoch number for each subject
+    epoch_number = np.arange(0,len(epochs[n]))
+    # Pre-allocate memory
+    temp_epochs = [0]*n_trials
+    for trial in range(n_trials):
+        # Retrieve trial
+        trial_idx = (epoch_number >= cumsum_n_epochs.iloc[n,trial]) & (epoch_number < cumsum_n_epochs.iloc[n,trial+1])
+        temp_epochs[trial] = epochs[n].copy().drop(np.invert(trial_idx))
+        # Set bad channel
+        trial_bad_ch = bad_channels[bad_ch_idx[i]][trial]
+        if trial_bad_ch == 0:
+            # Do not perform interpolation, only re-referencing to average
+            temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
+        else:
+            if type(trial_bad_ch) == str: # fix if only one ch is provided
+                trial_bad_ch = [trial_bad_ch] # make it to list
+            temp_epochs[trial].info["bads"] = trial_bad_ch
+            # Re-reference to Cz
+            temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
+            # Interpolate the bad channels
+            temp_epochs[trial].interpolate_bads(reset_bads=True)
+            # Increase counter
+            bad_ch_counter += len(trial_bad_ch)
+    # Concatenate temporary epoch and save
+    cleaned_epochs[n] = mne.concatenate_epochs(temp_epochs, add_offset=False)
+    # Notice that an offset is still added when using to_data_frame!
+
+# Re-reference all other data (that did not have bad channels) to common avg
+good_ch_idx = np.where(np.array(bad_channels,dtype=object) == 0)[0]
+for i in range(len(good_ch_idx)):
+    n = good_ch_idx[i]
+    cleaned_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
+
+# Drop bad epochs
+bad_epoch_counter = 0
+bad_epochs_idx = np.array(bad_epochs,dtype=object).nonzero()[0]
+for i in range(len(bad_epochs_idx)):
+    n = bad_epochs_idx[i]
+    subject_bad_epochs = bad_epochs[n]
+    cleaned_epochs[n].drop(subject_bad_epochs)
+    bad_epoch_counter += len(subject_bad_epochs)
+
+# Summarize how many bad channels and epochs there were manually defined
+(bad_ch_counter/(n_trials*epochs[0].info["nchan"]*n_subjects))*100 # 0.4% bad channels
+(bad_epoch_counter/(141+150*(n_subjects-1)))*100 # 1% bad epochs rejected
+
+
+# %% ICA is performed to remove eye blinks, ECG and EOG artifacts
+# Make list to contain all ICAs
+ica = [0]*len(cleaned_epochs)
+# Make ICA objects
+for n in range(len(cleaned_epochs)):
+    # Matrix rank is -1 because I used common average reference
+    # If any channels were interpolated the rank is further reduced
+    if any([i in [n] for i in bad_ch_idx]):
+        inter_pol_ch = [i for i in bad_channels[n] if i != 0] # remove all 0
+        # Make a flat list to use np.unique
+        flat_list = []
+        for sublist in inter_pol_ch:
+            if type(sublist) == str: # fix if only one ch is provided
+                sublist = [sublist]
+            for item in sublist:
+                flat_list.append(item)
+        n_inter_pol_ch = len(np.unique(flat_list))
+        matrix_rank = cleaned_epochs[n].info["nchan"]-1-n_inter_pol_ch
+    else:
+        matrix_rank = cleaned_epochs[n].info["nchan"]-1
+
+    ica[n] = mne.preprocessing.ICA(method="fastica", random_state=42, verbose=0,
+                                       max_iter=500, n_components=matrix_rank)
+    ica[n].fit(cleaned_epochs[n])
+    print("{} out of {} ICAs processed".format(n+1,len(cleaned_epochs)))
+
+# Plot the components for visual inspection
+def ica_analysis(n):
+    plt.close("all")
+    # Plot original
+    cleaned_epochs[n].plot(scalings=200e-6, n_epochs = 10)
+    # Plot ICA - compontents
+    ica[n].plot_sources(cleaned_epochs[n], start = 0, stop = 10)
+    ica[n].plot_sources(cleaned_epochs[n], start = 5, stop = 8) # zoomed in helps for ECG artifact recognition
+    ica[n].plot_components(picks=np.arange(0,20))
+
+# Manual ICA decomposition visualization
+n = 1; ica_analysis(n) # manually change n
+
+# # Plot specific component for further inspection of the marked artifacts
+# ica[n].plot_properties(cleaned_epochs[n], picks = artifacts[n])
+
+# Pre-allocate memeory
+artifacts = [0]*len(cleaned_epochs)
+# Manually determine artifacts
+artifacts[0] = [0, 5, 12] # eye blinks, eye blinks, eye movement
+artifacts[1] = [0, 2] # eye blinks, eye movement
+
+# Remove the artifact components from the signal
+corrected_epochs = cleaned_epochs.copy()
+
+for n in range(len(cleaned_epochs)):
+    # Define the components with artifacts    
+    ica[n].exclude = artifacts[n]
+    # Remove on corrected data
+    ica[n].apply(corrected_epochs[n].load_data())
+
+# Inspect how the ICA worked
+n=0; corrected_epochs[n].plot(scalings=200e-6, n_epochs = 10)
+
+# %% Detect bad epochs automatically using AutoReject
+"""
+AutoReject will use cross-validation to determine optimal Peak-to-Peak threshold
+This threshold will be determined for each channel for each subject
+The threshold is used to mark whether each ch and epoch are bad
+Interpolation is performed on bad epoch/ch segments
+If many neighboring ch are bad, then the algorithm will score the worst ch
+(based on peak-to-peak amplitude) and only interpolate the worst
+If too many ch in one epoch are bad, the epoch will be rejected
+
+I am not using Autoreject directly to interpolate, but as a guide.
+The marked bad ch and epochs are then manually inspected and determined
+whether they should be dropped or not.
+The algorithm is run for each eye status separately
+"""
+
+# Suppress plots
+# import matplotlib
+# matplotlib.use("Agg")
+
+save_path = "./Autoreject_overview" # for autoreject overview
+# Pre-allocate memory
+reject_log = [0]*len(corrected_epochs)
+dropped_epochs = [0]*len(corrected_epochs)
+ar = [0]*len(corrected_epochs)
+mean_threshold = [0]*len(corrected_epochs)
+for i in range(len(corrected_epochs)):
+    reject_log0 = [0]*n_eye_status
+    ar0 = [0]*n_eye_status
+    mean_threshold0 = [0]*n_eye_status
+    drop_epochs0 = [0]*n_eye_status
+    for e in range(n_eye_status):
+        ee = eye_status[e]
+        # Initialize class
+        ar0[e] = AutoReject(consensus=np.linspace(0,1,11), cv=10, n_jobs=8,
+                            verbose=False, random_state=42)
+        # Fit to data - but do not transform
+        ar0[e].fit(corrected_epochs[i][ee])
+        # Get rejection log
+        reject_log0[e] = ar0[e].get_reject_log(corrected_epochs[i][ee])
+        # Plot and save Autorejected epochs
+        fig = reject_log0[e].plot(orientation="horizontal", show=False)
+        fig.savefig(os.path.join(save_path,"AR_" + str(Subject_id_concat[i]) + "_" + str(ee) + ".png"))
+        # Close figure window
+        plt.close(fig)
+        # Save mean peak-to-peak voltage threshold used
+        mean_threshold0[e] = np.mean(list(ar0[e].threshes_.values()))
+        # Save suggested dropped epochs
+        drop_epochs0[e] = reject_log0[e].bad_epochs
+    # Concatenate dropped epochs
+    drop_epochs1 = np.concatenate(drop_epochs0)
+    # Save
+    dropped_epochs[i] = drop_epochs1.nonzero()[0]
+    reject_log[i] = reject_log0
+    ar[i] = ar0
+    mean_threshold[i] = mean_threshold0
+    print("{} out of {} subjects finished autorejecting".format(i+1,len(corrected_epochs)))
+
+# Overview of dropped epochs
+Drop_epochs_df = pd.DataFrame.from_records(dropped_epochs) # convert to dataframe
+Drop_epochs_df.insert(0, "Subject_id", Subject_id)
+
+# Re-enable plots
+# import matplotlib.pyplot as plt
+# %matplotlib qt
+
+print(Drop_epochs_df) # no further bad epochs after ICA
+# But for demonstration purposes we will add one more bad epochs in subject 2
+
+### Dropped epochs are used as guide to manually inspect potential thresholded bad epochs
+# Visualize
+n=1 ; corrected_epochs[n].plot(scalings=100e-6, n_channels=31, n_epochs=15)
+
+bad_epochs_ar = dropped_epochs # make copy before modifying
+# Manually modify when appropriate
+# n = 0 - agreed to all
+bad_epochs_ar[1] = np.array([8], dtype="int64") # added 1 more than suggested
+# n = 1: T7 is an outlier and bad
+
+# Combine dropped epochs from first visual inspection and Autoreject
+bad_epochs_comb = [0]*len(corrected_epochs)
+for i in range(len(corrected_epochs)):
+    # Retrieve bad epochs from first manual inspection
+    bad1 = np.array(bad_epochs[i])
+    # If there are no dropped epochs from first inspection, just use second
+    if type(bad_epochs[i]) == int:
+        bad_epochs_comb[i] = bad_epochs_ar[i]
+        continue
+    else:
+        # Retrieve bad epochs from manual of AR
+        bad2 = bad_epochs_ar[i].copy()
+        # Fix index due to iterative dropped epoch
+        # E.g. if 45 is dropped, then all epochs above 45 will have 1 index lower and should be fixed
+        for drops in range(len(bad1)):
+            bad2[bad2 >= bad1[drops]] += 1
+        # Concatenate from both
+        bad_epochs_comb[i] = np.sort(np.concatenate([bad1,bad2]))
+
+# Convert to dataframe
+bad_epochs_comb_df = pd.DataFrame.from_records(bad_epochs_comb)
+bad_epochs_comb_df.insert(0, "Subject_ID", Subject_id)
+# Save
+bad_epochs_comb_df.to_pickle("./Preprocessing/dropped_epochs.pkl")
+
+# Drop bad epochs
+final_epochs = corrected_epochs.copy()
+bad_epochs_df = pd.DataFrame.from_records(bad_epochs_ar)
+bad_epoch_counter2 = 0
+bad_epochs_idx2 = np.where(bad_epochs_df.iloc[:,0].notnull())[0]
+for i in range(len(bad_epochs_idx2)):
+    n = bad_epochs_idx2[i]
+    subject_bad_epochs = bad_epochs_ar[n]
+    final_epochs[n].drop(subject_bad_epochs)
+    bad_epoch_counter2 += len(subject_bad_epochs)
+
+# Summarize how many bad channels and epochs there were manually defined
+(bad_epoch_counter2+bad_epoch_counter)/(141+150*(n_subjects-1))*100 # in total 1.37% epochs rejected
+
+# Re-reference and interpolate
+def re_reference_interpolate(input_epochs, n, bads):
+    # Re-reference to Cz from average
+    input_epochs[n].set_eeg_reference(ref_channels=["Cz"], verbose=0)
+    # Set bads
+    input_epochs[n].info["bads"] = bads
+    # Re-reference to average without bads
+    input_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
+    # Interpolate
+    input_epochs[n].interpolate_bads(reset_bads=True)
+
+re_reference_interpolate(final_epochs, 1, ["T7"]) # n = 1
+
+# Save all bad channels from first visual inspection and autorejection guided
+bad_channels[1] = ["T7"]*n_trials
+# To keep track of which channels were interpolated
+with open("./Preprocessing/bad_ch.pkl", "wb") as filehandle:
+    pickle.dump(bad_channels, filehandle)
+
+
+# %% Save the preprocessed epochs
+save_path = "./PreprocessedData"
+for n in range(len(corrected_epochs)):
+    final_epochs[n].save(fname = os.path.join(save_path,str("{}_preprocessed".format(Subject_id[n]) + "-epo.fif")),
+                    overwrite=True, verbose=0)
+
+# %% Evaluate dropped epochs to determine threshold for exclusion
+dropped_epochs_df = np.load("./Preprocessing/dropped_epochs.pkl", allow_pickle=True)
+
+# Convert the df to number of dropped epochs
+# In the original data I had multiple datasets and combined the dropped df here
+All_drop_epoch_df = dropped_epochs_df # originally using pd.concat
+
+def drop_epoch_counter(row,df_row):
+    res = np.sum(df_row.iloc[row,1:].notnull())
+    return res
+
+Number_dropped_epochs = [0]*All_drop_epoch_df.shape[0]
+for i in range(All_drop_epoch_df.shape[0]):
+    Number_dropped_epochs[i] = drop_epoch_counter(i,All_drop_epoch_df)
+
+All_drop_epoch_df2 = pd.DataFrame({"Subject_ID":All_drop_epoch_df["Subject_ID"],
+                                   "Number_drop_epoch":Number_dropped_epochs})
+
+# Plot histogram
+bin_seq = range(0,np.max(All_drop_epoch_df2["Number_drop_epoch"])+1)
+All_drop_epoch_df2.hist(column="Number_drop_epoch",figsize=(12,8),bins=bin_seq)
+
+# View all subjects with more than 15 dropped epochs
+All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>=15,:]
+# Number of dropped subjects using above 30 as cutoff
+len(list(All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>30,"Subject_ID"]))
+# Using cutoff for exclusion: 20% or more dropped epochs
-- 
GitLab