RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')

# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_excel('dia.xls')
df

二 数据处理分析

# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
 Age  Gender  Ethnicity  EducationLevel        BMI  Smoking   
0      73       0          0               2  22.927749        0  \
1      89       0          0               0  26.827681        0   
2      73       0          3               1  17.795882        0   
3      74       1          0               1  33.800817        1   
4      89       0          0               0  20.716974        0   
...   ...     ...        ...             ...        ...      ...   
2144   61       0          0               1  39.121757        0   
2145   75       0          0               2  17.857903        0   
2146   77       0          0               1  15.476479        0   
2147   78       1          3               1  15.299911        0   
2148   72       0          0               2  33.289738        0   

      AlcoholConsumption  PhysicalActivity  DietQuality  SleepQuality  ...   
0              13.297218          6.327112     1.347214      9.025679  ...  \
1               4.542524          7.619885     0.518767      7.151293  ...   
2              19.555085          7.844988     1.826335      9.673574  ...   
3              12.209266          8.428001     7.435604      8.392554  ...   
4              18.454356          6.310461     0.795498      5.597238  ...   
...                  ...               ...          ...           ...  ...   
2144            1.561126          4.049964     6.555306      7.535540  ...   
2145           18.767261          1.360667     2.904662      8.555256  ...   
2146            4.594670          9.886002     8.120025      5.769464  ...   
2147            8.674505          6.354282     1.263427      8.322874  ...   
2148            7.890703          6.570993     7.941404      9.878711  ...   

      FunctionalAssessment  MemoryComplaints  BehavioralProblems       ADL   
0                 6.518877                 0                   0  1.725883  \
1                 7.118696                 0                   0  2.592424   
2                 5.895077                 0                   0  7.119548   
3                 8.965106                 0                   1  6.481226   
4                 6.045039                 0                   0  0.014691   
...                    ...               ...                 ...       ...   
2144              0.238667                 0                   0  4.492838   
2145              8.687480                 0                   1  9.204952   
2146              1.972137                 0                   0  5.036334   
2147              5.173891                 0                   0  3.785399   
2148              6.307543                 0                   1  8.327563   

      Confusion  Disorientation  PersonalityChanges   
0             0               0                   0  \
1             0               0                   0   
2             0               1                   0   
3             0               0                   0   
4             0               0                   1   
...         ...             ...                 ...   
2144          1               0                   0   
2145          0               0                   0   
2146          0               0                   0   
2147          0               0                   0   
2148          0               1                   0   

      DifficultyCompletingTasks  Forgetfulness  Diagnosis  
0                             1              0          0  
1                             0              1          0  
2                             1              0          0  
3                             0              0          0  
4                             1              0          0  
...                         ...            ...        ...  
2144                          0              0          1  
2145                          0              0          1  
2146                          0              0          1  
2147                          0              1          1  
2148                          0              1          0  

[2149 rows x 33 columns]

三 探索性数据分析 

1.得病分布

res = df.groupby('Diabetes')['Age'].count()
print(res)

plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,
        colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1,  0),
        wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

 2.BMI分布直方图

# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

3.年龄分布直方图 

# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

 

四 构建划分数据集 

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

sc = StandardScaler()
X = sc.fit_transform(X)

# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)

# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 五 训练模型

1.构建模型

# 构建模型
class model_rnn(nn.Module):
        def __init__(self):
                super(model_rnn, self).__init__()
                self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)
                self.fc0 = nn.Linear(200,50)
                self.fc1 = nn.Linear(50,2)

        def forward(self,x):
                out , hidden1 = self.rnn0(x)
                out = self.fc0(out)
                out = self.fc1(out)
                return out

model = model_rnn().to(device)
print(model)
model_rnn(
  (rnn0): RNN(32, 200, batch_first=True)
  (fc0): Linear(in_features=200, out_features=50, bias=True)
  (fc1): Linear(in_features=50, out_features=2, bias=True)
)

2.训练函数 

'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)   # 训练集的大小
    num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)

    train_acc,train_loss = 0,0  # 初始化训练损失和正确率

    for x,y in dataloader:    # 获取数据
        X,y = x.to(device),y.to(device)

        # 计算预测误差
        pred = model(X)   # 网络输出
        loss = loss_fn(pred,y)   # 计算误差

        # 反向传播
        optimizer.zero_grad()    # grad属性归零
        loss.backward()   # 反向传播
        optimizer.step()   # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc,train_loss

