Skip to content
Snippets Groups Projects
Commit e7a9ba66 authored by glia's avatar glia
Browse files

Preprocessing

parent ae4cc071
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""
Updated Oct 18 2022
@author: Qianliang Li (glia@dtu.dk)
The following preprocessing steps were employed:
1. Load the data
2. 1-100Hz Bandpass filtering
3. 50Hz Notch filtering
4. Retrieving event information about eyes open/closed states
5. Epoching to 4s non-overlapping segments
6. Visual inspection of the data
Bad channels and non-ocular artifacts are removed
7. Robust re-reference to common average (without bad channels)
8. Interpolation of bad channels
9. ICA artifact removal of ocular and ECG artifacts
10. AutoReject guided final visual inspection
Due to privacy issues of clinical data, our data is not publically available.
For demonstration purposes, I am using a publically available EEG dataset
and treating it as resting-state eyes open/closed
Link to the demonstration data: www.bci2000.org
"""
# Set working directory
import os
wkdir = "/home/glia/EEG"
os.chdir(wkdir)
# Load all libraries from the Preamble
from Preamble import *
# To demonstrate the script, a publically available EEG dataset are used
# EEG recordings from 2 subjects are used as an example
# The EEGBCI200 is task based, but we will treat it as "resting-state"
# And the 2 runs as Eyes Closed and Eyes Open
n_subjects = 2
Subject_id = [1,2]
# Download and get filenames
from mne.datasets import eegbci
files = []
for i in range(n_subjects):
raw_fnames = eegbci.load_data(Subject_id[i], [1,2]) # The first 2 runs
files.append(raw_fnames)
# # Original code to get filenames in a folder
# data_path = "EEG_folder"
# # Get filenames
# files = []
# for r, d, f in os.walk(data_path):
# for file in f:
# if ".bdf" in file:
# files.append(os.path.join(r, file))
# Eye status
anno_to_event = {'Eyes Closed': 1, 'Eyes Open': 2} # manually defined event id
eye_status = list(anno_to_event.keys())
n_eye_status = len(eye_status)
# Epoch settings
epoch_len = 4 # length of epochs in seconds
n_epochs_trial = int(60/epoch_len) # number of epochs in a trial
n_trials = 2 # 1 eyes closed followed by 1 eyes open
# Montage settings
montage = mne.channels.make_standard_montage('standard_1005')
#montage = mne.channels.read_custom_montage(filename) # custom montage file
# %% Load, filter and epoch (Steps 1 to 5)
# Pre-allocate memory
epochs = [0]*n_subjects
for i in range(n_subjects):
# MNE python supports many EEG formats. Make sure to use the proper one
raw = mne.io.concatenate_raws([mne.io.read_raw_edf(f, preload=True) for f in files[i]])
# Fix EEGBCI channel names
eegbci.standardize(raw)
# Set montage
raw.set_montage(montage)
# Only use the EEG channels
raw.pick_types(meg=False, eeg=True, stim=False)
# Bandpass filter (will also detrend, median = 0)
raw.filter(1, 70, fir_design="firwin", verbose=0) # Due to sfreq = 160 I cannot lowpass at 100Hz
# raw.filter(1, 100, fir_design="firwin", verbose=0) # original line
# Notch filter to remove power-line noise (50Hz is common in Europe)
raw.notch_filter(50, fir_design="firwin", verbose=0)
# Epoch to 4 sec
# The first 60s are treated as Eyes Closed and following 60s as Eyes Open
event = np.zeros((int(2*n_epochs_trial),3), dtype=int) # manually make event array (n_events,3)
event[:,0] = np.array(np.linspace(0,2*60-epoch_len,int(2*n_epochs_trial))*int(raw.info["sfreq"]), dtype=int) # first column, the time for events
# Hardcoded based on data format
event[:n_epochs_trial,2] = 1 # Eyes closed
event[n_epochs_trial:,2] = 2 # Eyes open
# Make the epochs. (The -1/int(raw[n].info["sfreq"]) is needed because python start with 0, so 30000 points is 0 to 29999
epochs[i] = mne.Epochs(raw, event, event_id=anno_to_event,
tmin=0, tmax=epoch_len-1/int(raw.info["sfreq"]),baseline=None, verbose=0).load_data()
print("Subject:{} finished epoching ({}/{})".format(Subject_id[i],i+1,n_subjects))
# Compute the number of epochs in each trial
n_epochs = pd.DataFrame(np.zeros((n_subjects,n_trials)))
for i in range(n_subjects):
for t in range(n_trials):
try:
n_epochs_ec = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[0]])
except:
n_epochs_ec = 0
try:
n_epochs_eo = len(epochs[i][t:int((t+1)*n_epochs_trial)][eye_status[1]])
except:
n_epochs_eo = 0
n_epochs.iloc[i,t] = np.max([n_epochs_ec,n_epochs_eo])
# Calculate cumulative sum which is used for later indexing
cumsum_n_epochs = np.cumsum(n_epochs, axis=1)
# Insert 0 as first column
cumsum_n_epochs.insert(0,"Start",0)
# %% First visual inspection, re-referencing and ch interpolation (Steps 6 to 8)
# When in doubt, I did not exclude the data
# Ocular artifacts are not removed, as we will use ICA to correct for those
# Bad epochs are dropped
# Bad channels are removed and interpolated on a trial basis
# i.e. for each 1min eyes open or eyes closed trial
# Visualize each
epochs[0].plot(scalings=200e-6, n_epochs=15, n_channels=24)
epochs[1].plot(scalings=200e-6, n_epochs=15, n_channels=24)
# Pre-allocate memory
bad_channels = [0]*len(epochs)
bad_epochs = [0]*len(epochs)
# Manually found bad data segments
#bad_channels[0] = [0] # If all trials are fine I do not make a line with bad channels
bad_epochs[0] = [6] # Epoch 6
bad_channels[1] = ["T7",0] # T7 is bad in first eyes closed trial
bad_epochs[1] = [12,13] # 2 bad epochs
# Cautionary note: bad_epochs are using index. Not the epoch number in MNE-viewer
# They are often the same, but not always, like in this demonstration
# where there are no Epoch 15 because it was not full length data due to
# the way I defined the event times, so when you drop Epoch idx 15, it is the
# epoch that MNE labeled as 16 in this example!
# Pre-allocate memory
cleaned_epochs = epochs.copy()
# Interpolate bad channels and re-reference to robust common average
bad_ch_counter = 0 # count how many bad channels for summary
bad_ch_idx = np.array(bad_channels,dtype=object).nonzero()[0]
for i in range(len(bad_ch_idx)):
n = bad_ch_idx[i]
n_epochs_trial = int(60/epoch_len)
# Get epoch number for each subject
epoch_number = np.arange(0,len(epochs[n]))
# Pre-allocate memory
temp_epochs = [0]*n_trials
for trial in range(n_trials):
# Retrieve trial
trial_idx = (epoch_number >= cumsum_n_epochs.iloc[n,trial]) & (epoch_number < cumsum_n_epochs.iloc[n,trial+1])
temp_epochs[trial] = epochs[n].copy().drop(np.invert(trial_idx))
# Set bad channel
trial_bad_ch = bad_channels[bad_ch_idx[i]][trial]
if trial_bad_ch == 0:
# Do not perform interpolation, only re-referencing to average
temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
else:
if type(trial_bad_ch) == str: # fix if only one ch is provided
trial_bad_ch = [trial_bad_ch] # make it to list
temp_epochs[trial].info["bads"] = trial_bad_ch
# Re-reference to Cz
temp_epochs[trial].set_eeg_reference(ref_channels="average", verbose=0)
# Interpolate the bad channels
temp_epochs[trial].interpolate_bads(reset_bads=True)
# Increase counter
bad_ch_counter += len(trial_bad_ch)
# Concatenate temporary epoch and save
cleaned_epochs[n] = mne.concatenate_epochs(temp_epochs, add_offset=False)
# Notice that an offset is still added when using to_data_frame!
# Re-reference all other data (that did not have bad channels) to common avg
good_ch_idx = np.where(np.array(bad_channels,dtype=object) == 0)[0]
for i in range(len(good_ch_idx)):
n = good_ch_idx[i]
cleaned_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
# Drop bad epochs
bad_epoch_counter = 0
bad_epochs_idx = np.array(bad_epochs,dtype=object).nonzero()[0]
for i in range(len(bad_epochs_idx)):
n = bad_epochs_idx[i]
subject_bad_epochs = bad_epochs[n]
cleaned_epochs[n].drop(subject_bad_epochs)
bad_epoch_counter += len(subject_bad_epochs)
# Summarize how many bad channels and epochs there were manually defined
(bad_ch_counter/(n_trials*epochs[0].info["nchan"]*n_subjects))*100 # 0.4% bad channels
(bad_epoch_counter/(141+150*(n_subjects-1)))*100 # 1% bad epochs rejected
# %% ICA is performed to remove eye blinks, ECG and EOG artifacts
# Make list to contain all ICAs
ica = [0]*len(cleaned_epochs)
# Make ICA objects
for n in range(len(cleaned_epochs)):
# Matrix rank is -1 because I used common average reference
# If any channels were interpolated the rank is further reduced
if any([i in [n] for i in bad_ch_idx]):
inter_pol_ch = [i for i in bad_channels[n] if i != 0] # remove all 0
# Make a flat list to use np.unique
flat_list = []
for sublist in inter_pol_ch:
if type(sublist) == str: # fix if only one ch is provided
sublist = [sublist]
for item in sublist:
flat_list.append(item)
n_inter_pol_ch = len(np.unique(flat_list))
matrix_rank = cleaned_epochs[n].info["nchan"]-1-n_inter_pol_ch
else:
matrix_rank = cleaned_epochs[n].info["nchan"]-1
ica[n] = mne.preprocessing.ICA(method="fastica", random_state=42, verbose=0,
max_iter=500, n_components=matrix_rank)
ica[n].fit(cleaned_epochs[n])
print("{} out of {} ICAs processed".format(n+1,len(cleaned_epochs)))
# Plot the components for visual inspection
def ica_analysis(n):
plt.close("all")
# Plot original
cleaned_epochs[n].plot(scalings=200e-6, n_epochs = 10)
# Plot ICA - compontents
ica[n].plot_sources(cleaned_epochs[n], start = 0, stop = 10)
ica[n].plot_sources(cleaned_epochs[n], start = 5, stop = 8) # zoomed in helps for ECG artifact recognition
ica[n].plot_components(picks=np.arange(0,20))
# Manual ICA decomposition visualization
n = 1; ica_analysis(n) # manually change n
# # Plot specific component for further inspection of the marked artifacts
# ica[n].plot_properties(cleaned_epochs[n], picks = artifacts[n])
# Pre-allocate memeory
artifacts = [0]*len(cleaned_epochs)
# Manually determine artifacts
artifacts[0] = [0, 5, 12] # eye blinks, eye blinks, eye movement
artifacts[1] = [0, 2] # eye blinks, eye movement
# Remove the artifact components from the signal
corrected_epochs = cleaned_epochs.copy()
for n in range(len(cleaned_epochs)):
# Define the components with artifacts
ica[n].exclude = artifacts[n]
# Remove on corrected data
ica[n].apply(corrected_epochs[n].load_data())
# Inspect how the ICA worked
n=0; corrected_epochs[n].plot(scalings=200e-6, n_epochs = 10)
# %% Detect bad epochs automatically using AutoReject
"""
AutoReject will use cross-validation to determine optimal Peak-to-Peak threshold
This threshold will be determined for each channel for each subject
The threshold is used to mark whether each ch and epoch are bad
Interpolation is performed on bad epoch/ch segments
If many neighboring ch are bad, then the algorithm will score the worst ch
(based on peak-to-peak amplitude) and only interpolate the worst
If too many ch in one epoch are bad, the epoch will be rejected
I am not using Autoreject directly to interpolate, but as a guide.
The marked bad ch and epochs are then manually inspected and determined
whether they should be dropped or not.
The algorithm is run for each eye status separately
"""
# Suppress plots
# import matplotlib
# matplotlib.use("Agg")
save_path = "./Autoreject_overview" # for autoreject overview
# Pre-allocate memory
reject_log = [0]*len(corrected_epochs)
dropped_epochs = [0]*len(corrected_epochs)
ar = [0]*len(corrected_epochs)
mean_threshold = [0]*len(corrected_epochs)
for i in range(len(corrected_epochs)):
reject_log0 = [0]*n_eye_status
ar0 = [0]*n_eye_status
mean_threshold0 = [0]*n_eye_status
drop_epochs0 = [0]*n_eye_status
for e in range(n_eye_status):
ee = eye_status[e]
# Initialize class
ar0[e] = AutoReject(consensus=np.linspace(0,1,11), cv=10, n_jobs=8,
verbose=False, random_state=42)
# Fit to data - but do not transform
ar0[e].fit(corrected_epochs[i][ee])
# Get rejection log
reject_log0[e] = ar0[e].get_reject_log(corrected_epochs[i][ee])
# Plot and save Autorejected epochs
fig = reject_log0[e].plot(orientation="horizontal", show=False)
fig.savefig(os.path.join(save_path,"AR_" + str(Subject_id_concat[i]) + "_" + str(ee) + ".png"))
# Close figure window
plt.close(fig)
# Save mean peak-to-peak voltage threshold used
mean_threshold0[e] = np.mean(list(ar0[e].threshes_.values()))
# Save suggested dropped epochs
drop_epochs0[e] = reject_log0[e].bad_epochs
# Concatenate dropped epochs
drop_epochs1 = np.concatenate(drop_epochs0)
# Save
dropped_epochs[i] = drop_epochs1.nonzero()[0]
reject_log[i] = reject_log0
ar[i] = ar0
mean_threshold[i] = mean_threshold0
print("{} out of {} subjects finished autorejecting".format(i+1,len(corrected_epochs)))
# Overview of dropped epochs
Drop_epochs_df = pd.DataFrame.from_records(dropped_epochs) # convert to dataframe
Drop_epochs_df.insert(0, "Subject_id", Subject_id)
# Re-enable plots
# import matplotlib.pyplot as plt
# %matplotlib qt
print(Drop_epochs_df) # no further bad epochs after ICA
# But for demonstration purposes we will add one more bad epochs in subject 2
### Dropped epochs are used as guide to manually inspect potential thresholded bad epochs
# Visualize
n=1 ; corrected_epochs[n].plot(scalings=100e-6, n_channels=31, n_epochs=15)
bad_epochs_ar = dropped_epochs # make copy before modifying
# Manually modify when appropriate
# n = 0 - agreed to all
bad_epochs_ar[1] = np.array([8], dtype="int64") # added 1 more than suggested
# n = 1: T7 is an outlier and bad
# Combine dropped epochs from first visual inspection and Autoreject
bad_epochs_comb = [0]*len(corrected_epochs)
for i in range(len(corrected_epochs)):
# Retrieve bad epochs from first manual inspection
bad1 = np.array(bad_epochs[i])
# If there are no dropped epochs from first inspection, just use second
if type(bad_epochs[i]) == int:
bad_epochs_comb[i] = bad_epochs_ar[i]
continue
else:
# Retrieve bad epochs from manual of AR
bad2 = bad_epochs_ar[i].copy()
# Fix index due to iterative dropped epoch
# E.g. if 45 is dropped, then all epochs above 45 will have 1 index lower and should be fixed
for drops in range(len(bad1)):
bad2[bad2 >= bad1[drops]] += 1
# Concatenate from both
bad_epochs_comb[i] = np.sort(np.concatenate([bad1,bad2]))
# Convert to dataframe
bad_epochs_comb_df = pd.DataFrame.from_records(bad_epochs_comb)
bad_epochs_comb_df.insert(0, "Subject_ID", Subject_id)
# Save
bad_epochs_comb_df.to_pickle("./Preprocessing/dropped_epochs.pkl")
# Drop bad epochs
final_epochs = corrected_epochs.copy()
bad_epochs_df = pd.DataFrame.from_records(bad_epochs_ar)
bad_epoch_counter2 = 0
bad_epochs_idx2 = np.where(bad_epochs_df.iloc[:,0].notnull())[0]
for i in range(len(bad_epochs_idx2)):
n = bad_epochs_idx2[i]
subject_bad_epochs = bad_epochs_ar[n]
final_epochs[n].drop(subject_bad_epochs)
bad_epoch_counter2 += len(subject_bad_epochs)
# Summarize how many bad channels and epochs there were manually defined
(bad_epoch_counter2+bad_epoch_counter)/(141+150*(n_subjects-1))*100 # in total 1.37% epochs rejected
# Re-reference and interpolate
def re_reference_interpolate(input_epochs, n, bads):
# Re-reference to Cz from average
input_epochs[n].set_eeg_reference(ref_channels=["Cz"], verbose=0)
# Set bads
input_epochs[n].info["bads"] = bads
# Re-reference to average without bads
input_epochs[n].set_eeg_reference(ref_channels="average", verbose=0)
# Interpolate
input_epochs[n].interpolate_bads(reset_bads=True)
re_reference_interpolate(final_epochs, 1, ["T7"]) # n = 1
# Save all bad channels from first visual inspection and autorejection guided
bad_channels[1] = ["T7"]*n_trials
# To keep track of which channels were interpolated
with open("./Preprocessing/bad_ch.pkl", "wb") as filehandle:
pickle.dump(bad_channels, filehandle)
# %% Save the preprocessed epochs
save_path = "./PreprocessedData"
for n in range(len(corrected_epochs)):
final_epochs[n].save(fname = os.path.join(save_path,str("{}_preprocessed".format(Subject_id[n]) + "-epo.fif")),
overwrite=True, verbose=0)
# %% Evaluate dropped epochs to determine threshold for exclusion
dropped_epochs_df = np.load("./Preprocessing/dropped_epochs.pkl", allow_pickle=True)
# Convert the df to number of dropped epochs
# In the original data I had multiple datasets and combined the dropped df here
All_drop_epoch_df = dropped_epochs_df # originally using pd.concat
def drop_epoch_counter(row,df_row):
res = np.sum(df_row.iloc[row,1:].notnull())
return res
Number_dropped_epochs = [0]*All_drop_epoch_df.shape[0]
for i in range(All_drop_epoch_df.shape[0]):
Number_dropped_epochs[i] = drop_epoch_counter(i,All_drop_epoch_df)
All_drop_epoch_df2 = pd.DataFrame({"Subject_ID":All_drop_epoch_df["Subject_ID"],
"Number_drop_epoch":Number_dropped_epochs})
# Plot histogram
bin_seq = range(0,np.max(All_drop_epoch_df2["Number_drop_epoch"])+1)
All_drop_epoch_df2.hist(column="Number_drop_epoch",figsize=(12,8),bins=bin_seq)
# View all subjects with more than 15 dropped epochs
All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>=15,:]
# Number of dropped subjects using above 30 as cutoff
len(list(All_drop_epoch_df2.loc[All_drop_epoch_df2["Number_drop_epoch"]>30,"Subject_ID"]))
# Using cutoff for exclusion: 20% or more dropped epochs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment