R6:LSTM实现糖尿病探索与预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、实验目的:

学习使用LSTM对糖尿病进行探索预测

二、实验环境:

  • 语言环境:python 3.8
  • 编译器:Jupyter notebook
  • 深度学习环境:Pytorch
    • torch==2.4.0+cu124
    • torchvision==0.19.0+cu124

三、数据预处理

逻辑回归在二分类问题中应用广泛;KNN(K 近邻算法)、SVM(支持向量机)、决策树、贝叶斯分类器、随机森林和 XGBoost(极端梯度提升树)都是常见的用于结构化数据分类的算法。

本次实验我们采用 LSTM(长短期记忆网络)进行分类预测。LSTM 主要用于处理序列数据,虽然在一些特定情况下可以对序列数据进行分类,但对于一般的二维结构化数据,上述提到的传统分类算法通常更加合适。二维结构化数据通常指表格形式的数据,每一行代表一个样本,每一列代表一个特征,对于这类数据,传统的机器学习分类算法在计算效率和可解释性方面往往具有优势。

在这里插入图片描述

1. 设置GPU、导入数据

#设置GPU 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision,torch 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device 
#导入数据
import numpy   as np
import pandas  as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['savefig.dpi'] = 500 #图片像素
plt.rcParams['figure.dpi'] = 500 #分辨率

plt.rcParams['font.sans-serif'] = ['SimHei'] #用来正常显示中文标签

import warnings
warnings.filterwarnings('ignore')

DataFrame = pd.read_excel('diabetes.xls')
DataFrame.head()

在这里插入图片描述

DataFrame.shape
(1006, 16)    

2. 数据检查

#查看数据是否有缺失值
print('数据缺失值--------------------------')
print(DataFrame.isnull().sum())

在这里插入图片描述

#查看数据是否有重复值
print('数据重复值--------------------------')
print('数据集的重复值为:'f'{DataFrame.duplicated().sum()}')

在这里插入图片描述

3. 数据分布分析

feature_map = { '年龄': '年龄',
    '高密度脂蛋白胆固醇': '高密度脂蛋白胆固醇',
    '低密度脂蛋白胆固醇': '低密度脂蛋白胆固醇',
    '极低密度脂蛋白胆固醇': '极低密度脂蛋白胆固醇',
    '甘油三酯': '甘油三酯',
    '总胆固醇': '总胆固醇',
    '脉搏': '脉搏',
    '舒张压':'舒张压',
    '高血压史':'高血压史',
    '尿素氮':'尿素氮',
    '尿酸':'尿酸',
    '肌酐':'肌酐',
    '体重检查结果':'体重检查结果'}

plt.figure(figsize=(15,10))

for i, (col, col_name) in enumerate(feature_map.items(), 1):
    plt.subplot(3,5,i)
    sns.boxplot(x=DataFrame['是否糖尿病'], y=DataFrame[col])
    plt.title(f'{col_name}的箱线图', fontsize=14)
    plt.ylabel('数值', fontsize=12)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
plt.tight_layout()
plt.show()

在这里插入图片描述
以下是分析箱线图的方法,并以年龄的箱线图为例进行介绍:

一、认识箱线图的组成部分

  1. 箱体:箱体的上下边界分别代表数据的上四分位数(Q3)和下四分位数(Q1)。箱体中间的线通常代表中位数。
  2. whiskers(须):从箱体延伸出去的线段,代表数据的范围。一般来说,须的长度是由一些特定的规则决定的,常见的是 1.5 倍的四分位距(IQR,即 Q3 - Q1)。超出须范围的数据点被视为异常值,可能会以单独的点显示。

