Skip to content
Snippets Groups Projects
Preprocessing.py 17.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • glia's avatar
    glia committed
    1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
    # -*- coding: utf-8 -*-
    """
    Updated Oct 18 2022
    
    @author: Qianliang Li (glia@dtu.dk)
    
    The following preprocessing steps were employed:
        1. Load the data
        2. 1-100Hz Bandpass filtering
        3. 50Hz Notch filtering
        4. Retrieving event information about eyes open/closed states
        5. Epoching to 4s non-overlapping segments
        6. Visual inspection of the data
            Bad channels and non-ocular artifacts are removed
        7. Robust re-reference to common average (without bad channels)
        8. Interpolation of bad channels
        9. ICA artifact removal of ocular and ECG artifacts
        10. AutoReject guided final visual inspection
    
    Due to privacy issues of clinical data, our data is not publically available.
    For demonstration purposes, I am using a publically available EEG dataset
    and treating it as resting-state eyes open/closed
    Link to the demonstration data: www.bci2000.org
    """
    
    # Set working directory
    import os
    wkdir = "/home/glia/EEG"
    os.chdir(wkdir)
    
    # Load all libraries from the Preamble
    from Preamble import *
    
    # To demonstrate the script, a publically available EEG dataset are used
    # EEG recordings from 2 subjects are used as an example
    # The EEGBCI200 is task based, but we will treat it as "resting-state"
    # And the 2 runs as Eyes Closed and Eyes Open
    n_subjects = 2
    Subject_id = [1,2]
    # Download and get filenames
    from mne.datasets import eegbci
    files = []
    for i in range(n_subjects):
        raw_fnames = eegbci.load_data(Subject_id[i], [1,2]) # The first 2 runs
        files.append(raw_fnames)
    
    # # Original code to get filenames in a folder
    # data_path = "EEG_folder"
    # # Get filenames
    # files = []
    # for r, d, f in os.walk(data_path):
    #     for file in f:
    #         if ".bdf" in file:
    #             files.append(os.path.join(r, file))
    
    # Eye status
    anno_to_event = {'Eyes Closed': 1, 'Eyes Open': 2} # manually defined event id
    eye_status = list(anno_to_event.keys())
    n_eye_status = len(eye_status)
    
    # Epoch settings
    epoch_len = 4 # length of epochs in seconds
    n_epochs_trial = int(60/epoch_len) # number of epochs in a trial
    n_trials = 2 # 1 eyes closed followed by 1 eyes open
    
    # Montage settings
    montage = mne.channels.make_standard_montage('standard_1005')
    #montage = mne.channels.read_custom_montage(filename) # custom montage file
    
    # %% Load, filter and epoch (Steps 1 to 5)
    # Pre-allocate memory
    epochs = [0]*n_subjects
    for i in range(n_subjects):
        # MNE python supports many EEG formats. Make sure to use the proper one
        raw = mne.io.concatenate_raws([mne.io.read_raw_edf(f, preload=True) for f in files[i]])
        # Fix EEGBCI channel names
        eegbci.standardize(raw)
        # Set montage
        raw.set_montage(montage)
        # Only use the EEG channels
        raw.pick_types(meg=False, eeg=True, stim=False)
        # Bandpass filter (will also detrend, median = 0)
        raw.filter(1, 70, fir_design="firwin", verbose=0) # Due to sfreq = 160 I cannot lowpass at 100Hz
        # raw.filter(1, 100, fir_design="firwin", verbose=0) # original line
        # Notch filter to remove power-line noise (50Hz is common in Europe)
        raw.notch_filter(50, fir_design="firwin", verbose=0)
        # Epoch to 4 sec
        # The first 60s are treated as Eyes Closed and following 60s as Eyes Open
        event = np.zeros((int(2*n_epochs_trial),3), dtype=int) # manually make event array (n_events,3)
        event[:,0] = np.array(np.linspace(0,2*60-epoch_len,int(2*n_epochs_trial))*int(raw.info["sfreq"]), dtype=int) # first column, the time for events
        # Hardcoded based on data format
        event[:n_epochs_trial,2] = 1 # Eyes closed
        event[n_epochs_trial:,2] = 2 # Eyes open
        # Make the epochs. (The -1/int(raw[n].info["sfreq"]) is needed because python start with 0, so 30000 points is 0 to 29999
        epochs[i] = mne.Epochs(raw, event, event_id=anno_to_event,
                     tmin=0, tmax=epoch_len-1/int(raw.info["sfreq"]),baseline=None, verbose=0).load_data()
        
        print("Subject:{} finished epoching ({}/{})".format(Subject_id[i],i+1,n_subjects))
    
    # Compute the number of epochs in each trial
    n_epochs = pd.DataFrame(np.zeros((n_subjects,n_trials)))
    for i in range(n_subjects):
        for t in range(n_trials):
            try:
                n_epochs_ec = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[0]])
            except:
                n_epochs_ec = 0
            try: 
                n_epochs_eo = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[1]])
            except:
                n_epochs_eo = 0
            n_epochs.iloc[i,t] = np.max([n_epochs_ec,n_epochs_eo])
    # Calculate cumulative sum which is used for later indexing
    cumsum_n_epochs = np.cumsum(n_epochs, axis=1)
    # Insert 0 as first column
    cumsum_n_epochs.insert(0,"Start",0)
    
    # %% First visual inspection, re-referencing and ch interpolation (Steps 6 to 8)
    # When in doubt, I did not exclude the data
    # Ocular artifacts are not removed, as we will use ICA to correct for those
    # Bad epochs are dropped
    # Bad channels are removed and interpolated on a trial basis
    # i.e. for each 1min eyes open or eyes closed trial
    
    # Visualize each
    epochs[0].plot(scalings=200e-6, n_epochs=15, n_channels=24)
    epochs[1].plot(scalings=200e-6, n_epochs=15, n_channels=24)
    
    # Pre-allocate memory
    bad_channels = [0]*len(epochs)
    bad_epochs = [0]*len(epochs)
    
    # Manually found bad data segments
    #bad_channels[0] = [0] # If all trials are fine I do not make a line with bad channels
    bad_epochs[0] = [6] # Epoch 6
    bad_channels[1] = ["T7",0] # T7 is bad in first eyes closed trial
    bad_epochs[1] = [12,13] # 2 bad epochs
    
    # Cautionary note: bad_epochs are using index. Not the epoch number in MNE-viewer
    # They are often the same, but not always, like in this demonstration
    # where there are no Epoch 15 because it was not full length data due to
    # the way I defined the event times, so when you drop Epoch idx 15, it is the
    # epoch that MNE labeled as 16 in this example!
    
    # Pre-allocate memory
    cleaned_epochs = epochs.copy()
    
    # Interpolate bad channels and re-reference to robust common average
    bad_ch_counter = 0 # count how many bad channels for summary
    bad_ch_idx = np.array(bad_channels,dtype=object).nonzero()[0]
    for i in range(len(bad_ch_idx)):
        n = bad_ch_idx[i]
        n_epochs_trial = int(60/epoch_len)
        # Get epoch number for each subject
        epoch_number = np.arange(0,len(epochs[n]))
        # Pre-allocate memory
        temp_epochs = [0]*n_trials
        for trial in range(n_trials):
            # Retrieve trial
            trial_idx = (epoch_number >= cumsum_n_epochs.iloc[n,trial]) & (epoch_number < cumsum_n_epochs.iloc[n,trial+1])
            temp_epochs[trial] = epochs[n].copy().drop(np.invert(trial_idx))
            # Set bad channel
            trial_bad_ch = bad_channels[bad_ch_idx[i]][trial]
            if trial_bad_ch == 0:
                # Do not perform interpolation, only re-referencing to average
                temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
            else:
                if type(trial_bad_ch) == str: # fix if only one ch is provided
                    trial_bad_ch = [trial_bad_ch] # make it to list
                temp_epochs[trial].info["bads"] = trial_bad_ch
                # Re-reference to Cz
                temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
                # Interpolate the bad channels
                temp_epochs[trial].interpolate_bads(reset_bads=True)
                # Increase counter
                bad_ch_counter += len(trial_bad_ch)
        # Concatenate temporary epoch and save
        cleaned_epochs[n] = mne.concatenate_epochs(temp_epochs, add_offset=False)
        # Notice that an offset is still added when using to_data_frame!
    
    # Re-reference all other data (that did not have bad channels) to common avg
    good_ch_idx = np.where(np.array(bad_channels,dtype=object) == 0)[0]
    for i in range(len(good_ch_idx)):
        n = good_ch_idx[i]
        cleaned_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
    
    # Drop bad epochs
    bad_epoch_counter = 0
    bad_epochs_idx = np.array(bad_epochs,dtype=object).nonzero()[0]
    for i in range(len(bad_epochs_idx)):
        n = bad_epochs_idx[i]
        subject_bad_epochs = bad_epochs[n]
        cleaned_epochs[n].drop(subject_bad_epochs)
        bad_epoch_counter += len(subject_bad_epochs)
    
    # Summarize how many bad channels and epochs there were manually defined
    (bad_ch_counter/(n_trials*epochs[0].info["nchan"]*n_subjects))*100 # 0.4% bad channels
    (bad_epoch_counter/(141+150*(n_subjects-1)))*100 # 1% bad epochs rejected
    
    
    # %% ICA is performed to remove eye blinks, ECG and EOG artifacts
    # Make list to contain all ICAs
    ica = [0]*len(cleaned_epochs)
    # Make ICA objects
    for n in range(len(cleaned_epochs)):
        # Matrix rank is -1 because I used common average reference
        # If any channels were interpolated the rank is further reduced
        if any([i in [n] for i in bad_ch_idx]):
            inter_pol_ch = [i for i in bad_channels[n] if i != 0] # remove all 0
            # Make a flat list to use np.unique
            flat_list = []
            for sublist in inter_pol_ch:
                if type(sublist) == str: # fix if only one ch is provided
                    sublist = [sublist]
                for item in sublist:
                    flat_list.append(item)
            n_inter_pol_ch = len(np.unique(flat_list))
            matrix_rank = cleaned_epochs[n].info["nchan"]-1-n_inter_pol_ch
        else:
            matrix_rank = cleaned_epochs[n].info["nchan"]-1
    
        ica[n] = mne.preprocessing.ICA(method="fastica", random_state=42, verbose=0,
                                           max_iter=500, n_components=matrix_rank)
        ica[n].fit(cleaned_epochs[n])
        print("{} out of {} ICAs processed".format(n+1,len(cleaned_epochs)))
    
    # Plot the components for visual inspection
    def ica_analysis(n):
        plt.close("all")
        # Plot original
        cleaned_epochs[n].plot(scalings=200e-6, n_epochs = 10)
        # Plot ICA - compontents
        ica[n].plot_sources(cleaned_epochs[n], start = 0, stop = 10)
        ica[n].plot_sources(cleaned_epochs[n], start = 5, stop = 8) # zoomed in helps for ECG artifact recognition
        ica[n].plot_components(picks=np.arange(0,20))
    
    # Manual ICA decomposition visualization
    n = 1; ica_analysis(n) # manually change n
    
    # # Plot specific component for further inspection of the marked artifacts
    # ica[n].plot_properties(cleaned_epochs[n], picks = artifacts[n])
    
    # Pre-allocate memeory
    artifacts = [0]*len(cleaned_epochs)
    # Manually determine artifacts
    artifacts[0] = [0, 5, 12] # eye blinks, eye blinks, eye movement
    artifacts[1] = [0, 2] # eye blinks, eye movement
    
    # Remove the artifact components from the signal
    corrected_epochs = cleaned_epochs.copy()
    
    for n in range(len(cleaned_epochs)):
        # Define the components with artifacts    
        ica[n].exclude = artifacts[n]
        # Remove on corrected data
        ica[n].apply(corrected_epochs[n].load_data())
    
    # Inspect how the ICA worked
    n=0; corrected_epochs[n].plot(scalings=200e-6, n_epochs = 10)
    
    # %% Detect bad epochs automatically using AutoReject
    """
    AutoReject will use cross-validation to determine optimal Peak-to-Peak threshold
    This threshold will be determined for each channel for each subject
    The threshold is used to mark whether each ch and epoch are bad
    Interpolation is performed on bad epoch/ch segments
    If many neighboring ch are bad, then the algorithm will score the worst ch
    (based on peak-to-peak amplitude) and only interpolate the worst
    If too many ch in one epoch are bad, the epoch will be rejected
    
    I am not using Autoreject directly to interpolate, but as a guide.
    The marked bad ch and epochs are then manually inspected and determined
    whether they should be dropped or not.
    The algorithm is run for each eye status separately
    """
    
    # Suppress plots
    # import matplotlib
    # matplotlib.use("Agg")
    
    save_path = "./Autoreject_overview" # for autoreject overview
    # Pre-allocate memory
    reject_log = [0]*len(corrected_epochs)
    dropped_epochs = [0]*len(corrected_epochs)
    ar = [0]*len(corrected_epochs)
    mean_threshold = [0]*len(corrected_epochs)
    for i in range(len(corrected_epochs)):
        reject_log0 = [0]*n_eye_status
        ar0 = [0]*n_eye_status
        mean_threshold0 = [0]*n_eye_status
        drop_epochs0 = [0]*n_eye_status
        for e in range(n_eye_status):
            ee = eye_status[e]
            # Initialize class
            ar0[e] = AutoReject(consensus=np.linspace(0,1,11), cv=10, n_jobs=8,
                                verbose=False, random_state=42)
            # Fit to data - but do not transform
            ar0[e].fit(corrected_epochs[i][ee])
            # Get rejection log
            reject_log0[e] = ar0[e].get_reject_log(corrected_epochs[i][ee])
            # Plot and save Autorejected epochs
            fig = reject_log0[e].plot(orientation="horizontal", show=False)
            fig.savefig(os.path.join(save_path,"AR_" + str(Subject_id_concat[i]) + "_" + str(ee) + ".png"))
            # Close figure window
            plt.close(fig)
            # Save mean peak-to-peak voltage threshold used
            mean_threshold0[e] = np.mean(list(ar0[e].threshes_.values()))
            # Save suggested dropped epochs
            drop_epochs0[e] = reject_log0[e].bad_epochs
        # Concatenate dropped epochs
        drop_epochs1 = np.concatenate(drop_epochs0)
        # Save
        dropped_epochs[i] = drop_epochs1.nonzero()[0]
        reject_log[i] = reject_log0
        ar[i] = ar0
        mean_threshold[i] = mean_threshold0
        print("{} out of {} subjects finished autorejecting".format(i+1,len(corrected_epochs)))
    
    # Overview of dropped epochs
    Drop_epochs_df = pd.DataFrame.from_records(dropped_epochs) # convert to dataframe
    Drop_epochs_df.insert(0, "Subject_id", Subject_id)
    
    # Re-enable plots
    # import matplotlib.pyplot as plt
    # %matplotlib qt
    
    print(Drop_epochs_df) # no further bad epochs after ICA
    # But for demonstration purposes we will add one more bad epochs in subject 2
    
    ### Dropped epochs are used as guide to manually inspect potential thresholded bad epochs
    # Visualize
    n=1 ; corrected_epochs[n].plot(scalings=100e-6, n_channels=31, n_epochs=15)
    
    bad_epochs_ar = dropped_epochs # make copy before modifying
    # Manually modify when appropriate
    # n = 0 - agreed to all
    bad_epochs_ar[1] = np.array([8], dtype="int64") # added 1 more than suggested
    # n = 1: T7 is an outlier and bad
    
    # Combine dropped epochs from first visual inspection and Autoreject
    bad_epochs_comb = [0]*len(corrected_epochs)
    for i in range(len(corrected_epochs)):
        # Retrieve bad epochs from first manual inspection
        bad1 = np.array(bad_epochs[i])
        # If there are no dropped epochs from first inspection, just use second
        if type(bad_epochs[i]) == int:
            bad_epochs_comb[i] = bad_epochs_ar[i]
            continue
        else:
            # Retrieve bad epochs from manual of AR
            bad2 = bad_epochs_ar[i].copy()
            # Fix index due to iterative dropped epoch
            # E.g. if 45 is dropped, then all epochs above 45 will have 1 index lower and should be fixed
            for drops in range(len(bad1)):
                bad2[bad2 >= bad1[drops]] += 1
            # Concatenate from both
            bad_epochs_comb[i] = np.sort(np.concatenate([bad1,bad2]))
    
    # Convert to dataframe
    bad_epochs_comb_df = pd.DataFrame.from_records(bad_epochs_comb)
    bad_epochs_comb_df.insert(0, "Subject_ID", Subject_id)
    # Save
    bad_epochs_comb_df.to_pickle("./Preprocessing/dropped_epochs.pkl")
    
    # Drop bad epochs
    final_epochs = corrected_epochs.copy()
    bad_epochs_df = pd.DataFrame.from_records(bad_epochs_ar)
    bad_epoch_counter2 = 0
    bad_epochs_idx2 = np.where(bad_epochs_df.iloc[:,0].notnull())[0]
    for i in range(len(bad_epochs_idx2)):
        n = bad_epochs_idx2[i]
        subject_bad_epochs = bad_epochs_ar[n]
        final_epochs[n].drop(subject_bad_epochs)
        bad_epoch_counter2 += len(subject_bad_epochs)
    
    # Summarize how many bad channels and epochs there were manually defined
    (bad_epoch_counter2+bad_epoch_counter)/(141+150*(n_subjects-1))*100 # in total 1.37% epochs rejected
    
    # Re-reference and interpolate
    def re_reference_interpolate(input_epochs, n, bads):
        # Re-reference to Cz from average
        input_epochs[n].set_eeg_reference(ref_channels=["Cz"], verbose=0)
        # Set bads
        input_epochs[n].info["bads"] = bads
        # Re-reference to average without bads
        input_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
        # Interpolate
        input_epochs[n].interpolate_bads(reset_bads=True)
    
    re_reference_interpolate(final_epochs, 1, ["T7"]) # n = 1
    
    # Save all bad channels from first visual inspection and autorejection guided
    bad_channels[1] = ["T7"]*n_trials
    # To keep track of which channels were interpolated
    with open("./Preprocessing/bad_ch.pkl", "wb") as filehandle:
        pickle.dump(bad_channels, filehandle)
    
    
    # %% Save the preprocessed epochs
    save_path = "./PreprocessedData"
    for n in range(len(corrected_epochs)):
        final_epochs[n].save(fname = os.path.join(save_path,str("{}_preprocessed".format(Subject_id[n]) + "-epo.fif")),
                        overwrite=True, verbose=0)
    
    # %% Evaluate dropped epochs to determine threshold for exclusion
    dropped_epochs_df = np.load("./Preprocessing/dropped_epochs.pkl", allow_pickle=True)
    
    # Convert the df to number of dropped epochs
    # In the original data I had multiple datasets and combined the dropped df here
    All_drop_epoch_df = dropped_epochs_df # originally using pd.concat
    
    def drop_epoch_counter(row,df_row):
        res = np.sum(df_row.iloc[row,1:].notnull())
        return res
    
    Number_dropped_epochs = [0]*All_drop_epoch_df.shape[0]
    for i in range(All_drop_epoch_df.shape[0]):
        Number_dropped_epochs[i] = drop_epoch_counter(i,All_drop_epoch_df)
    
    All_drop_epoch_df2 = pd.DataFrame({"Subject_ID":All_drop_epoch_df["Subject_ID"],
                                       "Number_drop_epoch":Number_dropped_epochs})
    
    # Plot histogram
    bin_seq = range(0,np.max(All_drop_epoch_df2["Number_drop_epoch"])+1)
    All_drop_epoch_df2.hist(column="Number_drop_epoch",figsize=(12,8),bins=bin_seq)
    
    # View all subjects with more than 15 dropped epochs
    All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>=15,:]
    # Number of dropped subjects using above 30 as cutoff
    len(list(All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>30,"Subject_ID"]))
    # Using cutoff for exclusion: 20% or more dropped epochs