基于MNE的EEGNet 神经网络的脑电信号分类实战(附完整源码)

利用MNE中的EEG数据,进行EEGNet神经网络的脑电信号分类实现:

代码:

代码主要包括一下几个步骤:
1)从MNE中加载脑电信号,并进行相应的预处理操作,得到训练集、验证集以及测试集,每个集中都包括数据和标签;
2)基于tensorflow构建EEGNet网络模型;
3)编译模型,配置损失函数、优化器和评估指标等,并进行模型训练和预测;
4)绘制训练集和验证集的损失曲线以及训练集和验证集的准确度曲线。
代码如下:

import mne
import os
from pathlib import Path
import numpy as np
from keras.src.utils import np_utils

from mne import io
from mne.datasets import sample
import matplotlib.pyplot as plt
import pathlib

from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
from keras.src.callbacks import ModelCheckpoint

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression


def EEGNet(nb_classes, Chans=64, Samples=128,
           dropoutRate=0.5, kernelLength=64,
           F1=8, D=2, F2=16, norm_rate=0.25,
           dropout_type='Dropout'):
    """
    EEGNet模型的实现。

    参数:
    - nb_classes: int, 输出类别的数量。
    - Chans: int, 通道数,默认为64。
    - Samples: int, 每个通道的样本数,默认为128。
    - dropoutRate: float, Dropout率,默认为0.5。
    - kernelLength: int, 卷积核的长度,默认为64。
    - F1: int, 第一个卷积层的滤波器数量,默认为8。
    - D: int, 深度乘法器,默认为2。
    - F2: int, 第二个卷积层的滤波器数量,默认为16。
    - norm_rate: float, 权重范数约束,默认为0.25。
    - dropout_type: str, Dropout类型,默认为'Dropout'。

    返回:
    - Model: Keras模型对象。
    """

    # 根据dropout_type参数确定使用哪种Dropout方式
    if dropout_type == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropout_type == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropout_type must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')

    # 定义模型的输入层
    input1 = Input(shape=(Chans, Samples, 1))

    # 第一个卷积块
    block1 = Conv2D(F1, (1, kernelLength), padding='same',
                       input_shape=(Chans, Samples, 1),
                       use_bias=False)(input1)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((Chans, 1), use_bias=False,
                               depth_multiplier=D,
                               depthwise_constraint=max_norm(1.))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = dropoutType(dropoutRate)(block1)

    # 第二个卷积块
    block2 = SeparableConv2D(F2, (1, 16),
                               use_bias=False, padding='same')(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = dropoutType(dropoutRate)(block2)

    # 将卷积块的输出展平以便输入到全连接层
    flatten = Flatten(name='flatten')(block2)

    # 定义全连接层
    dense = Dense(nb_classes, name='dense', kernel_constraint=max_norm(norm_rate))(flatten)
    softmax = Activation('softmax', name='softmax')(dense)

    # 创建并返回模型
    return Model(inputs=input1, outputs=softmax)


def get_data4EEGNet(kernels, chans, samples):
    """
    为EEGNet模型准备数据。

    该函数从指定的文件路径中读取原始EEG数据和事件数据,进行预处理,
    包括滤波、选择通道、分割数据集,并将数据集按给定的通道、核数和样本数进行重塑。

    参数:
    kernels - 数据集中的核数量。
    chans - 数据集中的通道数量。
    samples - 数据集中的样本数量。

    返回:
    X_train, X_validate, X_test, y_train, y_validate, y_test - 分别是训练、验证和测试数据集,
    以及相应的标签。
    """
    # 设置图像数据格式,确保数据维度顺序正确
    K.set_image_data_format('channels_last')

    # 定义数据路径
    data_path = Path("C:\\Users\\72671\\mne_data\\MNE-sample-data")

    # 定义原始数据和事件数据的文件路径
    raw_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")
    event_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")

    # 定义时间范围和事件ID
    tmin, tmax = -0., 1
    event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

    # 读取并预处理原始数据
    raw = io.Raw(raw_fname, preload=True, verbose=False)
    raw.filter(2, None, method='iir')
    events = mne.read_events(event_fname)

    # 设置无效通道并选择所需通道类型
    raw.info['bads'] = ['MEG 2443']
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

    # 创建epochs数据集
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, picks=picks, baseline=None,
                        preload=True, verbose=False)
    labels = epochs.events[:, -1]

    # 获取数据并进行缩放
    X = epochs.get_data(copy=False) * 1e6
    y = labels

    # 分割数据集为训练、验证和测试集
    X_train = X[0:144, ]
    y_train = y[0:144]
    X_validate = X[144:216, ]
    y_validate = y[144:216]
    X_test = X[216:, ]
    y_test = y[216:]

    # 将训练、验证和测试数据集中的标签转换为one-hot编码
    # 减1是因为标签通常从1开始计数,而one-hot编码需要从0开始
    y_train = np_utils.to_categorical(y_train-1)
    y_validate = np_utils.to_categorical(y_validate-1)
    y_test = np_utils.to_categorical(y_test-1)


    # 重塑数据集以匹配EEGNet模型的输入要求
    X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
    X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
    X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

    # 返回准备好的数据集
    return X_train, X_validate, X_test, y_train, y_validate, y_test


