机器学习——6.模型训练案例: 预测儿童神经缺陷分类TD/ADHD

案例目的

有一份EXCEL标注数据,如下,训练出合适的模型来预测儿童神经缺陷分类。

参考文章:机器学习——5.案例: 乳腺癌预测-CSDN博客

代码逻辑步骤

  1. 读取数据
  2. 训练集与测试集拆分
  3. 数据标准化
  4. 数据转化为Pytorch张量
  5. label维度转换
  6. 定义模型
  7. 定义损失计算函数
  8. 定义优化器
  9. 定义梯度下降函数
  10. 模型训练(正向传播、计算损失、反向传播、梯度清空)
  11. 模型测试
  12. 精度计算

代码实现

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


df = pd.read_excel('/Users/guojun/Desktop/Learning/machine_learning/Preprocess_Without_WDE_Channels_Data.xlsx')

X = df[df.columns[0:8]].values
mapping = {"TD":0,"ADHD":1}
Y = df["Class"].replace(mapping)

# 数据集拆分
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.2,random_state=5)
Y_train = Y_train.to_numpy()
Y_test = Y_test.to_numpy()

# 数据标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)


# 转化为张量
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
Y_train = torch.from_numpy(Y_train.astype(np.float32))
Y_test = torch.from_numpy(Y_test.astype(np.float32))

# 真值转为为二维数据
Y_train = Y_train.view(Y_train.shape[0],-1)
Y_test = Y_test.view(Y_test.shape[0],-1)

# 定义模型
class Model(torch.nn.Module):
    def __init__(self,n_input_features):
        super(Model,self).__init__()
        self.linear = torch.nn.Linear(n_input_features,1)
        
    def forward(self,x):
        return torch.sigmoid(self.linear(x))

model = Model(X_train.shape[1])
# 定义损失函数
loss = torch.nn.BCELoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

# 梯度下降函数
def gradient_descent():
    # 预测Y值
    pre_y = model(X_train)
    # 计算损失
    l = loss(pre_y,Y_train)
    # 反向传播
    l.backward()
    # 梯度更新
    optimizer.step()
    # 梯度清空
    optimizer.zero_grad()
    return l,list(model.parameters())

# 模型训练
for i in range(10000):
    l,p = gradient_descent()
    print(l,p)

# 模型测试
mapping = {0:"TD",1:"ADHD"}
index = np.random.randint(0,X_test.shape[0])
pre_y = model(X_test[index])
pre_y = mapping[int(pre_y.round().item())]
gt_y = mapping[int(Y_test[index].item())]
print(pre_y,gt_y)


# 计算模型准确率
pres_y = model(X_test).round()
result = np.where(pres_y==Y_test,1,0)
ac = np.sum(result)/result.size
print(ac)

 即使调整参数后,损失在0.68左右就不会再下降了。

最终的准确率只有54%-60%,我会在后面的笔记中使用深度神经网络来重新训练,提升模型精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/611720.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL慢查询SQL优化

一、慢查询日志 描述:通过慢查询日志等定位那些执行效率较低的SQL语句 查看 # 慢查询是否开启 show variables like slow_query_log%; # 慢查询超时时间 show variables like long_query_time%;执行SQL 开启慢查询日志 set global slow_query_log ON;设置慢查…

Secure Transformer Inference Made Non-interactive

目录 1.概述2.Attention2.1 Matrix multiplication (ciphertext-plaintext).2.2 Matrix multiplication (ciphertext-ciphertext)2.3 Placement of bootstrapping3.SIMD密文压缩和解压缩4.SIMD槽折叠5.实验结果 1.概述 我们提出了NEXUS,这是第一个用于安全变压器推…

pdf编辑软件,四款软件让你轻松玩转PDF编辑!

在信息爆炸的当今时代,PDF格式文档因其跨平台、不易被篡改的特性而深受大家喜爱。然而,如何高效地编辑PDF文档却成为许多人的难题。今天,我将为大家推荐四款实用的PDF编辑软件,让你轻松玩转PDF编辑,告别繁琐操作&#…

【Linux】为什么有僵尸状态,什么是僵尸进程,造成危害以及如何避免“内存泄漏”问题详解

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

