Skip to content
Snippets Groups Projects
Preprocessing.py 17.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • glia's avatar
    glia committed
    # -*- coding: utf-8 -*-
    """
    Updated Oct 18 2022
    
    @author: Qianliang Li (glia@dtu.dk)
    
    The following preprocessing steps were employed:
        1. Load the data
        2. 1-100Hz Bandpass filtering
        3. 50Hz Notch filtering
        4. Retrieving event information about eyes open/closed states
        5. Epoching to 4s non-overlapping segments
        6. Visual inspection of the data
            Bad channels and non-ocular artifacts are removed
        7. Robust re-reference to common average (without bad channels)
        8. Interpolation of bad channels
        9. ICA artifact removal of ocular and ECG artifacts
        10. AutoReject guided final visual inspection
    
    Due to privacy issues of clinical data, our data is not publically available.
    For demonstration purposes, I am using a publically available EEG dataset
    and treating it as resting-state eyes open/closed
    Link to the demonstration data: www.bci2000.org
    """
    
    # Set working directory
    import os
    wkdir = "/home/glia/EEG"
    os.chdir(wkdir)
    
    # Load all libraries from the Preamble
    from Preamble import *
    
    # To demonstrate the script, a publically available EEG dataset are used
    # EEG recordings from 2 subjects are used as an example
    # The EEGBCI200 is task based, but we will treat it as "resting-state"
    # And the 2 runs as Eyes Closed and Eyes Open
    n_subjects = 2
    Subject_id = [1,2]
    # Download and get filenames
    from mne.datasets import eegbci
    files = []
    for i in range(n_subjects):
        raw_fnames = eegbci.load_data(Subject_id[i], [1,2]) # The first 2 runs
        files.append(raw_fnames)
    
    # # Original code to get filenames in a folder
    # data_path = "EEG_folder"
    # # Get filenames
    # files = []
    # for r, d, f in os.walk(data_path):
    #     for file in f:
    #         if ".bdf" in file:
    #             files.append(os.path.join(r, file))
    
    # Eye status
    anno_to_event = {'Eyes Closed': 1, 'Eyes Open': 2} # manually defined event id
    eye_status = list(anno_to_event.keys())
    n_eye_status = len(eye_status)
    
    # Epoch settings
    epoch_len = 4 # length of epochs in seconds
    n_epochs_trial = int(60/epoch_len) # number of epochs in a trial
    n_trials = 2 # 1 eyes closed followed by 1 eyes open
    
    # Montage settings
    montage = mne.channels.make_standard_montage('standard_1005')
    #montage = mne.channels.read_custom_montage(filename) # custom montage file
    
    # %% Load, filter and epoch (Steps 1 to 5)
    # Pre-allocate memory
    epochs = [0]*n_subjects
    for i in range(n_subjects):
        # MNE python supports many EEG formats. Make sure to use the proper one
        raw = mne.io.concatenate_raws([mne.io.read_raw_edf(f, preload=True) for f in files[i]])
        # Fix EEGBCI channel names
        eegbci.standardize(raw)
        # Set montage
        raw.set_montage(montage)
        # Only use the EEG channels
        raw.pick_types(meg=False, eeg=True, stim=False)
        # Bandpass filter (will also detrend, median = 0)
        raw.filter(1, 70, fir_design="firwin", verbose=0) # Due to sfreq = 160 I cannot lowpass at 100Hz
        # raw.filter(1, 100, fir_design="firwin", verbose=0) # original line
        # Notch filter to remove power-line noise (50Hz is common in Europe)
        raw.notch_filter(50, fir_design="firwin", verbose=0)
        # Epoch to 4 sec
        # The first 60s are treated as Eyes Closed and following 60s as Eyes Open
        event = np.zeros((int(2*n_epochs_trial),3), dtype=int) # manually make event array (n_events,3)
        event[:,0] = np.array(np.linspace(0,2*60-epoch_len,int(2*n_epochs_trial))*int(raw.info["sfreq"]), dtype=int) # first column, the time for events
        # Hardcoded based on data format
        event[:n_epochs_trial,2] = 1 # Eyes closed
        event[n_epochs_trial:,2] = 2 # Eyes open
        # Make the epochs. (The -1/int(raw[n].info["sfreq"]) is needed because python start with 0, so 30000 points is 0 to 29999
        epochs[i] = mne.Epochs(raw, event, event_id=anno_to_event,
                     tmin=0, tmax=epoch_len-1/int(raw.info["sfreq"]),baseline=None, verbose=0).load_data()
        
        print("Subject:{} finished epoching ({}/{})".format(Subject_id[i],i+1,n_subjects))
    
    # Compute the number of epochs in each trial
    n_epochs = pd.DataFrame(np.zeros((n_subjects,n_trials)))
    for i in range(n_subjects):
        for t in range(n_trials):
            try:
                n_epochs_ec = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[0]])
            except:
                n_epochs_ec = 0
            try: 
                n_epochs_eo = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[1]])
            except:
                n_epochs_eo = 0
            n_epochs.iloc[i,t] = np.max([n_epochs_ec,n_epochs_eo])
    # Calculate cumulative sum which is used for later indexing
    cumsum_n_epochs = np.cumsum(n_epochs, axis=1)
    # Insert 0 as first column
    cumsum_n_epochs.insert(0,"Start",0)
    
    # %% First visual inspection, re-referencing and ch interpolation (Steps 6 to 8)
    # When in doubt, I did not exclude the data
    # Ocular artifacts are not removed, as we will use ICA to correct for those
    # Bad epochs are dropped
    # Bad channels are removed and interpolated on a trial basis
    # i.e. for each 1min eyes open or eyes closed trial
    
    # Visualize each
    epochs[0].plot(scalings=200e-6, n_epochs=15, n_channels=24)
    epochs[1].plot(scalings=200e-6, n_epochs=15, n_channels=24)
    
    # Pre-allocate memory
    bad_channels = [0]*len(epochs)
    bad_epochs = [0]*len(epochs)
    
    # Manually found bad data segments
    #bad_channels[0] = [0] # If all trials are fine I do not make a line with bad channels
    bad_epochs[0] = [6] # Epoch 6
    bad_channels[1] = ["T7",0] # T7 is bad in first eyes closed trial
    bad_epochs[1] = [12,13] # 2 bad epochs
    
    # Cautionary note: bad_epochs are using index. Not the epoch number in MNE-viewer
    # They are often the same, but not always, like in this demonstration
    # where there are no Epoch 15 because it was not full length data due to
    # the way I defined the event times, so when you drop Epoch idx 15, it is the
    # epoch that MNE labeled as 16 in this example!
    
    # Pre-allocate memory
    cleaned_epochs = epochs.copy()
    
    # Interpolate bad channels and re-reference to robust common average
    bad_ch_counter = 0 # count how many bad channels for summary
    bad_ch_idx = np.array(bad_channels,dtype=object).nonzero()[0]
    for i in range(len(bad_ch_idx)):
        n = bad_ch_idx[i]
        n_epochs_trial = int(60/epoch_len)
        # Get epoch number for each subject
        epoch_number = np.arange(0,len(epochs[n]))
        # Pre-allocate memory
        temp_epochs = [0]*n_trials
        for trial in range(n_trials):
            # Retrieve trial
            trial_idx = (epoch_number >= cumsum_n_epochs.iloc[n,trial]) & (epoch_number < cumsum_n_epochs.iloc[n,trial+1])
            temp_epochs[trial] = epochs[n].copy().drop(np.invert(trial_idx))
            # Set bad channel
            trial_bad_ch = bad_channels[bad_ch_idx[i]][trial]
            if trial_bad_ch == 0:
                # Do not perform interpolation, only re-referencing to average
                temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
            else:
                if type(trial_bad_ch) == str: # fix if only one ch is provided
                    trial_bad_ch = [trial_bad_ch] # make it to list
                temp_epochs[trial].info["bads"] = trial_bad_ch
                # Re-reference to Cz
                temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
                # Interpolate the bad channels
                temp_epochs[trial].interpolate_bads(reset_bads=True)
                # Increase counter
                bad_ch_counter += len(trial_bad_ch)
        # Concatenate temporary epoch and save
        cleaned_epochs[n] = mne.concatenate_epochs(temp_epochs, add_offset=False)
        # Notice that an offset is still added when using to_data_frame!
    
    # Re-reference all other data (that did not have bad channels) to common avg
    good_ch_idx = np.where(np.array(bad_channels,dtype=object) == 0)[0]
    for i in range(len(good_ch_idx)):
        n = good_ch_idx[i]
        cleaned_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
    
    # Drop bad epochs
    bad_epoch_counter = 0
    bad_epochs_idx = np.array(bad_epochs,dtype=object).nonzero()[0]
    for i in range(len(bad_epochs_idx)):
        n = bad_epochs_idx[i]
        subject_bad_epochs = bad_epochs[n]
        cleaned_epochs[n].drop(subject_bad_epochs)
        bad_epoch_counter += len(subject_bad_epochs)
    
    # Summarize how many bad channels and epochs there were manually defined
    (bad_ch_counter/(n_trials*epochs[0].info["nchan"]*n_subjects))*100 # 0.4% bad channels
    (bad_epoch_counter/(141+150*(n_subjects-1)))*100 # 1% bad epochs rejected
    
    
    # %% ICA is performed to remove eye blinks, ECG and EOG artifacts
    # Make list to contain all ICAs
    ica = [0]*len(cleaned_epochs)
    # Make ICA objects
    for n in range(len(cleaned_epochs)):
        # Matrix rank is -1 because I used common average reference
        # If any channels were interpolated the rank is further reduced
        if any([i in [n] for i in bad_ch_idx]):
            inter_pol_ch = [i for i in bad_channels[n] if i != 0] # remove all 0
            # Make a flat list to use np.unique
            flat_list = []
            for sublist in inter_pol_ch:
                if type(sublist) == str: # fix if only one ch is provided
                    sublist = [sublist]
                for item in sublist:
                    flat_list.append(item)
            n_inter_pol_ch = len(np.unique(flat_list))
            matrix_rank = cleaned_epochs[n].info["nchan"]-1-n_inter_pol_ch
        else:
            matrix_rank = cleaned_epochs[n].info["nchan"]-1
    
        ica[n] = mne.preprocessing.ICA(method="fastica", random_state=42, verbose=0,
                                           max_iter=500, n_components=matrix_rank)
        ica[n].fit(cleaned_epochs[n])
        print("{} out of {} ICAs processed".format(n+1,len(cleaned_epochs)))
    
    # Plot the components for visual inspection
    def ica_analysis(n):
        plt.close("all")
        # Plot original
        cleaned_epochs[n].plot(scalings=200e-6, n_epochs = 10)
        # Plot ICA - compontents
        ica[n].plot_sources(cleaned_epochs[n], start = 0, stop = 10)
        ica[n].plot_sources(cleaned_epochs[n], start = 5, stop = 8) # zoomed in helps for ECG artifact recognition
        ica[n].plot_components(picks=np.arange(0,20))
    
    # Manual ICA decomposition visualization
    n = 1; ica_analysis(n) # manually change n
    
    # # Plot specific component for further inspection of the marked artifacts
    # ica[n].plot_properties(cleaned_epochs[n], picks = artifacts[n])
    
    # Pre-allocate memeory
    artifacts = [0]*len(cleaned_epochs)
    # Manually determine artifacts
    artifacts[0] = [0, 5, 12] # eye blinks, eye blinks, eye movement
    artifacts[1] = [0, 2] # eye blinks, eye movement
    
    # Remove the artifact components from the signal
    corrected_epochs = cleaned_epochs.copy()
    
    for n in range(len(cleaned_epochs)):
        # Define the components with artifacts    
        ica[n].exclude = artifacts[n]
        # Remove on corrected data
        ica[n].apply(corrected_epochs[n].load_data())
    
    # Inspect how the ICA worked
    n=0; corrected_epochs[n].plot(scalings=200e-6, n_epochs = 10)
    
    # %% Detect bad epochs automatically using AutoReject
    """
    AutoReject will use cross-validation to determine optimal Peak-to-Peak threshold
    This threshold will be determined for each channel for each subject
    The threshold is used to mark whether each ch and epoch are bad
    Interpolation is performed on bad epoch/ch segments
    If many neighboring ch are bad, then the algorithm will score the worst ch
    (based on peak-to-peak amplitude) and only interpolate the worst
    If too many ch in one epoch are bad, the epoch will be rejected
    
    I am not using Autoreject directly to interpolate, but as a guide.
    The marked bad ch and epochs are then manually inspected and determined
    whether they should be dropped or not.
    The algorithm is run for each eye status separately
    """
    
    # Suppress plots
    # import matplotlib
    # matplotlib.use("Agg")
    
    save_path = "./Autoreject_overview" # for autoreject overview
    # Pre-allocate memory
    reject_log = [0]*len(corrected_epochs)
    dropped_epochs = [0]*len(corrected_epochs)
    ar = [0]*len(corrected_epochs)
    mean_threshold = [0]*len(corrected_epochs)
    for i in range(len(corrected_epochs)):
        reject_log0 = [0]*n_eye_status
        ar0 = [0]*n_eye_status
        mean_threshold0 = [0]*n_eye_status
        drop_epochs0 = [0]*n_eye_status
        for e in range(n_eye_status):
            ee = eye_status[e]
            # Initialize class
            ar0[e] = AutoReject(consensus=np.linspace(0,1,11), cv=10, n_jobs=8,
                                verbose=False, random_state=42)
            # Fit to data - but do not transform
            ar0[e].fit(corrected_epochs[i][ee])
            # Get rejection log
            reject_log0[e] = ar0[e].get_reject_log(corrected_epochs[i][ee])
            # Plot and save Autorejected epochs
            fig = reject_log0[e].plot(orientation="horizontal", show=False)
    
    glia's avatar
    glia committed
            fig.savefig(os.path.join(save_path,"AR_" + str(Subject_id[i]) + "_" + str(ee) + ".png"))
    
    glia's avatar
    glia committed
            # Close figure window
            plt.close(fig)
            # Save mean peak-to-peak voltage threshold used
            mean_threshold0[e] = np.mean(list(ar0[e].threshes_.values()))
            # Save suggested dropped epochs
            drop_epochs0[e] = reject_log0[e].bad_epochs
        # Concatenate dropped epochs
        drop_epochs1 = np.concatenate(drop_epochs0)
        # Save
        dropped_epochs[i] = drop_epochs1.nonzero()[0]
        reject_log[i] = reject_log0
        ar[i] = ar0
        mean_threshold[i] = mean_threshold0
        print("{} out of {} subjects finished autorejecting".format(i+1,len(corrected_epochs)))
    
    # Overview of dropped epochs
    Drop_epochs_df = pd.DataFrame.from_records(dropped_epochs) # convert to dataframe
    Drop_epochs_df.insert(0, "Subject_id", Subject_id)
    
    # Re-enable plots
    # import matplotlib.pyplot as plt
    # %matplotlib qt
    
    print(Drop_epochs_df) # no further bad epochs after ICA
    # But for demonstration purposes we will add one more bad epochs in subject 2
    
    ### Dropped epochs are used as guide to manually inspect potential thresholded bad epochs
    # Visualize
    n=1 ; corrected_epochs[n].plot(scalings=100e-6, n_channels=31, n_epochs=15)
    
    bad_epochs_ar = dropped_epochs # make copy before modifying
    # Manually modify when appropriate
    # n = 0 - agreed to all
    bad_epochs_ar[1] = np.array([8], dtype="int64") # added 1 more than suggested
    # n = 1: T7 is an outlier and bad
    
    # Combine dropped epochs from first visual inspection and Autoreject
    bad_epochs_comb = [0]*len(corrected_epochs)
    for i in range(len(corrected_epochs)):
        # Retrieve bad epochs from first manual inspection
        bad1 = np.array(bad_epochs[i])
        # If there are no dropped epochs from first inspection, just use second
        if type(bad_epochs[i]) == int:
            bad_epochs_comb[i] = bad_epochs_ar[i]
            continue
        else:
            # Retrieve bad epochs from manual of AR
            bad2 = bad_epochs_ar[i].copy()
            # Fix index due to iterative dropped epoch
            # E.g. if 45 is dropped, then all epochs above 45 will have 1 index lower and should be fixed
            for drops in range(len(bad1)):
                bad2[bad2 >= bad1[drops]] += 1
            # Concatenate from both
            bad_epochs_comb[i] = np.sort(np.concatenate([bad1,bad2]))
    
    # Convert to dataframe
    bad_epochs_comb_df = pd.DataFrame.from_records(bad_epochs_comb)
    bad_epochs_comb_df.insert(0, "Subject_ID", Subject_id)
    # Save
    bad_epochs_comb_df.to_pickle("./Preprocessing/dropped_epochs.pkl")
    
    # Drop bad epochs
    final_epochs = corrected_epochs.copy()
    bad_epochs_df = pd.DataFrame.from_records(bad_epochs_ar)
    bad_epoch_counter2 = 0
    bad_epochs_idx2 = np.where(bad_epochs_df.iloc[:,0].notnull())[0]
    for i in range(len(bad_epochs_idx2)):
        n = bad_epochs_idx2[i]
        subject_bad_epochs = bad_epochs_ar[n]
        final_epochs[n].drop(subject_bad_epochs)
        bad_epoch_counter2 += len(subject_bad_epochs)
    
    # Summarize how many bad channels and epochs there were manually defined
    (bad_epoch_counter2+bad_epoch_counter)/(141+150*(n_subjects-1))*100 # in total 1.37% epochs rejected
    
    # Re-reference and interpolate
    def re_reference_interpolate(input_epochs, n, bads):
        # Re-reference to Cz from average
        input_epochs[n].set_eeg_reference(ref_channels=["Cz"], verbose=0)
        # Set bads
        input_epochs[n].info["bads"] = bads
        # Re-reference to average without bads
        input_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
        # Interpolate
        input_epochs[n].interpolate_bads(reset_bads=True)
    
    re_reference_interpolate(final_epochs, 1, ["T7"]) # n = 1
    
    # Save all bad channels from first visual inspection and autorejection guided
    bad_channels[1] = ["T7"]*n_trials
    # To keep track of which channels were interpolated
    with open("./Preprocessing/bad_ch.pkl", "wb") as filehandle:
        pickle.dump(bad_channels, filehandle)
    
    
    # %% Save the preprocessed epochs
    save_path = "./PreprocessedData"
    for n in range(len(corrected_epochs)):
        final_epochs[n].save(fname = os.path.join(save_path,str("{}_preprocessed".format(Subject_id[n]) + "-epo.fif")),
                        overwrite=True, verbose=0)
    
    # %% Evaluate dropped epochs to determine threshold for exclusion
    dropped_epochs_df = np.load("./Preprocessing/dropped_epochs.pkl", allow_pickle=True)
    
    # Convert the df to number of dropped epochs
    # In the original data I had multiple datasets and combined the dropped df here
    All_drop_epoch_df = dropped_epochs_df # originally using pd.concat
    
    def drop_epoch_counter(row,df_row):
        res = np.sum(df_row.iloc[row,1:].notnull())
        return res
    
    Number_dropped_epochs = [0]*All_drop_epoch_df.shape[0]
    for i in range(All_drop_epoch_df.shape[0]):
        Number_dropped_epochs[i] = drop_epoch_counter(i,All_drop_epoch_df)
    
    All_drop_epoch_df2 = pd.DataFrame({"Subject_ID":All_drop_epoch_df["Subject_ID"],
                                       "Number_drop_epoch":Number_dropped_epochs})
    
    # Plot histogram
    bin_seq = range(0,np.max(All_drop_epoch_df2["Number_drop_epoch"])+1)
    All_drop_epoch_df2.hist(column="Number_drop_epoch",figsize=(12,8),bins=bin_seq)
    
    # View all subjects with more than 15 dropped epochs
    All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>=15,:]
    # Number of dropped subjects using above 30 as cutoff
    len(list(All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>30,"Subject_ID"]))
    # Using cutoff for exclusion: 20% or more dropped epochs