信号处理--基于混合CNN和transfomer自注意力的多通道脑电信号的情绪分类的简单应用

目录

关于

工具

数据集

数据集简述

方法实现

数据读取

​编辑数据预处理

传统机器学习模型(逻辑回归,支持向量机,随机森林)

多层感知机模型

CNN+transfomer模型

代码获取


关于

  • 本实验利用结合了卷积神经网络 (CNN) 和 Transformer 组件的混合架构,实现基于 EEG 的有效情绪分类。
  • 尝试各种机器学习模型,包括逻辑回归、支持向量机 (SVM)、随机森林分类器和多层感知器 (MLP) 神经网络,以比较不同模型的性能。 

 图片来自于: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9991178

工具

数据集

数据集简述

脑电图数据是从两名受试者(1 名男性、1 名女性,年龄 20-22 岁)收集的,针对特定电影剪辑引发的六种情绪状态(积极、消极、中性)中的每一种状态。该数据集包括从脑电波中收集的 324,000 个数据点,这些数据点被重新采样到 150Hz。还收集了中性脑电波数据,作为代表受试者静息情绪状态的第三类数据。从四个电极(TP9、AF7、AF8、TP10)记录 EEG 数据,并进行处理以生成通过 1 秒滑动窗口提取的统计特征数据集。

图片来源:https://www.researchgate.net/figure/This-figure-shows-the-standard-locations-for-measuring-EEG-as-per-10-20-International_fig2_358644174 

方法实现

数据读取
raw_eeg_data = pd.read_csv('../data/features_raw.csv')
raw_eeg_data.head()

