【信号处理】基于DGGAN的单通道脑电信号增强和情绪检测(tensorflow)

关于

情绪检测,是脑科学研究中的一个常见和热门的方向。在进行情绪检测的分类中,真实数据不足,经常导致情绪检测模型的性能不佳。因此,对数据进行增强,成为了一个提升下游任务的重要的手段。本项目通过DCGAN模型实现脑电信号的扩充。

 图片来源:https://www.medicalnewstoday.com/articles/seizure-eeg

工具

数据

方法实现

DCGAN速递:https://arxiv.org/abs/1511.06434

数据加载和预处理
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import SGD
from sklearn.metrics import accuracy_score
from model_DCGAN import DCGAN
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from sklearn.utils import shuffle
from sklearn.ensemble import GradientBoostingClassifier

use_feature_reduction = True

tf.keras.backend.clear_session()

df=pd.read_csv('dataset/emotions.csv')

encode = ({'NEUTRAL': 0, 'POSITIVE': 1, 'NEGATIVE': 2} )
#new dataset with replaced values
df_encoded = df.replace(encode)

print(df_encoded.head())
print(df_encoded['label'].value_counts()),

x=df_encoded.drop(["label"]  ,axis=1)
y = df_encoded.loc[:,'label'].values

scaler = StandardScaler()
scaler.fit(x)
x = scaler.transform(x)
y = to_categorical(y)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 4)

if use_feature_reduction:
    # Feature reduction part
    est = GradientBoostingClassifier(n_estimators=10, learning_rate=0.1, random_state=0).fit(x_train,
                                                                                             y_train.argmax(-1))

    # Obtain feature importance results from Gradient Boosting Regressor
    feature_importance = est.feature_importances_
    epsilon_feature = 1e-2
    x_train = x_train[:, feature_importance > epsilon_feature]
    x_test = x_test[:, feature_importance > epsilon_feature]
设置DCGAN优化器

# setup optimzers
gen_optim = Adam(1e-4, beta_1=0.5)
disc_optim = RMSprop(5e-4)
 训练GAN生成类别0脑电数据
# generate samples for class 0
generator_class = 0
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_0 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_0, acc_history_class_0, grads_history_class_0 = dcgan.train(x_train_class_0, epochs=100)
print("Class 0 fake samples are generating")
generator_class_0 = dcgan.generator
generated_samples_class_0, _ = dcgan.generate_fake_data(N=len(x_train_class_0))
  训练GAN生成类别1脑电数据
# generate samples for class 1
generator_class = 1
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_1 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_1, acc_history_class_1, grads_history_class_1 = dcgan.train(x_train_class_1, epochs=100)
print("Class 1 fake samples are generating")
generator_class_1 = dcgan.generator
generated_samples_class_1, _ = dcgan.generate_fake_data(N=len(x_train_class_1))
 训练GAN生成类别2脑电数据
# generate samples for class 2
generator_class = 2
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_2 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_2, acc_history_class_2, grads_history_class_2 = dcgan.train(x_train_class_2,epochs=100)
print("Class 2 fake samples are generating")
generator_class_2 = dcgan.generator
generated_samples_class_2, _ = dcgan.generate_fake_data(N=len(x_train_class_2))
合成数据融入真实训练数据集
generated_samples = np.concatenate((generated_samples_class_0,
                                    generated_samples_class_1,
                                    generated_samples_class_2),axis=0)
generated_y =np.concatenate((np.zeros((len(x_train_class_0),),dtype=np.int32),
                             np.ones((len(x_train_class_1),),dtype=np.int32),
                             2 * np.ones((len(x_train_class_2),),dtype=np.int32)),axis=0)

generated_y = to_categorical(generated_y)

x_train_all = np.concatenate((x_train,generated_samples),axis=0)
y_train_all = np.concatenate((y_train,generated_y), axis=0)


#shuffle training data
x_train_all, y_train_all = shuffle(x_train_all,y_train_all)
 基于数据增强的LSTM模型情绪检测