3.测试函数 

# 测试循环
def valid(dataloader,model,loss_fn):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)

    test_loss, test_acc = 0, 0  # 初始化训练损失和正确率

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target = imgs.to(device),target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred,target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc,test_loss

4.正式训练 

loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)

    model.eval()
    epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))

print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04
Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04
Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04
Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04
Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04
Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04
Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04
Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04
Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04
Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04
Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04
Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04
Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04
Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04
Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04
Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04
Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04
Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04
Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04
Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04
Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04
Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04
Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04
Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04
Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04
Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04
Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04
Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04
Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04
==================== Done ====================

六 结果可视化 

1.Loss和Acc图

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

 2.调用模型预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0
========================================
0:未患病
1:已患病

3.绘制混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)

pred = model(X_test.to(device)).argmax(1).cpu().numpy()

print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)

from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay

# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)

plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为==============
X_test.shape: torch.Size([215, 32])
y_test.shape: torch.Size([215])

==========输出数据shape为==============
pred.shape: (215,)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/959372.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ---面试题

常见面试题 1.MQ的作用及应用场景 类似问题:项目什么情况下用到了MQ,为什么要用MQ MQ的主要应用场景,消息队列的应用场景,为什么说消息队列可以削峰 首先MQ是一种用来接收和转发消息的队列,常见的应用常见如下&…

全球AI模型百科全书,亚马逊云科技Bedrock上的100多款AI模型

今天小李哥给大家介绍的是亚马逊云科技上的AI模型管理平台Amazon Bedrock上的Marketplace,这是亚马逊云科技在今年re:Invent发布的一个全新功能,将亚马逊的电商基因带到了其云计算平台,让我们能够通过Amazon Bedrock访问100多种流行、新兴和专…

八种排序算法【C语言实现】

系列文章目录 🎈 🎈 我的CSDN主页:OTWOL的主页,欢迎!!!👋🏼👋🏼 🎉🎉我的C语言初阶合集:C语言初阶合集,希望能…

在 vscode + cmake + GNU 工具链的基础上配置 JLINK

安装 JLINK JLINK 官网链接 下载安装后找到安装路径下的可执行文件 将此路径添加到环境变量的 Path 中。 创建 JFlash 项目 打开 JFlash,选择新建项目 选择单片机型号 在弹出的窗口中搜索单片机 其他参数根据实际情况填写 新建完成: 接下来设置…

PyQt5之QtDesigner的若干配置和使用

1.描述 QtDesigner是一个可视化工具,可以通过该工具设计页面 2.简单使用 1.下载PyQt5-tools pip install pyqt5-tools 2.打开designer.exe文件 我采用的是虚拟环境,该文件位于C:\Users\24715\anaconda3\envs\pyqt\Lib\site-packages\qt5_applicatio…

【高项】6.3 排列活动顺序 ITTO

输入 项目管理计划组件: ① 进度管理计划;② 范围基准 项目文件: ① 假设日志;② 活动属性;③ 活动清单;④ 里程碑清单 工具与技术 紧前关系绘图法(PDM) ① 完成到开始&…

Elasticsearch 自定义分成器 拼音搜索 搜索自动补全 Java对接