# plot the F8 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F8'])
plt.title('F8 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()


# plot the F7 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F7'])
plt.title('F7 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()
数据预处理
X = eeg_emotions_data.drop(['label'], axis=1)
y = eeg_emotions_data['label']

# Encoding categorical data
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

labelencoder_emotions = LabelEncoder()
y = labelencoder_emotions.fit_transform(y)


# Standardizing the features in the dataset
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

X = scaler.fit_transform(X)

传统机器学习模型(逻辑回归,支持向量机,随机森林)
from sklearn.linear_model import LogisticRegression
import pickle

# Create a logistic regression classifier
model = LogisticRegression(random_state=2003, multi_class='multinomial', max_iter=1000)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.svm import SVC

# Create a model: a support vector classifier
model = SVC(kernel='rbf', gamma='auto', C=1.0, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.ensemble import RandomForestClassifier

# Create a random forest Classifier.
model = RandomForestClassifier(n_estimators=100, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

在传统机器模型,我们可以发现随机森林的性能表现最好; 

多层感知机模型
class EEGClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super(EEGClassifier, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


input_dim = 2548  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGClassifier(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

CNN+transfomer模型
class EEGConformer(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(EEGConformer, self).__init__()

        # CNN
        self.conv1 = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))
        self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, input_dim), stride=(1, 1))
        self.batchnorm = nn.BatchNorm2d(40)

        # Transformer
        self.layernorm1 = nn.LayerNorm(40)
        self.multiheadattention = nn.MultiheadAttention(40, 1)
        self.layernorm2 = nn.LayerNorm(40)

        self.feedworward_block = nn.Sequential(
            nn.Linear(40, 32),
            nn.GELU(),
            nn.Dropout(p=0.1),
            nn.Linear(32, 40)
        )

        # MLP
        self.fc1 = nn.Linear(40, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # CNN
        x = x.unsqueeze(1).unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.batchnorm(x)

        # Transformer
        x = x.squeeze()
        x = self.layernorm1(x)
        attn_out = self.multiheadattention(x, x, x)
        x = x + nn.Dropout(0.1)(attn_out[0])
        x = self.layernorm2(x)
        x = self.feedworward_block(x)
        x = nn.Dropout(p=0.1)(x)

        # MLP
        x = self.fc1(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.5)(x)
        x = self.fc2(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.3)(x)
        x = self.fc3(x)
        
        return x

input_dim = 2524  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGConformer(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

代码获取

后台私信,注明来意和文章名称;

其他问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/491257.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt 作业 24/3/26

1、实现闹钟 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTime> #include <QLineEdit>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent …

FPGA之状态机学习

作为一名逻辑工程师&#xff0c;掌握和应用状态机设计是必不可少的。能够灵活的应用状态机是对逻辑工程师最基本的要求&#xff0c;状态机设计的好坏能够直接影响到设计系统的稳定性&#xff0c;所以学会状态机是非常的重要。 1.状态机的概念 状态机通过不同的状态迁移来完成特…

STM32之HAL开发——串口配置(源码)

串口收发原理框图&#xff08;F1系列&#xff09; 注意&#xff1a;数据寄存器有俩个一个是收一个是发&#xff0c;但是在标准库或者HAL库中没有特别区分开来是俩个寄存器&#xff01; USART 初始化结构体详解 HAL 库函数对每个外设都建立了一个初始化结构体&#xff0c;比如 …

如何快速在ESXi中嵌套部署一台ESXi服务器?

正文共&#xff1a;1234 字 26 图&#xff0c;预估阅读时间&#xff1a;2 分钟 我们之前介绍过VMWare ESXi服务器镜像的定制&#xff08;VMware ESXi部署镜像定制&#xff09;和部署&#xff08;惠普VMware ESXI 6.7定制版部署&#xff09;&#xff0c;但是还没有介绍过ESXi版本…

YOLOv8改进 | 主干篇 | 修复官方去除掉PP-HGNetV2的通道缩放功能(轻量又涨点,全网独家整理)

一、本文介绍 本文给大家带来的改进机制是大家在跑RT-DETR提供的HGNetV2时的一个通道缩放功能&#xff08;官方在前几个版本去除掉的一个功能&#xff09;&#xff0c;其中HGNetV2当我们将其集成在YOLOv8n的模型上作为特征提取主干的时候参数量仅为230W 计算量为6.7GFLOPs该网…

【机器学习之---数学】随机游走

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 随机游走 1. 概念 1.1 例1 在你的饮食俱乐部度过了一个富有成效的晚上后&#xff0c;你在不太清醒的状态下离开了。因此&#xff0c;你会醉醺醺地在展…

【opencv】实时位姿估计(real_time_pose_estimation)—3D模型注册

相机成像原理图 物体网格、关键点&#xff08;局内点、局外点&#xff09;图像 box.ply resized_IMG_3875.JPG 主程序main_registration.cpp 主要实现了利用OpenCV库进行3D模型的注册。主要步骤包括加载3D网格模型、使用鼠标事件选择对应的3D点进行2D到3D的注册、利用solvePnP算…

在django中使用kindeditor出现转圈问题

在django中使用kindeditor出现转圈问题 【一】基础检查 【1】前端检查 确保修改了uploadJson的默认地址 该地址需要在路由层有映射关系 确认有加载官方文件 kindeditor-all-min.js确保有传递csrfmiddlewaretoken 或者后端关闭了csrf验证 <textarea name"content&qu…

无人驾驶矿卡整体解决方案(5g物联网通信方案)

​无人驾驶矿卡是智能矿山的重要组成部分,通过远程操控替代人工驾驶,可以显著提高采矿效率和作业安全性。但要实现无人驾驶矿卡,需要依赖于可靠高效的通信网络,来传输现场视频、控制指令和运行数据。以下是某大型煤矿在部署无人驾驶矿卡时,所采用的星创易联物联网整体解决方案。…

如何区分模型文件是稳定扩散模型和LORA模型

区分模型文件是否为稳定扩散模型&#xff08;Stable Diffusion Models&#xff09;或LORA模型&#xff08;LowRank Adaptation&#xff09;通常需要对模型的结构和内容有一定的了解。以下是一些方法来区分这两种模型文件&#xff1a; 1. 文件格式和结构 稳定扩散模型&#xff1…

词根词缀基础

一&#xff0e;词根词缀方法&#xff1a; 1. 类似中文的偏旁部首&#xff08;比如“休”单人旁木→一个人靠木头上休息&#xff09; 2. 把单词拆分后&#xff0c;每一个部分都有它自己的意思&#xff0c;拼凑在一起就构成了这个单词的意思 3. 一个规律&#xff0c;适用大部分…

基于nodejs+vue多媒体素材管理系统python-flask-django-php

该系统采用了nodejs技术、express 框架&#xff0c;连接MySQL数据库&#xff0c;具有较高的信息传输速率与较强的数据处理能力。包含管理员、教师和用户三个层级的用户角色&#xff0c;系统管理员可以对个人中心、用户管理、教师管理、资源类型管理、资源信息管理、素材类型管理…

论文阅读-《Lite Pose: Efficient Architecture Design for 2D Human Pose Estimation》

摘要 这篇论文主要研究了2D人体姿态估计的高效架构设计。姿态估计在以人为中心的视觉应用中发挥着关键作用&#xff0c;但由于基于HRNet的先进姿态估计模型计算成本高昂&#xff08;每帧超过150 GMACs&#xff09;&#xff0c;难以在资源受限的边缘设备上部署。因此&#xff0…

(三)Ribbon负载均衡

1.1.负载均衡原理 SpringCloud底层其实是利用了一个名为Ribbon的组件&#xff0c;来实现负载均衡功能的。 1.2.源码跟踪 为什么我们只输入了service名称就可以访问了呢&#xff1f;之前还要获取ip和端口。 显然有人帮我们根据service名称&#xff0c;获取到了服务实例的ip和…

GitLab更新失败(Ubuntu)

在Ubuntu下使用apt更新gitlab报错如下&#xff1a; An error occurred during the signature verification.The repository is not updated and the previous index files will be used.GPG error: ... Failed to fetch https://packages.gitlab.com/gitlab/gitlab-ee/ubuntu/d…

Leetcode 3.26

Leetcode Hot 100 一级目录1.每日温度 堆1.数组中的第K个最大元素知识点&#xff1a;排序复杂度知识点&#xff1a;堆的实现 2.前 K 个高频元素知识点&#xff1a;优先队列 一级目录 1.每日温度 每日温度 思路是维护一个递减栈&#xff0c;存储的是当前元素的位置。 遍历整个…

web学习笔记(四十五)Node.js

目录 1. Node.js 1.1 什么是Node.js 1.2 为什么要学node.js 1.3 node.js的使用场景 1.4 Node.js 环境的安装 1.5 如何查看自己安装的node.js的版本 1.6 常用终端命令 2. fs 文件系统模块 2.1引入fs核心模块 2.2 读取指定文件的内容 2.3 向文件写入指定内容 2.4 创…

app自动化-Appium学习笔记

使用Appium&#xff0c;优点&#xff1a; 1、支持语言比较多&#xff0c;例如&#xff1a;Java、Python、Javascript、PHP、C#等语言 2、支持跨应用&#xff08;windows、mac、linux&#xff09; 3、适用平台Android、iOS 4、支持Native App(原生app)、Web App、Hybird App…

canvas画图写文字,有0.5像素左右的位置偏差,无解决办法,希望有知道问题的大神告知一下

提示&#xff1a;canvas画图写文字 文章目录 前言一、写文字总结 前言 一、写文字 test.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-widt…

Fragment 与 ViewPager的联合应用(2)

5.创建底部布局bottom_layout <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:orientation"horizontal"android:layout_width"match_parent"android:layout_height"55dp"android:background&qu…