LSTM预测:糖尿病的发生情况

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 本期,做个二维结构化数据的分类预测。提到结构化数据,一般的分类算法常用有:逻辑回归(二分类)、KNN、SVM、决策树、贝叶斯、随机森林、XGBoost等。本次我们采用LSTM长短期记忆网络进行分类预测

一 数据导入 

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')
# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_excel('dia.xls')
print(df.head())
print(df.shape)

print('数据缺失值-------------------')
print(df.isnull().sum())
print('数据重复值-------------------')
print(df.duplicated().sum())
卡号  性别  年龄  高密度脂蛋白胆固醇  低密度脂蛋白胆固醇  极低密度脂蛋白胆固醇  甘油三酯  总胆固醇  脉搏  舒张压   
0  18054421   0  38       1.25       2.99        1.07  0.64  5.31  83   83  \
1  18054422   0  31       1.15       1.99        0.84  0.50  3.98  85   63   
2  18054423   0  27       1.29       2.21        0.69  0.60  4.19  73   61   
3  18054424   0  33       0.93       2.01        0.66  0.84  3.60  83   60   
4  18054425   0  36       1.17       2.83        0.83  0.73  4.83  85   67   

   高血压史   尿素氮     尿酸  肌酐  体重检查结果  是否糖尿病  
0     0  4.99  243.3  50       1      0  
1     0  4.72  391.0  47       1      0  
2     0  5.87  325.7  51       1      0  
3     0  2.40  203.2  40       2      0  
4     0  4.09  236.8  43       0      0  
(1006, 16)
数据缺失值-------------------
卡号            0
性别            0
年龄            0
高密度脂蛋白胆固醇     0
低密度脂蛋白胆固醇     0
极低密度脂蛋白胆固醇    0
甘油三酯          0
总胆固醇          0
脉搏            0
舒张压           0
高血压史          0
尿素氮           0
尿酸            0
肌酐            0
体重检查结果        0
是否糖尿病         0
dtype: int64
数据重复值-------------------
0

 二 探索性数据分析

columns = df.drop(['是否糖尿病','卡号','性别'],axis=1).columns
plt.figure(figsize=(15,10))
for i,col in enumerate(columns,1):
    plt.subplot(3,5,i)
    sns.boxplot(x=df['是否糖尿病'],y=df[col])
    plt.title(f'{col}的箱线图')
    plt.ylabel('数值')
    plt.grid(True)
plt.tight_layout()
plt.show()

三 相关性分析 

df_corr = df.drop(['卡号'],axis=1).corr()
plt.figure(figsize=(8,6))
plt.title('相关性热图')
sns.heatmap(df_corr,annot=True)
plt.show()

四 划分数据集 

将负相关的特征(高密度脂蛋白胆固醇)剔除,只保留正相关的特征

# 划分数据集
X = df.drop(['是否糖尿病','高密度脂蛋白胆固醇','卡号'],axis=1)
y = df['是否糖尿病']

X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)

train_X,test_X,train_y,test_y = train_test_split(X,y,test_size=0.3,random_state=1)

# 数据集构建
train_dl = DataLoader(TensorDataset(train_X,train_y),
                      batch_size=64,
                      shuffle=False)

test_dl = DataLoader(TensorDataset(test_X,test_y),
                     batch_size=64,
                     shuffle=False)

五 构建LSTM模型 

# 定义模型
class model_lstm(nn.Module):
    def __init__(self):
        super(model_lstm,self).__init__()
        self.lstm0 = nn.LSTM(input_size=13,hidden_size=200,num_layers=1,batch_first=True)
        self.lstm1 = nn.LSTM(input_size=200,hidden_size=200,num_layers=1,batch_first=True)
        self.fc0 = nn.Linear(200,2)

    def forward(self,x):
        out,hidden1 = self.lstm0(x)
        out,_ = self.lstm1(out,hidden1)
        out = self.fc0(out)
        return out

model = model_lstm().to(device)
print(model)

六 训练并评估 

# 训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)   # 训练集的大小
    num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)

    train_acc,train_loss = 0,0  # 初始化训练损失和正确率

    for x,y in dataloader:    # 获取数据
        X,y = x.to(device),y.to(device)

        # 计算预测误差
        pred = model(X)   # 网络输出
        loss = loss_fn(pred,y)   # 计算误差

        # 反向传播
        optimizer.zero_grad()    # grad属性归零
        loss.backward()   # 反向传播
        optimizer.step()   # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc,train_loss