#########################################################################
# 定义模型参数
kernels, chans, samples = 1, 60, 151
# 获取预处理后的EEG数据集
X_train, X_validate, X_test, y_train, y_validate, y_test = get_data4EEGNet(kernels, chans, samples)

# 初始化EEGNet模型
model = EEGNet(nb_classes=4, Chans=chans, Samples=samples, dropoutRate=0.5,
               kernelLength=32, F1=8, D=2, F2=16, dropout_type='Dropout')

# 编译模型,配置损失函数、优化器和评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 设置模型检查点以保存最佳模型
checkpointer = ModelCheckpoint(filepath='./models/EEGNet_best_model.h5', verbose=1, save_best_only=True)
# 定义类别权重
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}

# 训练模型
fittedModel = model.fit(X_train, y_train, batch_size=32, epochs=500, verbose=2,
                        validation_data=(X_validate, y_validate),
                        callbacks=[checkpointer], class_weight=class_weights)

# 加载最佳模型权重
model.load_weights('./models/EEGNet_best_model.h5')

# 对测试集进行预测
probs = model.predict(X_test)
# 获取预测标签
preds = probs.argmax(axis=-1)
# 计算分类准确率
acc = np.mean(preds == y_test.argmax(axis=-1))

# 输出分类准确率
print("Classification accuracy: %f " % (acc))


# 获取训练历史
history = fittedModel.history

# 绘制训练集和验证集的损失曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# 绘制训练集和验证集的准确度曲线
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

效果如下:

在这里插入图片描述

参考资料:

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/938827.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SCHEMA find old payroll result

这几天原来HK客户要我帮忙看一个问题,在看HK雇佣条例时,发现又假期是取前12个月的工资,后来查看标准函数发现两个有用的operation,PLOOP与IMPRE 下图是2012年6月工资核算,现在循环着前面4个月。 输入 输出 2012年5月…

【ARM Trace32(劳特巴赫) 使用介绍 1 -- Trace32 debug 工具安装详细步骤】

文章目录 Trace32 工具解压查看安装手册准备安装设置安装目录指定安装目录选择安装类型:选择平台架构开始安装Trace32 应用打开使用界面Trace32 工具解压 使用 7-zip 解压两次: 查看安装手册 安装步骤按照文档中的 1、2、3 进行: 在解压文件中找到安装工具,如下: 准备…

Flux Tools 结构简析

Flux Tools 结构简析 BFL 这次一共发布了 Canny、Depth、Redux、Fill 四个 Tools 模型系列,分别对应我们熟悉的 ControlNets、Image Variation(IP Adapter)和 Inpainting 三种图片条件控制方法。虽然实现功能是相同的,但是其具体…

【物联网技术与应用】实验3:七彩LED灯闪烁

实验3 七彩LED灯闪烁 【实验介绍】 七彩LED灯上电后,7色动闪光LED模块可自动闪烁内置颜色。它可以用来制作相当吸引人的灯光效果。 【实验组件】 ● Arduino Uno主板* 1 ● USB数据线* 1 ● 7彩LED模块*1 ● 面包板*1 ● 9V方型电池*1 ● 跳线若干 【实验原…

Web 安全必读:跨站脚本攻击 (XSS) 原理与防御指南

文章目录 原理解析:触发方式 文件内容中的xss文件名中的xssHTTP请求中的xss其他 分类: 根据攻击脚本存储的方式根据脚本是否通过服务器处理根据持久性 常见的js触发标签 无过滤情况有过滤情况 xss-labs通关 level1-level10level11-level20 XSS&#x…

Set集合进行!contains判断IDEA提示Unnecessary ‘contains()‘ check

之前写过一个代码&#xff0c;用到了Set集合&#xff0c;判断了如果某个元素不存在就添加到集合中。今天翻看代码又看到了IDEAUnnecessary contains() check爆黄提示。 来一段测试代码&#xff1a; public class SetTest {public static void main(String[] args) {Set<Int…

Mapper代理开发

引入 Mybatis入门方式中&#xff0c;以下代码仍存在硬编码问题 Mapper 代理开发&#xff1a; 目的&#xff1a; 解决原生方式中的硬编码 简化后期执行sql ------下图中&#xff0c;第一段代码是原生硬编码代码块&#xff0c;第二个是引入了Mapper代理开发的代码块。 Mapper代…

abc 384 D(子数组->前缀和) +E(bfs 扩展的时候 按照数值去扩展)

