-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheeg_signal_processing.py
61 lines (56 loc) · 2.01 KB
/
eeg_signal_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import mne
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.signal import correlate
def plot_eeg_signals(df):
ch_names = list(df.columns.values)
df_eeg = df.copy()
df_eeg = df_eeg.transpose()
info = mne.create_info(ch_names = ch_names, sfreq = 2048)
raw = mne.io.RawArray(data = df_eeg, info = info)
raw.plot()
plt.show()
print(raw)
print(raw.info)
def compute_freq_bands(data, fs):
# Get real amplitudes of FFT (only in postive frequencies)
fft_vals = np.absolute(np.fft.rfft(data))
# Get frequencies for amplitudes in Hz
fft_freq = np.fft.rfftfreq(len(data), 1.0/fs)
# Define EEG bands
eeg_bands = {'Delta': (0, 4),
'Theta': (4, 8),
'Alpha': (8, 12),
'Beta': (12, 30),
'Gamma': (30, 45)}
# Take the mean of the fft amplitude for each EEG band
eeg_band_fft = dict()
for band in eeg_bands:
freq_ix = np.where((fft_freq >= eeg_bands[band][0]) &
(fft_freq <= eeg_bands[band][1]))[0]
eeg_band_fft[band] = np.mean(fft_vals[freq_ix])
# Plot the data (using pandas here cause it's easy)
df = pd.DataFrame(columns=['band', 'val'])
df['band'] = eeg_bands.keys()
df['val'] = [eeg_band_fft[band] for band in eeg_bands]
return df
if __name__ == "__main__":
filename = './eeg_math_subj.csv'
df_eeg = pd.read_csv(filename)
df_eeg = df_eeg.drop(columns=["'EDF Annotations'"]).reset_index(drop=True)
print(df_eeg.head(5))
plot_eeg_signals(df_eeg)
columns = list(df_eeg.columns.values)
idx = 0
fig, axes = plt.subplots(nrows = 7, ncols = 3)
params = {
'axes.titlesize': 8,
'axes.labelsize': 5}
plt.rcParams.update(params)
for col in columns:
df_pr = compute_freq_bands(df_eeg[col], 2048)
df_pr.plot(kind = 'bar', x='band', y='val', legend=False, ax = axes[idx % 7, idx % 3] )
axes[idx % 7, idx % 3].set_title('{}'.format(col))
idx += 1
plt.show()