# 测试循环
def valid(dataloader,model,loss_fn):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)

    test_loss, test_acc = 0, 0  # 初始化训练损失和正确率

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target = imgs.to(device),target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred,target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc,test_loss
loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)

    model.eval()
    epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))

print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:56.5%,Train_loss:0.687,Test_acc:53.3%,Test_loss:0.686,lr:1.00E-04
Epoch: 2,Train_acc:56.7%,Train_loss:0.683,Test_acc:53.3%,Test_loss:0.687,lr:1.00E-04
Epoch: 3,Train_acc:56.5%,Train_loss:0.682,Test_acc:53.3%,Test_loss:0.686,lr:1.00E-04
Epoch: 4,Train_acc:56.5%,Train_loss:0.681,Test_acc:53.3%,Test_loss:0.685,lr:1.00E-04
Epoch: 5,Train_acc:56.7%,Train_loss:0.681,Test_acc:53.3%,Test_loss:0.685,lr:1.00E-04
Epoch: 6,Train_acc:56.8%,Train_loss:0.679,Test_acc:53.3%,Test_loss:0.684,lr:1.00E-04
Epoch: 7,Train_acc:57.0%,Train_loss:0.678,Test_acc:53.3%,Test_loss:0.683,lr:1.00E-04
Epoch: 8,Train_acc:56.8%,Train_loss:0.676,Test_acc:53.3%,Test_loss:0.681,lr:1.00E-04
Epoch: 9,Train_acc:56.5%,Train_loss:0.674,Test_acc:53.3%,Test_loss:0.679,lr:1.00E-04
Epoch:10,Train_acc:56.7%,Train_loss:0.671,Test_acc:53.3%,Test_loss:0.676,lr:1.00E-04
Epoch:11,Train_acc:57.0%,Train_loss:0.668,Test_acc:53.3%,Test_loss:0.673,lr:1.00E-04
Epoch:12,Train_acc:57.4%,Train_loss:0.665,Test_acc:53.3%,Test_loss:0.669,lr:1.00E-04
Epoch:13,Train_acc:57.7%,Train_loss:0.660,Test_acc:53.6%,Test_loss:0.664,lr:1.00E-04
Epoch:14,Train_acc:58.8%,Train_loss:0.655,Test_acc:53.6%,Test_loss:0.660,lr:1.00E-04
Epoch:15,Train_acc:59.5%,Train_loss:0.649,Test_acc:54.0%,Test_loss:0.655,lr:1.00E-04
Epoch:16,Train_acc:59.9%,Train_loss:0.643,Test_acc:56.3%,Test_loss:0.650,lr:1.00E-04
Epoch:17,Train_acc:61.8%,Train_loss:0.636,Test_acc:57.9%,Test_loss:0.644,lr:1.00E-04
Epoch:18,Train_acc:63.6%,Train_loss:0.628,Test_acc:60.3%,Test_loss:0.637,lr:1.00E-04
Epoch:19,Train_acc:65.3%,Train_loss:0.618,Test_acc:61.6%,Test_loss:0.630,lr:1.00E-04
Epoch:20,Train_acc:65.8%,Train_loss:0.607,Test_acc:63.2%,Test_loss:0.623,lr:1.00E-04
Epoch:21,Train_acc:66.6%,Train_loss:0.596,Test_acc:63.9%,Test_loss:0.616,lr:1.00E-04
Epoch:22,Train_acc:67.8%,Train_loss:0.584,Test_acc:64.6%,Test_loss:0.609,lr:1.00E-04
Epoch:23,Train_acc:70.3%,Train_loss:0.572,Test_acc:64.2%,Test_loss:0.602,lr:1.00E-04
Epoch:24,Train_acc:71.4%,Train_loss:0.560,Test_acc:66.6%,Test_loss:0.595,lr:1.00E-04
Epoch:25,Train_acc:72.4%,Train_loss:0.549,Test_acc:66.9%,Test_loss:0.590,lr:1.00E-04
Epoch:26,Train_acc:73.9%,Train_loss:0.538,Test_acc:66.6%,Test_loss:0.584,lr:1.00E-04
Epoch:27,Train_acc:74.3%,Train_loss:0.528,Test_acc:66.9%,Test_loss:0.579,lr:1.00E-04
Epoch:28,Train_acc:74.7%,Train_loss:0.518,Test_acc:67.5%,Test_loss:0.574,lr:1.00E-04
Epoch:29,Train_acc:76.0%,Train_loss:0.508,Test_acc:69.5%,Test_loss:0.570,lr:1.00E-04
Epoch:30,Train_acc:76.8%,Train_loss:0.499,Test_acc:70.5%,Test_loss:0.566,lr:1.00E-04
==================== Done ====================