D 做出来的很开心&#xff0c;好像本来就应该做出来>< 思路&#xff1a; 对于连续的子序列&#xff08;也就是 子数组&#xff09; 一般都和 前缀和 后缀和 有关系 区间[l r] 可以用 前缀 S_r -S{l-1} tar来表示。(对于两个元素等于一个数字&#xff0c;就可以枚举一个&…

【2024版】最新推荐好用的XSS漏洞扫描利用工具_xss扫描工具

工具介绍 toxssin 是一种开源渗透测试工具&#xff0c;可自动执行跨站脚本 (XSS) 漏洞利用过程。它由一个 https 服务器组成&#xff0c;它充当为该工具 (toxin.js) 提供动力的恶意 JavaScript 有效负载生成的流量的解释器。 安装与使用 1、安装需要的依赖库 git clone http…

web网页前后端交互方式

参考该文&#xff0c; 一、前端通过表单<form>向后端发送数据 前端是通过html中的<form>表单&#xff0c;设置method属性定义发送表单数据的方式是get还是post。 如使用get方式&#xff0c;则提交的数据会在url中显示&#xff1b;如使用post方式&#xff0c;提交…

Visual studio的AI插件-通义灵码

通义灵码 TONGYI Lingma 兼容 Visual Studio、Visual Studio Code、JetBrains IDEs 等主流 IDE&#xff1b;支持 Java、Python、Go、C/C、C#、JavaScript、TypeScript、PHP、Ruby、Rust、Scala 等主流编程语言。 安装 打开扩展管理器&#xff0c;搜送“TONGYI Lingma”&…

[C++]类的继承

一、什么是继承 1.定义&#xff1a; 在 C 中&#xff0c;继承是一种机制&#xff0c;允许一个类&#xff08;派生类&#xff09;继承另一个类&#xff08;基类&#xff09;的成员&#xff08;数据和函数&#xff09;。继承使得派生类能够直接访问基类的公有和保护成员&#xf…

【系统】Mac crontab 无法退出编辑模式问题

【系统】Mac crontab 无法退出编辑模式问题 背景一、问题回答1.定位原因&#xff1a;2.确认编辑器类型3.确保编辑器进入正确3.1 确认是否有crontab调度任务3.2 进入编辑器并确保编辑器正常3.3 保存操作 4.确认crontab任务存在 二、后续 背景 之前写过一篇&#xff1a;【系统】…

WPF系列一:窗口设置无边框

WindowStyle 设置&#xff1a;WindowStyle"None"&#xff0c;窗口无法拖拽&#xff0c;但可纵向和横向拉伸 <Window x:Class"WPFDemo.MainWindow.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x&quo…

售电公司办理全指南,开启能源新征程

一、售电公司的重要性 售电公司在能源市场中起着至关重要的作用&#xff0c;成为连接发电企业与终端用户的关键桥梁。随着经济的发展和生活水平的提高&#xff0c;电力需求持续增长&#xff0c;售电公司能够为满足这不断增长的需求提供更多选择。同时&#xff0c;在新电力技术…

JS使用random随机数实现简单的四则算数验证

1.效果图 2.代码实现 index.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</ti…

【记录】Django解决与VUE跨域问题

1 梗概 这里记录Django与VUE的跨域问题解决方法&#xff0c;主要修改内容是在 Django 中。当然其他的前端项目 Django 也可以这样处理。 2 安装辅助包 pip install django-cors-headers3 配置 settings.py INSTALLED_APPS [ # ... corsheaders, # ... ] 为了响应…

Sub-GHz无线通信技术,打造LPWAN的“最优解”

无线通信是指利用电磁波信号可以在自由空间中传播的特性进行信息交换的一种通信技术。近些年来&#xff0c;从便捷的手机通信到智能互联的家居系统&#xff0c;再到“万物互联”的物联网时代&#xff0c;无线通信技术都以其无处不在的身影&#xff0c;展现出了强大的生命力和无…

2D gaussian splatting的配置和可视化

继3D gaussian splatting&#xff0c;2D gaussian splatting除了渲染新视角&#xff0c;还能够生成mesh模型。 2D gaussian splatting的配置 两者的运行环境基本一致 GitHub - hbb1/2d-gaussian-splatting: [SIGGRAPH24] 2D Gaussian Splatting for Geometrically Accurate …

6、AI测试辅助-测试报告编写(生成Bug分析柱状图)

AI测试辅助-测试报告编写&#xff08;生成Bug分析柱状图&#xff09; 一、测试报告1. 创建测试报告2. 报告补充优化2.1 Bug图表分析 3. 风险评估 总结 一、测试报告 测试报告内容应该包含&#xff1a; 1、测试结论 2、测试执行情况 3、测试bug结果分析 4、风险评估 5、改进措施…