第15周:RNN心脏病预测

目录

前言

二、前期准备

2.1 设置GPU

2.2 导入数据

2.2.1 数据介绍

2.2.2 导入代码

2.2.3 检查数据

三、数据预处理

3.1 划分训练集与测试集

3.2 标准化

四、构建RNN模型

4.1 基本概念

4.2 搭建代码

五、编译模型

六、训练模型

七、模型评估

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客
  • 🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)

说在前面

本周目标:本地读取并加载数据、了解循环神经网络(RNN)的构建过程、调整代码是的测试机acuuracy达到87%;拔高目标——测试集accuracy达到89%

我的环境:Python3.8、Pycharm2020、tensorflow2.4.0

数据来源:[K同学啊](https://mtyjkh.blog.csdn.net/)

代码的流程图如下:


一、RNN简介

传统神经网络结构比较简单是输入层——隐藏层——输出层,而RNN与传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示,左图为传统神经网络,右图为RNN

 以一个案例具体分析RNN工作过程,用户说了一句“what time is it?”,我们的神经网络首先会将这句话分为五个基本单元(四个单词➕一个问号);然后按照顺序将5个基本单元输入RNN网络,what作为RNN的输入得到输出01,按照顺序将“time”输入RNN网络,得到输出02,这个过程中可以看到输入“time”的时候,前面“what”的输出也会对02的输出产生了影响(如下图中所示,隐藏层中有一半是黑色的),依次类推,前面所有的输入产生的结果都对后续的输出产生了印象(下图中最后的圆形中就包含了前面所有的颜色) 

当神经网络判断意图的时候,只需要最后一层的输出05,如下图所示

                               

循环神经网络(RNN)是一类用于处理序列数据的神经网络。不同于传统的前馈神经网络,RNN 能够处理序列长度变化的数据,如文本、语音等。RNN 的特点是在模型中引入循环,使得网络能够保持某种状态,从而在处理序列数据时表现出更好的性能。

上图左边简单描述 RNN 的原理,x 是输入层,o 是输出层,中间 s 是隐藏层,在 s 层进行一个循环,右边表示展开循环看到的逻辑,其实是和时间 t 相关的一个状态变化,也就是说神经网络在处理数据的时候,能看到前一时刻、后一时刻的状态,也就是常说的上下文

二、前期准备

2.1 设置GPU

代码如下:

#一、前期准备
#1.1 导入所需包和设置GPU
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
import matplotlib.pyplot as plt

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
print(gpus)

2.2 导入数据

2.2.1 数据介绍

  • age:1)年龄
  • sex:2)性别
  • cp:3)胸痛类型(4 values)
  • trestbps:4)静息血压
  • chol:5)血清胆甾淳(mg/dl)
  • fbs:6)空腹血糖>120mg/dl
  • restecg:7)静息心电图结果(值0,1,2)
  • thalach:8)达到的最大心率
  • exang:9)运动诱发的心绞痛
  • olddpeak:10)相对静止状态,运动引起的ST段压低
  • slope:11)运动峰值ST段的斜率
  • ca:12)荧光透视着色的主要血管数量(0-3)
  • thal:13)0=正常,1=固定缺陷;2=可逆转的缺陷
  • target:14)0=心脏病发作的几率较小,1=心脏病发作的几率更大

2.2.2 导入代码

#1.2 导入数据
df = pd.read_csv('heart.csv')
print(df)

2.2.3 检查数据

检查是否存在空值

df.isnull().sum()  #检查是否有空值

数据打印显示如下

三、数据预处理

3.1 划分训练集与测试集

补充:测试集与验证集的关系——①验证集并没有参与训练中梯度下降的过程,狭义上来讲是没有参与模型的参数训练更新的;②但广义上来说,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后的模型在vaild data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等;③所以也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

#二、数据预处理
#2.1 数据集划分
x = df.iloc[:,:-1]
y = df.iloc[:,-1]

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)
print(x_train.shape, y_train.shape)

打印输出:(272, 13) (272,)

3.2 标准化

代码如下:

# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

四、构建RNN模型

4.1 基本概念