介绍 通常用于将文档中的文本数据拆分成易于索引的词项(tokens)。有时,默认的分词器无法满足特定应用需求,这时就可以创建 自定义分词器 来实现定制化的文本分析。 自定义分词器组成 Char Filters(字符过滤器&#x…

Leecode刷题C语言之完成所有交易的初始最少钱数

执行结果:通过 执行用时和内存消耗如下&#xff1a; long long minimumMoney(int** transactions, int transactionsSize,int* transactionsColSize) {long long total_lose 0;int res 0;for (int i 0; i < transactionsSize; i) {int cost transactions[i][0];int cas…

案例研究丨浪潮云洲通过DataEase推进多维度数据可视化建设

浪潮云洲工业互联网有限公司&#xff08;以下简称为“浪潮云洲”&#xff09;成立于2018年&#xff0c;定位于工业数字基础设施建设商、具有国际影响力的工业互联网平台运营商、生产性互联网头部服务商。截至目前&#xff0c;浪潮云洲工业互联网平台连续五年入选跨行业跨领域工…

C++——动态管理

目录 一、C/C内存分布二、C语言中动态内存管理方式&#xff1a;malloc/calloc/realloc/free三、C内存管理方式3.1 new/delete操作内置类型3.2 new和delete操作自定义类型 四、operator new与operator delete函数4.1 operator new与operator delete函数 五、new和delete的实现原…

three.js+WebGL踩坑经验合集(4.1):THREE.Line2的射线检测问题(注意本篇说的是Line2,同样也不是阈值方面的问题)

上篇大家消化得如何了&#xff1f; 笔者说过&#xff0c;1级编号不同的两篇博文相对独立&#xff0c;所以这里笔者还是先给出完整代码&#xff0c;哪怕跟&#xff08;3&#xff09;没有太大区别。 这里我们把线的粗细调成5&#xff08;排除难选中的因素&#xff09;&#xff…

Vue2下篇

插槽&#xff1a; 基本插槽&#xff1a; 普通插槽&#xff1a;父组件向子组件传递静态内容。基本插槽只能有一个slot标签&#xff0c;因为这个是默认的位置&#xff0c;所以只能有一个 <!-- ParentComponent.vue --> <template> <ChildComponent> <p>…

【Unity3D】aab包太大无法上传Google问题

目录 一、勾选Split Application Binary&#xff0c;Unity直接打aab包 勾选Split Application Binary选项的影响 不勾选Split Application Binary选项的影响 总结 2、导出Android工程打包aab 一、勾选Split Application Binary&#xff0c;Unity直接打aab包 超出150MB部分…

第6章 数据结构—列表与列表项讲解--总结

整理 野火 《FreeRTOS 内核实现与应用开发实战指南》—基于野火 STM32 全系列&#xff08;M3/4/7&#xff09;开发板 文章目录 第6章 数据结构—列表与列表项讲解--总结6.1 C 语言链表简介6.1.1 单向链表6.1.2 双向链表6.1.3 链表与数组的对比 6.2 FreeRTOS 中链表的实现6.2.1 …

强化学习-Deep Q Network

文章目录 Deep Q Networkzip(*batch)的内部实现假设&#xff1a;结果&#xff1a; Deep Q Network 这种方式很适合格子游戏。因为格子游戏中的每一个格子就是一个状态&#xff0c;这是离散的&#xff0c;但在现实生活中&#xff0c;很多状态并不是离散而是连续的。所以我们可以…

C语言-构造数据类型

1、构造数据类型 结构体、共用体、枚举。 2、结构体 1、结构体的定义 结构体是一个自定义的复合数据类型&#xff0c;它允许将不同类型的数据组合在一起。 struct 结构体名 {数据类型1 成员变量1;数据类型2 成员变量2;数据类型3 成员变量3;数据类型4 成员变量4; } 2、结构体变…

FPGA实现任意角度视频旋转(二)视频90度/270度无裁剪旋转

本文主要介绍如何基于FPGA实现视频的90度/270度无裁剪旋转&#xff0c;旋转效果示意图如下&#xff1a; 为了实时对比旋转效果&#xff0c;采用分屏显示进行处理&#xff0c;左边代表旋转前的视频在屏幕中的位置&#xff0c;右边代表旋转后的视频在屏幕中的位置。 分屏显示的…

Spark/Kafka

文章目录 项目地址一、Spark1. RDD1.1 五大核心属性1.2 执行原理1.3 四种创建方式二、Kafka2.1 生产者(1)分区器(2)生产者提高吞吐量(3) 生产者数据可靠性数据传递语义幂等性和事务数据有序2.2 Broker(1)Broker工作流程(2)节点服役和退役2.3 副本(1)Follower故障细…

win32汇编环境,函数的编写与调用、传值或返回值等

;运行效果 ;win32汇编环境,函数的编写与调用、传值或返回值等 ;函数在被调用的时候&#xff0c;如果此函数实体在前面&#xff0c;可以不用声明。如果实体在后面&#xff0c;则需要先声明。类似于下面的DlgProc函数&#xff0c;因为它的实体在后面&#xff0c;所以需要在调用之…

[Spring] Gateway详解

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…