model = Sequential()
model.add(LSTM(64, input_shape=(1,x_train_all.shape[2]),activation="relu",return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(32,activation="sigmoid"))
model.add(Dropout(0.2))
model.add(Dense(3, activation='sigmoid'))
model.compile(loss = 'categorical_crossentropy', optimizer = "adam", metrics = ['accuracy'])
model.summary()


history = model.fit(x_train_all, y_train_all, epochs = 250, validation_data= (x_test, y_test))
score, acc = model.evaluate(x_test, y_test)

pred = model.predict(x_test)
predict_classes = np.argmax(pred,axis=1)
expected_classes = np.argmax(y_test,axis=1)
print(expected_classes.shape)
print(predict_classes.shape)
correct = accuracy_score(expected_classes,predict_classes)
print(f"Test Accuracy: {correct}")

已附DCGAN模型

相关项目和代码问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/495243.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【动手学深度学习】深入浅出深度学习之线性神经网络

目录 🌞一、实验目的 🌞二、实验准备 🌞三、实验内容 🌼1. 线性回归 🌻1.1 矢量化加速 🌻1.2 正态分布与平方损失 🌼2. 线性回归的从零开始实现 🌻2.1. 生成数据集 &#x…

泛微表单添加自定义按钮

页面效果&#xff1a; 点击按钮&#xff0c;将参数字段对应的值传入链接中。 表单配置如下&#xff1a; 然后插入js代码块&#xff0c;代码如下&#xff1a; <script> jQuery(document).ready(function(){ //在表单的按钮单元格插入自定义属性&#xff1a;ID&#xff1…

三级等保建设技术方案-Word

1信息系统详细设计方案 1.1安全建设需求分析 1.1.1网络结构安全 1.1.2边界安全风险与需求分析 1.1.3运维风险需求分析 1.1.4关键服务器管理风险分析 1.1.5关键服务器用户操作管理风险分析 1.1.6数据库敏感数据运维风险分析 1.1.7“人机”运维操作行为风险综合分析 1.2…

云能耗管理系统在某高校建筑系统平台的开发与应用

摘要&#xff1a;依据本项目依托某学院的电能计量管理系统、给水计量监管系统以及供热计量管理系统等基础平台&#xff0c;制订了高等学校建筑能耗综合管理系统平台应用的总体框架和方案&#xff0c;该系统可以对校园建筑的各种用能情况进行实时监测、统计能耗、进行能效分析&a…

DVWA-CSRF通关教程-完结

DVWA-CSRF通关教程-完结 文章目录 DVWA-CSRF通关教程-完结Low页面使用源码分析漏洞利用 Medium源码分析漏洞利用 High源码分析漏洞利用 impossible源码分析 Low 页面使用 当前页面上&#xff0c;是一个修改admin密码的页面&#xff0c;只需要输入新密码和重复新密码&#xff…

全局UI方法-弹窗三-文本滑动选择器弹窗(TextPickDialog)

1、描述 根据指定的选择范围创建文本选择器&#xff0c;展示在弹窗上。 2、接口 TextPickDialog(options?: TextPickDialogOptions) 3、TextPickDialogOptions 参数名称 参数类型 必填 参数描述 rang string[] | Resource 是 设置文本选择器的选择范围。 selected nu…

聚酰亚胺PI材料难于粘接,用什么胶水粘接?那么让我们先一步步的从认识它开始(十一): 聚酰亚胺PI纤维

聚酰亚胺PI纤维 聚酰亚胺PI纤维是由聚酰亚胺制成的纤维材料&#xff0c;是一种非常高性能的工程纤维材料。它具有极强的耐温性、耐化学性、耐热性和耐磨性等特点。具体来说&#xff0c;聚酰亚胺PI纤维的耐温性能非常突出&#xff0c;可以承受高达300℃以上的高温条件而不流失或…

热爱负压自动排渣放水器向光而行

你还很年轻&#xff0c;将来你会遇到很多人&#xff0c; 经历很多事&#xff0c;得到很多&#xff0c;也会失去很多&#xff0c; 但无论如何有两样东西&#xff0c;你绝不能丢弃&#xff0c; 一个叫良心&#xff0c;另一个叫理想。 热爱负压自动排渣放水器向光而行 一、概述 【…

目标检测+车道线识别+追踪

一种方法&#xff1a; 车道线检测-canny边缘检测-霍夫变换 一、什么是霍夫变换 霍夫变换&#xff08;Hough Transform&#xff09;是一种在图像处理和计算机视觉中广泛使用的特征检测技术&#xff0c;主要用于识别图像中的几何形状&#xff0c;尤其是直线、圆和椭圆等常见形状…

[密码学] 密码学基础

目录 一 为什么要加密? 二 常见的密码算法 三 密钥 四 密码学常识 五 密码信息威胁 六 凯撒密码 一 为什么要加密? 在互联网的通信中&#xff0c;数据是通过很多计算机或者通信设备相互转发&#xff0c;才能够到达目的地,所以在这个转发的过程中&#xff0c;如果通信包…

MySql实战--普通索引和唯一索引,应该怎么选择

在前面的基础篇文章中&#xff0c;我给你介绍过索引的基本概念&#xff0c;相信你已经了解了唯一索引和普通索引的区别。今天我们就继续来谈谈&#xff0c;在不同的业务场景下&#xff0c;应该选择普通索引&#xff0c;还是唯一索引&#xff1f; 假设你在维护一个市民系统&…

C++堆详细讲解

介绍 二叉堆是一种基础数据结构&#xff0c;主要应用于求出一组数据中的最大最小值。C 的STL中的优先队列就是使用二叉堆。 堆的性质 : 1 . 堆是一颗完全二叉树 ; 2 . 堆分为大根堆 和 小根堆(这里不讨论那些更高级的如:二叉堆&#xff0c;二叉堆&#xff0c;左偏树等等) …

《手把手教你》系列技巧篇(五十八)-java+ selenium自动化测试-分页测试(详细教程)

1.简介 前几天&#xff0c;有人私信里留言问宏哥&#xff0c;分页怎么自动化测试了&#xff0c;完了给他说了说思路&#xff0c;不知道最后搞定没有&#xff0c;索性宏哥就写一篇文章来讲解和介绍如何处理分页。 2.测试场景 对分页来说&#xff0c;我们最感兴趣的和测试的无非…

主流公链 - Monero

Monero: 加密货币的隐私标杆 1. 简介 Monero&#xff08;XMR&#xff09;&#xff0c;世界语中货币的意思&#xff0c;是一种去中心化的加密货币&#xff0c;旨在提供隐私和匿名性。与比特币等公开区块链不同&#xff0c;Monero专注于隐私保护&#xff0c;使用户的交易记录和余…

快速上手Pytrch爬虫之爬取某应图片壁纸

一、前置知识 1 爬虫简介 网络爬虫&#xff08;又被称作网络蜘蛛、网络机器人&#xff0c;在某些社区中也经常被称为网页追逐者)可以按照指定的规则&#xff08;网络爬虫的算法&#xff09;自动浏览或抓取网络中的信息。 1.1 Web网页存在方式 表层网页指的是不需要提交表单&…

网络: 套接字

套接字: 在网络上进行进程间通信 网络字节序与主机字节序的转化 sockaddr sockaddr struct sockaddr {sa_family_t sa_family; // 地址族char sa_data[14]; // 地址数据&#xff0c;具体内容与地址族相关 };sockaddr_in :主要是地址类型, 端口号, IP地址. 基于IPv4编程…

Linux:文件增删 文件压缩指令

Linux&#xff1a;文件增删 & 文件压缩指令 文件增删touch指令mkdir指令cp指令rm指令rmdir指令 文件压缩zip & unzip 指令tar指令 文件增删 touch指令 功能&#xff1a;touch命令参数可更改文档或目录的日期时间&#xff0c;包括存取时间和更改时间&#xff0c;或者新…

vitepress builld报错

问题&#xff1a;build时报错&#xff1a;document/window is not defined。 背景&#xff1a;使用vitepress展示自定义的组件&#xff0c;之前build是没有问题了&#xff0c;由于新增了qr-code以及quill富文本组件&#xff0c;导致打包时报错。 原因&#xff1a;vitepress官…

密码学 总结

群 环 域 群 group G是一个集合&#xff0c;在此集合上定义代数运算*&#xff0c;若满足下列公理&#xff0c;则称G为群。 1.封闭性 a ∈ G , b ∈ G a\in G,b\in G a∈G,b∈G> a ∗ b ∈ G a*b\in G a∗b∈G 2.G中有恒等元素e&#xff0c;使得任何元素与e运算均为元素本…

vue实现把Ox格式颜色值转换成rgb渐变颜色值(开箱即用)

图示&#xff1a; 核心代码&#xff1a; //将0x格式的颜色转换为Hex格式&#xff0c;并计算插值返回rgb颜色 Vue.prototype.$convertToHex function (colorCode1, colorCode2, amount) {// 确保输入是字符串&#xff0c;并检查是否以0x开头let newCode1 let newCode2 if (t…