函数原型:tf.keras.layers.SimpleRNN(units,activation='tanh',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',kernel_regularizer=Noe,recurrent_regularizer=Noe,bias_regularizer=None,activity_regularizer=None,keenel_constraint=None,recurrent_constraint=None,bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,**kwargs)

关键参数说明:

  • units——正整数,输出空间的维度
  • activation——要使用的激活函数,默认为双曲正切(tanh),如果传入None,则不使用激活函数(即线性激活a(x)=x)
  • use_bias——布尔值,该层是否使用偏置向量
  • kernel_initializer——kernel权值矩阵的初始化器,用于输入的线性转换
  • recurrent_initializer——recurrent_kernel权值矩阵的初始化器,用于循环层状态的线性转换
  • bias_initializer——偏置向量的初始化器
  • dropout:在-0和1之间的浮点数,单元的丢弃比例,用于输入的线性转换

4.2 搭建代码

#三、构建RNN模型

model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

模型输出如下:

五、编译模型

代码如下:

#四、编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy', optimizer=opt,metrics=['accuracy'])

六、训练模型

代码如下:

#五、训练模型
epochs = 100
history = model.fit(x_train, y_train,
                    epochs=epochs,
                    batch_size=128,
                    validation_data=(x_test, y_test),
                    verbose=1)

训练过程:

七、模型评估

代码如下

#六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

scores = model.evaluate(x_test,y_test,verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

打印结果:

accuracy: 90.32%


总结

RNN实战应用,是一种用于处理序列数据的神经网络,了解了基于Tensorflow搭建RNN的过程;学习了对于文本类数据,是怎么将其数字化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/762151.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

掌握 Python 中 isinstance 的正确用法

👋 简介 isinstance() 函数用于判断一个对象是否是一个特定类型或者在继承链中是否是特定类型的实例。它常用于确保函数接收到的参数类型是预期的。 📖 正文 1 语法 isinstance(object, classinfo) object参数是要检查的对象;classinfo参数…

幻兽帕鲁联机延迟高、无法联机、联机卡顿?这样解决

幻兽帕鲁是一款超人气的冒险游戏,该作曾被讽刺为抄袭怪、缝合怪,但是依旧架不住其在全球的爆火的架势,近期该作更新了游戏内的首个大型地图,并且还新增了区域系统上限、多人专用斗技场和部分游玩内容优化,也吸引了很多…

昇思25天学习打卡营第03天 | 张量 Tensor

昇思25天学习打卡营第03天 | 张量 Tensor 文章目录 昇思25天学习打卡营第03天 | 张量 Tensor张量张量的创建张量的属性Tensor与NumPy转换稀疏张量CSRTensorCOOTensor 总结打卡 张量 张量(Tensor)是一种类似于数组和矩阵的特殊数据结构,是神经…

AI智能在Type-C领域的应用

随着科技的飞速发展,Type-C接口凭借其卓越的性能和广泛的应用场景,已成为现代电子设备中不可或缺的一部分。而AI智能技术的兴起,为Type-C领域带来了革命性的变革,推动了其功能的进一步完善和应用领域的拓展。本文将探讨AI智能在Ty…

Redis缓存管理机制

在当今快节奏的数字世界中,性能优化对于提供无缝的用户体验至关重要。缓存在提高应用程序性能方面发挥着至关重要的作用,它通过将经常使用或处理的数据存储在临时高速存储中来减少数据库负载并缩短响应时间,从而减少系统的延迟。Redis 是一种…

基于深度学习的水果蔬菜检测识别系统(Python源码+YOLOv8+Pyqt5界面+数据集+训练代码 MX_004期)

系统演示: 基于深度学习的水果蔬菜检测识别系统 界面图: 技术组成: 深度学习模型(YOLOv8): YOLOv8是基于YOLO系列的目标检测模型,具有较快的检测速度和良好的准确率,适合于实时应用场…

在 Java 中的使用Selenium 测试框架

Selenium 测试框架:在 Java 中的使用 Selenium 测试框架就是这样一个强大的工具,它为 Web 应用的自动化测试提供了全面且高效的解决方案。 一、Selenium 简介 Selenium 是一个开源的自动化测试工具集,专门用于测试 Web 应用程序。它支持多…

聊一聊质量测试框架

质量测试框架的概述: 质量测试框架是一个为测试人员提供指导、工具和技术的系统,用于确保软件满足预定的质量标准和用户需求。它涵盖了测试计划、测试用例设计、测试执行、结果分析和测试报告等多个方面。 质量测试框架相关术语: 外部性质的…

解决OneDrive “拒绝访问文件” 问题

问题描述: 在尝试将其他文件拖入oneDrive或是打开OneDrive中的文件时。出现如下报错: 拒绝访问文件 无法访问XXXXXXX中的文件。可能已移动或删除了此文件,或者受制于文件权限而不能访问。 ERR_ACCESS_DENIED 解决办法: 1. 找到O…

【MySQL备份】Percona XtraBackup实战篇

目录 1. 前言 2.准备工作 2.1.创建备份目录 2.2.配置/etc/my.cnf文件 2.3.授予root用户BACKUP_ADMIN权限 3.全量备份 4.准备备份 5.数据恢复 6.总结 "实战演练:利用Percona XtraBackup执行MySQL全量备份操作详解" 1. 前言 本文将继续上篇【My…

论文笔记:MobilityGPT: Enhanced Human MobilityModeling with a GPT mode

1 intro 1.1 背景 尽管对人类移动轨迹数据集的需求不断增加,但其访问和分发仍面临诸多挑战 首先,这些数据集通常由私人公司或政府机构收集,因此可能因泄露个人敏感生活模式而引发隐私问题其次,公司拥有的数据集可能会暴露专有商…

侯捷C++面向对象高级编程(上)-6-三大函数:拷贝构造、拷贝复制、析构

1. 2.三个特殊函数 3.构造函数和析构函数 4.浅拷贝(系统默认仅把指针拷贝过去) 5.拷贝构造函数(深拷贝,拷贝的内容,重写string函数) 6.拷贝赋值

遇到多语言跨境电商系统源码问题?这里有解决方案!

从手机到电脑,从线下到线上,如今,跨境电商正在打破地域界限,成为全球贸易的新引擎。在这个全球化的背景下,跨境电商平台的运营也面临着一系列的挑战,其中之一就是多语言问题。如果你遇到了多语言跨境电商系…

每日一题-字符串相加

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” #define _CRT_SECURE_NO_WARNINGS #include<iostream> using namespace std;class Solution { public:string addStrings(string num1, string num2) {//11123//从个位开…

基于偏微分方程模型的一维信号降噪(MATLAB)

自然界很多领域如天文、物理、生物、化学等的运动和状态的变化受多个因素影响&#xff0c;而偏微分方程恰好可以描述这种多变量问题&#xff0c;因此被引入科学研究是一种必然。偏微分方程早期的时侯用来描述机械物体和流体的自然运动和物理规律&#xff0c;且其应用领域不断拓…

【收藏】SaaS运营方法论:寻找合适的合作伙伴的四大方法

一、使用关键字研究工具查找您所在行业的相关博客、频道和网站 但是&#xff0c;根据你的业务规模和性质&#xff0c;如果你需要主动出击寻找合适的推广伙伴&#xff0c;而不仅限于让潜在合作伙伴找你&#xff0c;你可以使用关键字研究工具。 实话实说&#xff0c;最好的联盟营…

告别手工录入,企业财务凭证同步迈入智能新时代!

一、客户介绍 某金融租赁股份有限公司作为一家领先的金融租赁企业&#xff0c;一直秉持着创新驱动、服务至上的经营理念。随着业务的快速发展&#xff0c;该公司在财务管理和凭证管理方面遇到了新的挑战。为了更好地提升工作效率&#xff0c;降低运营成本&#xff0c;该公司决…

vue+js实现鼠标右键页面时在鼠标位置出现弹窗

首先是弹窗元素 <div class"tanchuang move-win1"id"tanchuang1"><el-button>111</el-button></div>然后在需要弹窗的地方监听点击事件&#xff0c;可以将这个方法写在页面载入事件中 // 获取弹窗元素 var tanchuang document.…

Open3D Ransac点云配准算法(粗配准)

目录 一、概述 1.1简介 1.2RANSAC在点云粗配准中的应用步骤 二、代码实现 2.1关键函数 2.2完整代码 2.3代码解析 2.3.1计算FPFH 1. 法线估计 2. 计算FPFH特征 2.3.2 全局配准 1.函数&#xff1a;execute_global_registration 2.距离阈值 3.registration_ransac_b…

二维码中的文件如何列表排列?轻松在线做文件活码的简单技巧

现在二维码是常用来展示文件的手段之一&#xff0c;将文件生成二维码之后扫码就能够快速在手机上查看文件内容&#xff0c;减少了下载内存占用及传输的时间&#xff0c;通过这种方式可以更加方便快捷的获取文件内容。怎么把多个文件用列表的方式来扫码展示呢&#xff1f; 多个…