二、分析年龄箱线图的具体步骤

  1. 观察中位数:

    • 首先找到箱体中间的线,它代表了年龄数据的中位数。如果这条线在箱线图的中间位置附近,说明数据分布相对较为对称;如果偏向箱体的上边界或下边界,则说明数据可能存在偏斜。
    • 假设年龄箱线图中,中位数线靠近箱体上边界,这可能意味着年龄数据整体上偏大,即大部分人的年龄较高。
  2. 分析箱体长度:

    • 箱体的长度反映了数据的离散程度。如果箱体较短,说明数据比较集中;如果箱体较长,说明数据的分散程度较大。
    • 例如,如果年龄箱线图的箱体较短,说明年龄数据相对集中在一个较小的范围内。
  3. 观察须的长度:

    • 须的长度可以让你了解数据的整体范围。较长的须表示数据的范围较大;较短的须可能意味着数据比较集中在一个较小的区间内。
    • 如果年龄箱线图的须较长,说明年龄数据的跨度较大,可能有一些年龄较大或较小的极端值。
  4. 检查异常值:

    • 异常值通常以单独的点显示在箱线图之外。观察异常值的数量和分布,可以了解数据中是否存在极端情况。
    • 如果年龄箱线图中有一些异常值,需要进一步分析这些异常值的来源,例如是否是由于数据录入错误或者特殊的个体情况导致的。
  5. 比较不同组别的箱线图:

    • 如果有多个组别的年龄箱线图,可以比较它们的中位数、箱体长度、须的长度和异常值情况,以了解不同组之间年龄分布的差异。
    • 例如,比较糖尿病患者和非糖尿病患者的年龄箱线图,看是否存在明显的差异。如果糖尿病患者的年龄箱线图中位数较高,箱体较长,可能说明糖尿病患者的年龄普遍较大。

通过以上步骤,你可以对年龄箱线图进行较为全面的分析,了解年龄数据的分布特征和潜在的问题。对于其他变量的箱线图,也可以采用类似的方法进行分析。

df_corr = DataFrame.drop(['卡号'],axis=1).corr()
plt.figure(figsize=(12,10))
plt.title('相关性热图')
sns.heatmap(df_corr,annot=True)
plt.show()

在这里插入图片描述

四、LSTM模型

#数据集构建

from sklearn.preprocessing import StandardScaler

# '高密度脂蛋白胆固醇'字段与糖尿病负相关,故而在 X 中去掉该字段
X = DataFrame.drop(['卡号','是否糖尿病','高密度脂蛋白胆固醇'],axis=1)
y = DataFrame['是否糖尿病']

# sc_X    = StandardScaler()
# X = sc_X.fit_transform(X)

X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)

train_X, test_X, train_y, test_y = train_test_split(X, y, 
                                                    test_size=0.2,
                                                    random_state=1)
train_X.shape, train_y.shape

from torch.utils.data import TensorDataset, DataLoader

train_dl = DataLoader(TensorDataset(train_X, train_y),batch_size=64,shuffle=False)
test_dl  = DataLoader(TensorDataset(test_X, test_y),batch_size=64,shuffle=False)
#定义模型
class model_lstm(nn.Module):
    def __init__(self):
        super(model_lstm, self).__init__()
        self.lstm0 = nn.LSTM(input_size=13,  hidden_size=200, num_layers=1, batch_first=True)
        self.lstm1 = nn.LSTM(input_size=200, hidden_size=200, num_layers=1, batch_first=True)
        self.fc0   = nn.Linear(200, 2)
        
    def forward(self, x):
        out, hidden1 = self.lstm0(x)
        out, _       = self.lstm1(out, hidden1)
        out          = self.fc0(out)
        return out
    
model = model_lstm().to(device)
model

在这里插入图片描述

五、训练模型

#定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss
#定义测试函数
def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss
#训练模型
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs     = 30

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
 
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
print("="*20, 'Done', "="*20)

在这里插入图片描述

六、模型评估

#Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、总结

分析数据可知存在有一定的过拟合迹象:

  • 随着训练的进行,训练准确率不断上升,而测试准确率在前期长时间停滞,后期虽然有所上升,但上升幅度小于训练准确率。这表明模型在训练集上的学习能力较强,但在测试集上的泛化能力相对较弱。
  • 训练损失持续下降,而测试损失在下降过程中出现了波动,并且在后期与训练损失的差距有一定程度的扩大。这也暗示模型可能过度拟合了训练数据,导致在测试集上的表现不如在训练集上的表现稳定。
  • 实验中尝试通过提高学习率至1e-3,可以将预测准确率提升到71.3%,而提高训练轮数则始终难以收敛。而在构建数据集部分,可以看到注释部分的代码为数据的标准化处理。
  • 除此之外,还可以考虑采用正则化方法、增加数据量、早停法等技术来缓解过拟合问题。

在划分数据集过程中添加标准化处理可以提升测试数据集准确率的原因主要有以下几点:

