RNN心脏病预测-Pytorch版本

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''
导入数据
'''
df = pd.read_csv('heart.csv')
print(df)

二 构建数据集 

'''
构建数据集
'''
# 1. 标准化
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
sc = StandardScaler()
X = sc.fit_transform(X)

# 2.划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=14)

# 3.构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 三 模型训练

'''
模型训练
'''
# 1.构建模型
class model_rnn(nn.Module):
    def __init__(self):
        super(model_rnn,self).__init__()
        self.rnn0 = nn.RNN(input_size=13,hidden_size=200,num_layers=1,batch_first=True)
        self.fc0 = nn.Linear(200,50)
        self.fc1 = nn.Linear(50,2)

    def forward(self,x):
        out,hidden1 = self.rnn0(x)
        out = self.fc0(out)
        out = self.fc1(out)
        return out

model = model_rnn()
print(model)

 

四 定义训练函数 

# 2.定义训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

 五 定义测试函数

def t(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

 六 训练

loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 50

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    model.eval()
    epoch_test_acc, epoch_test_loss = t(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:46.7%, Train_loss:0.698, Test_acc:54.8%,Test_loss:0.689
Epoch: 2, Train_acc:59.6%, Train_loss:0.682, Test_acc:58.1%,Test_loss:0.682
Epoch: 3, Train_acc:71.3%, Train_loss:0.666, Test_acc:58.1%,Test_loss:0.676
Epoch: 4, Train_acc:77.2%, Train_loss:0.652, Test_acc:54.8%,Test_loss:0.670
Epoch: 5, Train_acc:81.6%, Train_loss:0.637, Test_acc:54.8%,Test_loss:0.665
Epoch: 6, Train_acc:82.4%, Train_loss:0.623, Test_acc:58.1%,Test_loss:0.659
Epoch: 7, Train_acc:80.9%, Train_loss:0.609, Test_acc:58.1%,Test_loss:0.654
Epoch: 8, Train_acc:81.2%, Train_loss:0.594, Test_acc:58.1%,Test_loss:0.649
Epoch: 9, Train_acc:81.6%, Train_loss:0.580, Test_acc:58.1%,Test_loss:0.644
Epoch:10, Train_acc:81.6%, Train_loss:0.565, Test_acc:58.1%,Test_loss:0.639
Epoch:11, Train_acc:82.0%, Train_loss:0.550, Test_acc:58.1%,Test_loss:0.634
Epoch:12, Train_acc:82.0%, Train_loss:0.535, Test_acc:58.1%,Test_loss:0.630
Epoch:13, Train_acc:82.0%, Train_loss:0.520, Test_acc:58.1%,Test_loss:0.625
Epoch:14, Train_acc:82.0%, Train_loss:0.504, Test_acc:58.1%,Test_loss:0.621
Epoch:15, Train_acc:82.0%, Train_loss:0.488, Test_acc:58.1%,Test_loss:0.617
Epoch:16, Train_acc:82.4%, Train_loss:0.473, Test_acc:58.1%,Test_loss:0.614
Epoch:17, Train_acc:83.1%, Train_loss:0.457, Test_acc:61.3%,Test_loss:0.611
Epoch:18, Train_acc:83.1%, Train_loss:0.442, Test_acc:61.3%,Test_loss:0.608
Epoch:19, Train_acc:83.8%, Train_loss:0.427, Test_acc:64.5%,Test_loss:0.605
Epoch:20, Train_acc:85.7%, Train_loss:0.412, Test_acc:61.3%,Test_loss:0.603
Epoch:21, Train_acc:86.4%, Train_loss:0.398, Test_acc:64.5%,Test_loss:0.601
Epoch:22, Train_acc:87.1%, Train_loss:0.383, Test_acc:64.5%,Test_loss:0.600
Epoch:23, Train_acc:87.9%, Train_loss:0.370, Test_acc:64.5%,Test_loss:0.600
Epoch:24, Train_acc:88.2%, Train_loss:0.357, Test_acc:67.7%,Test_loss:0.601
Epoch:25, Train_acc:88.2%, Train_loss:0.344, Test_acc:71.0%,Test_loss:0.602
Epoch:26, Train_acc:88.6%, Train_loss:0.332, Test_acc:71.0%,Test_loss:0.605
Epoch:27, Train_acc:89.3%, Train_loss:0.321, Test_acc:71.0%,Test_loss:0.609
Epoch:28, Train_acc:89.3%, Train_loss:0.310, Test_acc:67.7%,Test_loss:0.614
Epoch:29, Train_acc:89.3%, Train_loss:0.300, Test_acc:67.7%,Test_loss:0.620
Epoch:30, Train_acc:90.1%, Train_loss:0.290, Test_acc:67.7%,Test_loss:0.627
Epoch:31, Train_acc:90.1%, Train_loss:0.281, Test_acc:71.0%,Test_loss:0.635
Epoch:32, Train_acc:90.8%, Train_loss:0.272, Test_acc:71.0%,Test_loss:0.644
Epoch:33, Train_acc:90.8%, Train_loss:0.264, Test_acc:67.7%,Test_loss:0.654
Epoch:34, Train_acc:90.8%, Train_loss:0.257, Test_acc:64.5%,Test_loss:0.663
Epoch:35, Train_acc:90.8%, Train_loss:0.249, Test_acc:64.5%,Test_loss:0.673
Epoch:36, Train_acc:90.8%, Train_loss:0.243, Test_acc:64.5%,Test_loss:0.683
Epoch:37, Train_acc:91.9%, Train_loss:0.236, Test_acc:64.5%,Test_loss:0.693
Epoch:38, Train_acc:92.3%, Train_loss:0.230, Test_acc:64.5%,Test_loss:0.703
Epoch:39, Train_acc:92.3%, Train_loss:0.224, Test_acc:64.5%,Test_loss:0.714
Epoch:40, Train_acc:92.3%, Train_loss:0.218, Test_acc:64.5%,Test_loss:0.724
Epoch:41, Train_acc:92.3%, Train_loss:0.213, Test_acc:61.3%,Test_loss:0.736
Epoch:42, Train_acc:92.3%, Train_loss:0.207, Test_acc:61.3%,Test_loss:0.748
Epoch:43, Train_acc:93.4%, Train_loss:0.202, Test_acc:58.1%,Test_loss:0.760
Epoch:44, Train_acc:93.8%, Train_loss:0.197, Test_acc:58.1%,Test_loss:0.772
Epoch:45, Train_acc:93.8%, Train_loss:0.192, Test_acc:58.1%,Test_loss:0.784
Epoch:46, Train_acc:94.5%, Train_loss:0.187, Test_acc:58.1%,Test_loss:0.798
Epoch:47, Train_acc:94.9%, Train_loss:0.182, Test_acc:58.1%,Test_loss:0.812
Epoch:48, Train_acc:94.9%, Train_loss:0.177, Test_acc:58.1%,Test_loss:0.826
Epoch:49, Train_acc:95.2%, Train_loss:0.172, Test_acc:54.8%,Test_loss:0.842
Epoch:50, Train_acc:95.2%, Train_loss:0.167, Test_acc:54.8%,Test_loss:0.858
Done

七 结果可视化

1.准确率和损失值

plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2.混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)

pred = model(X_test.to(device)).argmax(1).cpu().numpy()

print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)

from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay

# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)

plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()

八 使用模型进行预测 

'''
使用模型进行预测
'''
test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果为:',pred)
print('=='*20)
print('0:不会患心脏病')
print('1:可能患心脏病')
模型预测结果为: 0
========================================
0:不会患心脏病
1:可能患心脏病

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951533.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

科普CMOS传感器的工作原理及特点

在当今数字化成像的时代,图像传感器无疑是幕后的关键 “功臣”,它宛如一位神奇的 “光影魔法师”,通过光电效应这一奇妙的物理现象,将光子巧妙地转换成电荷,为图像的诞生奠定基础。而在众多类型的图像传感器中&#xf…

Ubuntu 18.04 解决screen无法滚屏的问题

Ubuntu 18.04 解决screen无法滚屏的问题_ubuntu screen 无法上滑-CSDN博客文章浏览阅读2.7k次,点赞2次,收藏3次。在etc/screenrc中加入termcapinfo xterm* ti:te重新进入screen的sessionscreen -d -r XXX_ubuntu screen 无法上滑https://blog.csdn.net/w…

2025年01月09日Github流行趋势

1. 项目名称:khoj 项目地址url:https://github.com/khoj-ai/khoj项目语言:Python历史star数:22750今日star数:1272项目维护者:debanjum, sabaimran, MythicalCow, aam-at, eltociear项目简介:你…

Spring boot接入xxl-job

Spring boot接入xxl-job 导入maven包加入配置增加配置类创建执行器类&#xff08;写job的业务逻辑&#xff09;去控制台中配置job 导入maven包 <dependency><groupId>com.xuxueli</groupId><artifactId>xxl-job-core</artifactId><version>…

Cglib动态代理中method.invoke与methodProxy.invokeSuper区别浅尝

前段时间看了黑马Spring教程中&#xff0c;有期视频讲解Cglib动态代理。 代码如下图&#xff1a; 可以看到调用目标对象的方法代码为&#xff1a; method.invoke(target,objects);在其他地方看到的此处代码是&#xff1a; methodProxy.invokeSuper(o,objects);注意&#xff…

【Linux 之一 】Linux常用命令汇总

Linux常用命令 ./catcd 命令chmodclearcphistoryhtoplnmkdirmvpwdrmtailunamewcwhoami 我从2021年4月份开始才开始真正意义上接触Linux&#xff0c;最初学习时是一脸蒙圈&#xff0c;啥也不会&#xff0c;啥也不懂&#xff0c;做了很多乱七八糟&#xff0c;没有条理的笔记。不知…

NO.3 《机器学习期末复习篇》以题(问答题)促习(人学习),满满干huo,大胆学大胆补!

目录 &#x1f50d; 1. 对于非齐次线性模型 &#xff0c;试将其表示为齐次线性模型形式。 ​编辑 &#x1f50d; 2. 某汽车公司一年内各月份的广告投入与月销量数据如表3-28所示&#xff0c;试根据表中数据构造线性回归模型&#xff0c;并使用该模型预测月广告投入为20万元时…

Javascript算法——贪心算法(一)

贪心算法详解&#xff08;JavaScript&#xff09;&#xff08;局部最优->全局最优&#xff09; 贪心算法&#xff08;Greedy Algorithm&#xff09;是一种在每一步选择中都采取当前状态下的最优选择&#xff08;局部最优&#xff09;的算法设计方法。通过局部最优解的累积&…

【Vue】分享一个快速入门的前端框架以及如何搭建

先上效果图: 登录 菜单: 下载地址: 链接&#xff1a;https://pan.baidu.com/s/1m-ZlBARWU6_2n8jZil_RAQ 提取码&#xff1a;ui20 … 主要是可以自定义设置token,更改后端请求地址较为方便。 应用设置: 登录与token设置: 在这里设置不用登录,可以请求的接口: request.js i…

jdk8升级JDK21(Springboot2.7.18升级Springboot3.4.0)

目录 背景&#xff1a; 一、maven升级 二、代码改造 2.1 javax替换为jakarta 2.2 swagger2升级swagger3相关更新 2.2.1 新增SpringDocConfig配置类 2.2.2 全局代码更新 2.2.3 全局代码替换&#xff08;普通正则替换&#xff09; 2.3 Mybatis Plus升级 2.4 logback.xm…

数据库(3)--针对列的CRUD操作

1.Create 新增 语法&#xff1a; insert into 表名 &#xff08;列名&#xff09;values &#xff08;列&#xff09;... 创建一个学生表用于演示&#xff1a; create table if not exists student( id bigint comment 编号, name varchar(20) comment 姓名 ); 1.1直接增加…

加速科技荣获“浙江省企业研究院”认定

近日&#xff0c;浙江省经济和信息化厅公布“2024年认定&#xff08;备案&#xff09;省级企业研发机构名单”。经过多轮严格评审和公示&#xff0c;加速科技荣获“省企业研究院”认定。这是加速科技继获国家级专精特新“小巨人”企业认定荣誉后的又一里程碑。 “浙江省企业研究…

leetcode:1784. 检查二进制字符串字段(python3解法)

难度&#xff1a;简单 给你一个二进制字符串 s &#xff0c;该字符串 不含前导零 。 如果 s 包含 零个或一个由连续的 1 组成的字段 &#xff0c;返回 true​​​ 。否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;s "1001" 输出&#xff1a;fa…

双向列表的实现(C++)

一.实现思路 主要是一个空间存储一个数值&#xff0c;然后为了索引后面的数据单元和前面的数据单元&#xff0c;所以在每个空间里面还要存储前面和后面数据单元的指针&#xff0c;就形成了每个数据单元 后面就是要管理的是双向列表的头结点和尾节点&#xff0c;方便实现后面的头…

【前端开发常用网站汇总-01】

1、仿mac界面代码截图 https://codeimg.io/?utm_sourceappinn.com 2、可视化大屏汇总(在线Demo) https://www.xiongze.net/viewdata/index.html 3、在线Photoshop(实现简单P图) https://ps.gaoding.com/#/ 4、在线生成ico图标(png转icon文件) https://www.bitbug.net/in…

腾讯云AI代码助手编程挑战赛-百事一点通

作品简介 百事通问答是一款功能强大的智能问答工具。它依托海量知识储备&#xff0c;无论你是想了解生活窍门、学习难点&#xff0c;还是工作中的专业疑惑&#xff0c;只需输入问题&#xff0c;就能瞬间获得精准解答&#xff0c;以简洁易懂的方式呈现&#xff0c;随时随地为你…

网络安全 信息收集入门

1.信息收集定义 信息收集是指收集有关目标应用程序和系统的相关信息。这些信息可以帮助攻击者了解目标系统的架构、技术实现细节、运行环境、网络拓扑结构、安全措施等方面的信息&#xff0c;以便我们在后续的渗透过程更好的进行。 2.收集方式-主动和被动收集 ①收集方式不同…

Qt QDockWidget详解以及例程

Qt QDockWidget详解以及例程 引言一、基本用法二、深入了解2.1 窗口功能相关2.2 停靠区域限制2.3 在主窗体布局 引言 QDockWidget类提供了一个可以停靠在QMainWindow内的小窗口 (理论上可以在QMainWindow中任意排列)&#xff0c;也可以作为QMainWindow上的顶级窗口浮动 (类似一…

Spring——自动装配

假设一个场景&#xff1a; 一个人&#xff08;Person&#xff09;有一条狗&#xff08;Dog&#xff09;和一只猫(Cat)&#xff0c;狗和猫都会叫&#xff0c;狗叫是“汪汪”&#xff0c;猫叫是“喵喵”&#xff0c;同时人还有一个自己的名字。 将上述场景 抽象出三个实体类&…

MySQL安装,配置教程

一、Linux在线yum仓库安装 打开MySQL官方首页&#xff0c;链接为&#xff1a;https://www.mysql.com/ 界面如下&#xff1a; 在该页面中找到【DOWNOADS】选项卡&#xff0c;点击进入下载页面。 在下载界面中&#xff0c;可以看到不同版本的下载链接&#xff0c;这里选择【My…