基于Springboot的校园竞赛管理系统(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的校园竞赛管理系统(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构…

HarmonyOS开发案例:【生活健康app之编写通用工具类】(5)

本节将介绍日志打印、时间换算等通用工具类的编写和使用,工具类可以简化应用代码编写和业务流程处理。 日志类 日志类Logger旨在提供一个全局的日志打印、日志管理的地方,既可以规范整个应用的日志打印,也方便日后对日志工具类进行修改&…

02.文件IO

文件描述符 表述打开的文件的 它是open函数的返回值,一个进程启动之后,会默认打开3个文件标识符 0标准输入,1标准输出,2标准错误 新的打开的文件返回文件描述符表中未使用过的最小的文件描述符 open函数 用来打开或者新建一个文件…

vue3实现动态表格

vue3结合element-plus实现动态表格&#xff0c;可添加、删除、对单行数据判断。 实现效果&#xff1a;查看源代码 实现代码&#xff1a; <div class"arrTable-Box"><el-table :data"tableData" border max-height"250"><el-t…

思维导图如何用AI生成?借助这几款工具

思维导图如何用AI生成&#xff1f;在数字化时代&#xff0c;思维导图作为一种高效的信息组织与展示工具&#xff0c;被广泛应用于学习、工作和项目管理中。随着人工智能技术的飞速发展&#xff0c;AI生成思维导图已成为现实&#xff0c;极大地提升了创建思维导图的效率和创意。…

2024数维杯A题可运行思路代码文章成品

为了能够精确地确定飞行器在三维空间中的位置&#xff0c;理论上至少需要从三个不同位置的发射源接收TOA数据。下面是使用TOA数据确定位置所需的计算基础和原理&#xff1a; 单个TOA数据&#xff1a; 单个TOA测量可以确定接收器与发射源之间的距离&#xff0c;这在三维空间中形…

合并两个有序链表(C语言)———链表经典算法题

题目描述​​​​​​21. 合并两个有序链表 - 力扣&#xff08;LeetCode&#xff09;&#xff1a; 答案展示: 迭代&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* mergeTwoLis…

BBS客户端服务器的编写

根据网络编程中的内容&#xff0c;我们本篇文章将讲解一个bbs通信的项目&#xff0c;首先让我们了解一下什么是bbs. 一、bbs介绍 BBS&#xff0c;即Bulletin Board System的缩写&#xff0c;中文译为“电子公告板系统”或“网络论坛”。它是一个在网络上进行信息交流和讨论的…

STM32MP157_程序烧录

STM32MP157_程序烧录 说明&#xff1a; 1、使用emmc作为存储媒介&#xff0c;emmc是核心板上的存储颗粒空间有8GB 2、SD卡作为存储媒介&#xff0c;底板上有SD卡的插槽 emmc方式 软件&#xff1a;烧录软件使用STM32CubeProgrammer 连接线&#xff1a;硬件连接线使用type_c数据线…

RTSP/Onvif安防监控系统EasyNVR级联视频上云系统EasyNVS报错“Login error”的原因排查与解决

EasyNVR安防视频云平台是旭帆科技TSINGSEE青犀旗下支持RTSP/Onvif协议接入的安防监控流媒体视频云平台。平台具备视频实时监控直播、云端录像、云存储、录像检索与回看、告警等视频能力&#xff0c;能对接入的视频流进行处理与多端分发&#xff0c;包括RTSP、RTMP、HTTP-FLV、W…

jenkins使用gitLab(极狐)认证登陆

jenkins安装 GitLab Authentication插件 我因为java版本和最新GitLab Authentication 1.19版本不兼容&#xff0c;选择了本地安装 找个历史版本1.13版本&#xff0c;然后下载到电脑上 - 本地上传插件并安装 在极狐上创建一个应用 - 配置应用信息 应用名&#xff1a;jenkinsLo…

2024年最新方法下载钉钉群直播回放

链接&#xff1a;百度网盘 请输入提取码 提取码&#xff1a;1234 --来自百度网盘超级会员V10的分享 1.首先解压好所有的压缩包&#xff0c;这个压缩包里面还套着一共逍遥一仙下载器压缩包&#xff0c;也解压 2.进入逍遥一仙下载器文件夹&#xff0c;打开M3U8 V1.4.8 0508.e…

找不到msvcp140.dll无法执行代码的原因分析及修复方法

当用户在尝试运行某些应用程序或游戏时&#xff0c;可能会遇到系统弹出错误提示&#xff0c;显示“找不到msvcp140.dll无法执行代码”这一错误信息&#xff0c;它会导致程序无法正常启动。为了解决这个问题&#xff0c;我经过多次尝试和总结&#xff0c;找到了以下五种解决方法…

宏集Panorama SCADA软件获BACnet BTL认证

Panorama 获得BACnet BTL认证 建筑物的组件&#xff08;空调系统、照明传感器等&#xff09;能否使用共同通讯协议&#xff1f;这正是标准化 BACnet协议&#xff08;Building Automation and Control Networks&#xff09;所提供的功能。该协议旨在实现建筑物中各种设备和系统…

初探 JUC 并发编程:读写锁 ReentrantReadWriteLock 原理(8000 字源码详解)

本文中会涉及到一些前面 ReentrantLock 中学到的内容&#xff0c;先去阅读一下我关于独占锁 ReentrantLock 的源码解析阅读起来会更加清晰。 初探 JUC 并发编程&#xff1a;独占锁 ReentrantLock 底层源码解析 6.4&#xff09;读写锁 ReentrantReadWriteLock 原理 前面提到的 R…

谈基于ATTCK框架的攻击链溯源

引言 网络安全在当今数字化时代变得尤为关键&#xff0c;而MITRE公司开发的ATT&CK框架则成为了安全专业人员的重要工具。ATT&CK是一种广泛使用的攻击行为分类和描述框架。其目的在于提供一个共同的语言&#xff0c;使安全专业人员能够更好地理解攻击者的行为和目标&…