一、消除量纲影响

  1. 不同特征往往具有不同的量纲和尺度。例如,一个特征可能取值范围在 0 到 100 之间,而另一个特征可能取值在 0 到 1 之间。这会使得在某些算法中,具有较大数值范围的特征对模型的影响更大,从而可能导致模型偏向于这些特征,而忽略了其他重要特征的作用。
  2. 标准化处理将数据的各个特征转换到相同的尺度上,通常使得特征的均值为 0,标准差为 1。这样可以确保每个特征在模型中具有相对平等的影响力,避免了因量纲差异而导致的不公平性。

二、加速模型收敛

  1. 许多优化算法在处理标准化后的数据时能够更快地收敛。例如,梯度下降算法在标准化的数据上能够更有效地确定下降的方向和步长,因为数据的分布更加稳定,不会因为特征的尺度差异而导致梯度在不同方向上的变化幅度差异巨大。
  2. 当数据经过标准化后,模型在训练过程中可以更稳定地更新参数,减少了因数据尺度不一致而引起的震荡,从而更快地找到最优解,这也有助于提高模型在测试集上的准确率。

三、提高模型的泛化能力

  1. 标准化可以使模型对不同单位和尺度的输入数据具有更好的适应性,从而提高模型的泛化能力。如果模型在训练时只适应了特定尺度的数据集,那么在面对测试集上不同尺度的数据时,可能表现不佳。
  2. 标准化处理可以减少异常值对模型的影响。异常值在未标准化的数据中可能会对模型产生较大的干扰,而经过标准化后,异常值的影响相对减小,模型能够更加关注数据的整体分布特征,从而提高在测试集上的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/907376.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM+小程序的计算机实验室排课与查询管理系统(实验室2)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 1、管理员功能有个人中心,学生管理,教师管理,实验室信息管理,实验室预约管理,取消预约管理,实验课程管理&#xff0…

软件项目实施方案,实施计划,总体实施管理方案(word原件)

一、 概述 二、 项目介绍 2.1 概览 三、 项目实施 3.1 项目实施概况 3.2 项目实施管理原则 3.3 项目组织结构 3.4 项目团队 四、 项目实施计划 4.1 项目实施工作流程 4.2 项目软件部分进度安排 4.3 网络拓扑图 4.4 服务器需求清单 五、 人员培训 5.1 培训内容 5…

【日记】第一次觉得 “届く” 是一个很难的东西(1528 字)