七 可视化

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

总结:

 

1. 处理长序列数据

LSTM通过其独特的门控机制(输入门、遗忘门和输出门)有效地保留和忘记信息,解决了传统RNN在长序列学习中面临的梯度消失和梯度爆炸问题。这使得LSTM在处理长序列数据时能保持较好的性能。

2. 捕捉时间依赖性

LSTM能够捕捉序列中长期和短期的时间依赖性,对于许多应用(如语言建模、语音识别和视频分析等)来说,这种能力至关重要。它能够记住之前的信息并利用这些信息来影响当前的输出。

3. 适用于多种输入类型

LSTM不仅可以处理一维的时间序列数据,还可以处理多维的序列数据,适用于图像序列、文本序列等多种数据类型。这使得LSTM在许多不同的领域具有广泛的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/894136.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Jenkins配置流水线任务-实践操作(Pipeline-script)

Jenkins配置流水线任务-实践操作(Pipeline-script) 1、新增jenkins 任务,选择流水线 2、参数化 3、流水线配置 pipeline {agent anystages {stage(aoePlugin_mysql) {steps {echo "xxx,数据库:Mysql"echo "${HOST},${USER_NAME}"b…

王爽汇编语言第三版实验1

前言 本系列的文章是对王爽老师的汇编语言中的实验的解答记录,原书一共有17个实验,由于学校的教学流程只做到了第14个实验,因此本文章只会有前十四个实验的解答记录,还有个比较重要的是,文章中会有原书实验中没有的题目&#xff…

C语言 | Leetcode C语言题解之第477题汉明距离总和

题目&#xff1a; 题解&#xff1a; int totalHammingDistance(int* nums, int numsSize) {int ans 0;for (int i 0; i < 30; i) {int c 0;for (int j 0; j < numsSize; j) {c (nums[j] >> i) & 1;}ans c * (numsSize - c);}return ans; }

element plus的el-select分页

摘要&#xff1a; el-select的数据比较多的时候&#xff0c;必须要分页&#xff0c;处理方案有全部数据回来&#xff0c;或者添加搜索功能&#xff0c;但是就有个问题就是编辑的时候回显问题&#xff0c;必须要保证select的数据有对应的id与name匹配回显&#xff01; <el-fo…

如何用pyhton修改1000+图片的名字?

import os oldpath input("请输入文件路径&#xff08;在windows中复制那个图片文件夹的路径就可以):") #注意window系统中的路径用这个‘\分割&#xff0c;但是编程语言中一般都是正斜杠也就是’/‘ #这里写一个代码&#xff0c;将 \ > / path "" fo…

数字图像处理:图像复原应用

数字图像处理&#xff1a;图像复原应用 1.1 什么是图像复原&#xff1f; 图像复原是图像处理中的一个重要领域&#xff0c;旨在从退化&#xff08;例如噪声、模糊等&#xff09;图像中恢复出尽可能接近原始图像的结果。图像复原与图像增强不同&#xff0c;复原更多地依赖于图…

服务器数据恢复—服务器硬盘指示灯亮黄灯,raid崩溃的数据恢复案例

服务器数据恢复环境&#xff1a; 一台浪潮服务器中有一组由6块SAS硬盘组建的RAID。服务器上划分了1个卷&#xff0c;存放Oracle数据库文件。 服务器故障&检测&#xff1a; 服务器上有两个硬盘指示灯亮黄灯&#xff0c;RAID崩溃&#xff0c;服务器不可用。 将故障服务器中所…

LLM:deepspeed zero-2时模型训练所占显存分析

前置&#xff1a; fp16占2字节&#xff0c;fp32占4字节。换算就是1B的参数量&#xff0c;以fp16表示&#xff0c;占2G的内存。 模型参数为32B 全量微调&#xff1a; 模型参数&#xff1a;fp16的模型前向传播副本。fp32的模型的优化参数副本。这就是322324192G 梯度&#xff…

Jmeter简介

基础介绍 Jmeter录制脚本的原始是配置一个HTTP代理&#xff0c;然后浏览器通过这个代理访问测试页面从而完成脚本录制。 一、下载安装 jmeter本身不需要安装&#xff0c;需要配置环境变量JDK&#xff0c;然后打开bin文件夹中的jmeter.vbs即可。建议jdk 1.7及以上版本。 基本祖…

CVE-2024-22120:Zabbix低权限SQL注入至RCE+权限绕过

所有利用代码&#xff1a; GitHub - W01fh4cker/CVE-2024-22120-RCE: Time Based SQL Injection in Zabbix Server Audit Log --> RCE 一、漏洞环境搭建 1.1 下载vmware镜像并设置 直接懒人一键搭建&#xff1a; https://cdn.zabbix.com/zabbix/appliances/stable/6.0/6.0…

得物App3D创新应用引关注,世界设计之都大会启幕

近日&#xff0c;2024世界设计之都大会&#xff08;WDCC&#xff09;在上海盛大启幕。此次大会以“设计无界 新质生长”为主题&#xff0c;汇聚了全球设计领域的精英与前沿成果&#xff0c;展现了设计作为新质生产力的巨大潜力。主场展览占据了整整3个楼面&#xff0c;总面积达…

k8s-对命名空间资源配额

对k8s命名空间限制的方法有很多种&#xff0c;今天来演示一下很常用的一种 用的k8s对象就是ResourceQuota 一&#xff1a;创建命名空间 kubectl create ns test #namespace命名空间可以简写成ns 二&#xff1a; 对命名空间进行限制 创建resourcequota vim resourcequ…

基于Javaweb的医院挂号预约管理系统

系统展示 用户前台界面 管理员后台界面 医生后台界面 系统背景 在现代社会&#xff0c;随着医疗需求的不断增长&#xff0c;病患挂号成为医院面临的一大挑战。传统的挂号方式不仅耗时耗力&#xff0c;还容易引发混乱和不满。病患需要排队等候&#xff0c;挂号过程繁琐&#xff…

Nginx(Linux):启动停止Nginx

目录 1、理解Nginx后台进程2、停止Nginx(方式一&#xff1a;使用信号源)2.1 获取master进程号2.1 设置信号源 3、停止Nginx(方式二&#xff1a;使用命令行) 1、理解Nginx后台进程 Nginx后台进程包含master和worker两类进程。 master进程&#xff1a;主要用来管理worker进程&am…

Docker 教程四 (Docker 镜像加速)

Docker 镜像加速 国内从 DockerHub 拉取镜像有时会遇到困难&#xff0c;此时可以配置镜像加速器。 目前国内 Docker 镜像源出现了一些问题&#xff0c;基本不能用了&#xff0c;后期能用我再更新下。* Docker 官方和国内很多云服务商都提供了国内加速器服务&#xff0c;例如…

C++ | Leetcode C++题解之第479题最大回文数乘积

题目&#xff1a; 题解&#xff1a; class Solution { public:int largestPalindrome(int n) {if (n 1) {return 9;}int upper pow(10, n) - 1;for (int left upper;; --left) { // 枚举回文数的左半部分long p left;for (int x left; x > 0; x / 10) {p p * 10 x %…

Maxwell 底层原理 详解

Maxwell 是一个 MySQL 数据库的增量数据捕获&#xff08;CDC, Change Data Capture&#xff09;工具&#xff0c;它通过读取 MySQL 的 binlog&#xff08;Binary Log&#xff09;来捕获数据变化&#xff0c;并将这些变化实时地发送到如 Kafka、Kinesis、RabbitMQ 或其他输出端。…

信息搜集 --前端js打点

前端js打点 什么是js渗透测试 在Javascript中也存在变量和函数&#xff0c;当存在可控变量及函数调用即可参数漏洞JS开发的WEB应用和PHP&#xff0c;JAVA,NET等区别在于即没有源代码&#xff0c;也可以通过浏览器的查看源代码获取真实的点。获取URL&#xff0c;获取Js敏感信息&…

永磁同步电机控制算法--非线性自抗扰NLADRC转速环控制器(1)

一、原理介绍 ADRC由跟踪微分器(TD)、扩张状态观测器(ESO)和非线性状态误差反馈控制律(NLSEF)三部分组成。 其中SMC-LSEF表达式为: e3v1-z1u0ζsign(s)-ks-ce3u(u0-z2)/b 二、仿真验证 在MATLAB/simulink里面验证所提算法&#xff0c;采用和实验中一致的控制周期1e-4&#…

Navigation2 算法流程

转自 https://zhuanlan.zhihu.com/p/405670882 此文仅作学习笔记 启动流程 在仿真环境中启动导航包的示例程序&#xff0c;执行nav2_bringup/bringup/launch/tb3_simulation_launch.py文件。ROS2的launch文件支持采用python语言来编写以支持更加复杂的功能&#xff0c;本文件…