正文 昨天晚上吃饭路上,听到喵喵叫。四处张望了一下,看见一只野猫。天黑,看不清具体什么样子。我冲它喵喵叫,试图走近它。跑掉了。 看来我还是不讨小动物喜欢呢(笑。 一大早去了医院。这次膝盖看了三个医生。 放射科报…

【Python单元测试】pytest框架单元测试 配置 命令行操作 测试报告 覆盖率

单元测试(unit test),简称UT。本文将介绍在Python项目中,pytest测试框架的安装,配置,执行,测试报告与覆盖率 pytest简介 pytest是一款流行的,简单易上手的单元测试框架,…

三次样条插值算法及推导过程

目录 1、定义 2、已知条件求解 3、具体推导 4、matlab案例 5、案例结果 6、matlab仿真 1、定义 给定 n 1 n1 n1个数据点,共有 n n n个区间,三次样条方程 S ( n ) S(n) S(n)满足以下条件:在每个分段区间内 ( x i , x i 1 ) (x_i,x_{i1}) (…

C#与C++结构体的交互

C#在和C进行交互时,有时候会需要传递结构体。 做一些总结,避免大家在用的时候踩坑。 一般情况 例如我们在C里定义了一个struct_basic结构体 1 struct struct_basic 2 { 3 WORD value_1; 4 LONG value_2; 5 DWORD value_3; 6 UINT v…

Flutter Color 大调整,需适配迁移,颜色不再是 0-255,而是 0-1.0,支持更大色域

在之前的 3.10 里, Flutter 的 Impeller 在 iOS 上支持了 P3 广色域图像渲染,但是当时也仅仅是当具有广色域图像或渐变时,Impeller 才会在 iOS 上显示 P3 的广色域的颜色,而如果你使用的是 Color API,会发现使用的还是…

Android平台RTSP转RTMP推送之采集麦克风音频转发

技术背景 RTSP转RTMP推送,好多开发者第一想到的是采用ffmpeg命令行的形式,如果对ffmpeg比较熟,而且产品不要额外的定制和更高阶的要求,未尝不可,如果对产品稳定性、时延、断网重连等有更高的技术诉求,比较…

【十九周】文献阅读:图像识别的深度残差学习

目录 摘要Abstract图像识别的深度残差学习研究背景研究动机解决办法Residual LearningShortcut Connections网络结构 实验结果代码实践论文原文总结 摘要 在之前对神经网络的基础学习中,师兄推荐了我去了解一下 ResNet。因此本周对 ResNet 的开山之作—Deep Residu…

MATLAB/Simulink学习|在Simulink中调用C语言-01使用C Function 实现比例运算

前面的博客中,提到如果想将Simulink仿真推进至硬件实验,需要将积木式的仿真搭建,变换成C语言实现,那么如何在Simulink中验证C代码的正确性呢?我将一边学习,一边更新,一边比较不同方法实现C语言&…

基于BP神经网络的手写体数字图像识别

基于BP神经网络的手写体数字图像识别 摘要 在信息化飞速发展的时代,光学字符识别是一个重要的信息录入与信息转化的手段,其中手写体数字的识别有着广泛地应用,如:邮政编码、统计报表、银行票据等等,因其广泛地应用范围…

分享一个免费的网页转EXE的工具

HTML2EXE是一款在Windows系统下将Web项目或网站打包成EXE执行程序的免费工具。这款工具能够将单页面应用、传统HTMLJavaScriptCSS生成的网站、Web客户端,以及通过现代前端框架(如Vue)生成的应用转换成独立的EXE程序运行。它支持将任何网站打包…

ubuntu安装与配置Nginx(2)

1. 配置 Nginx Nginx 的配置文件通常位于 /etc/nginx/nginx.conf,而虚拟主机的配置文件通常在 /etc/nginx/sites-available/ 和 /etc/nginx/sites-enabled/ 目录中。 在/etc/nginx/conf.d目录下新建xx.conf文件,配置文件, nginx -t 检查语法…

C++_day2

目录 1. 引用 reference(重点) 1.1 基础使用 1.2 特性 1.3 引用参数 2. C窄化(了解) 3. 输入(熟悉) 4. string 字符串类(掌握) 4.1 基础使用 4.2 取出元素 4.3 字符串与数字转换 5. …

JAVA WEB — HTML CSS 入门学习

本文为JAVAWEB 关于HTML 的基础学习 一 概述 HTML 超文本标记语言 超文本 超越文本的限制 比普通文本更强大 除了文字信息 还可以存储图片 音频 视频等标记语言 由标签构成的语言HTML标签都是预定义的 HTML直接在浏览器中运行 在浏览器解析 CSS 是一种用来表现HTML或XML等文…

第十五章 Vue工程化开发及Vue CLI脚手架

目录 一、引言 二、Vue CLI 基本介绍 三、安装Vue CLI 3.1. 安装npm和yarn 3.2. 安装Vue CLI 3.3. 查看 Vue 版本 四、创建启动工程 4.1. 创建项目架子 4.2. 启动工程 五、脚手架目录文件介绍 六、核心文件讲解 6.1. index.html 6.2. main.js 6.3. App.vue 一、…

【1个月速成Java】基于Android平台开发个人记账app学习日记——第4天,注册登录功能设计

24.11.03 1.修改项目目录 从今天开始将正式进行功能的设计,首先需要对原来的项目结构进行修改,主要是添加新的文件夹用于存放新的文件。下面进行展示和讲解: 我用红圈圈出了新添加的文件夹,介绍下它们都是干啥的: da…

动态库实现lua网络请求GET, POST, 下载文件

DLL需要使用的网络封装 WinHttp异步实现GET, POST, 多线程下载文件_webclient post下载文件-CSDN博客文章浏览阅读726次。基于WinHttp封装, 实现异步多线程文件下载, GET请求, POST请求_webclient post下载文件https://blog.csdn.net/Flame_Cyclone/article/details/142644088…

unet中的attn_processor的修改(用于设计新的注意力模块)

参考资料 文章目录 unet中的一些变量的数据情况attn_processorunet.configunet_sd 自己定义自己的attn Processor ,对原始的attn Processor进行修改 IP-adapter中设置attn的方法 参考的代码: 腾讯ailabipadapter 的官方训练代码 unet中的一些变量的数据…

深度学习基础—序列采样

引言 深度学习基础—循环神经网络(RNN)https://blog.csdn.net/sniper_fandc/article/details/143417972?fromshareblogdetail&sharetypeblogdetail&sharerId143417972&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link …