梯度提升用于高效的分类与回归

使用 决策树(Decision Tree) 实现 梯度提升(Gradient Boosting) 主要是模拟 GBDT(Gradient Boosting Decision Trees) 的原理,即:

  1. 第一棵树拟合原始数据
  2. 计算残差(负梯度方向)
  3. 用新的树去拟合残差
  4. 累加所有树的预测值
  5. 重复步骤 2-4,直至达到指定轮数

下面是一个 纯 Python + PyTorch 实现 GBDT(梯度提升决策树) 的代码示例。

1. 纯 Python 实现梯度提升决策树

import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 参数
n_trees = 50   # 多少棵树
learning_rate = 0.1  # 学习率

# 初始化预测值(全部为 0)
y_pred_train = np.zeros_like(y_train)
y_pred_test = np.zeros_like(y_test)

# 训练梯度提升决策树
trees = []
for i in range(n_trees):
    residuals = y_train - y_pred_train  # 计算残差(负梯度方向)
    
    tree = DecisionTreeRegressor(max_depth=3)  # 这里使用较浅的树
    tree.fit(X_train, residuals)  # 让树学习残差
    trees.append(tree)
    
    # 更新预测值(累加弱学习器的结果)
    y_pred_train += learning_rate * tree.predict(X_train)
    y_pred_test += learning_rate * tree.predict(X_test)

    # 计算损失
    mse = mean_squared_error(y_train, y_pred_train)
    print(f"Iteration {i+1}: MSE = {mse:.4f}")

# 计算最终测试集误差
final_mse = mean_squared_error(y_test, y_pred_test)
print(f"\nFinal Test MSE: {final_mse:.4f}")

代码解析

  • 第一步:构建一个基础决策树 DecisionTreeRegressor(max_depth=3)
  • 第二步:每棵树学习前面所有树的残差(负梯度方向)。
  • 第三步:训练 n_trees 棵树,每棵树的预测结果乘以 learning_rate 累加到最终预测值。
  • 第四步:每次迭代后更新预测值,减少误差。

2. 用 PyTorch 实现 GBDT

虽然 GBDT 主要基于决策树,但如果你希望用 PyTorch 计算梯度并模拟 GBDT,可以如下操作:

  • 用 PyTorch 计算 损失函数的梯度
  • sklearn.tree.DecisionTreeRegressor 拟合梯度
  • 用 PyTorch 计算最终误差
import torch
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 参数
n_trees = 50  # 多少棵树
learning_rate = 0.1  # 学习率

# 转换数据为 PyTorch 张量
X_train_torch = torch.tensor(X_train, dtype=torch.float32)
y_train_torch = torch.tensor(y_train, dtype=torch.float32)

# 初始化预测值
y_pred_train = torch.zeros_like(y_train_torch)

# 训练 GBDT
trees = []
for i in range(n_trees):
    # 计算梯度(残差)
    residuals = y_train_torch - y_pred_train

    # 用决策树拟合梯度
    tree = DecisionTreeRegressor(max_depth=3)
    tree.fit(X_train, residuals.numpy())
    trees.append(tree)

    # 更新预测值
    y_pred_train += learning_rate * torch.tensor(tree.predict(X_train), dtype=torch.float32)

    # 计算损失
    mse = mean_squared_error(y_train, y_pred_train.numpy())
    print(f"Iteration {i+1}: MSE = {mse:.4f}")

PyTorch 实现的关键点

  1. y_train_torch - y_pred_train 计算 损失的梯度
  2. DecisionTreeRegressor 作为弱学习器,拟合梯度
  3. 预测值 += learning_rate * tree.predict(X_train)

3. 结合 PyTorch 和 XGBoost

如果你要 结合 PyTorch 和 GBDT,可以先用 XGBoost 训练 GBDT,再用 PyTorch 进行深度学习:

import xgboost as xgb
import torch.nn as nn
import torch.optim as optim
import torch
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


# 训练 XGBoost 作为特征提取器
xgb_model = xgb.XGBRegressor(n_estimators=50, max_depth=3, learning_rate=0.1)
xgb_model.fit(X_train, y_train)

# 提取 XGBoost 叶子节点特征
X_train_leaves = xgb_model.apply(X_train)
X_test_leaves = xgb_model.apply(X_test)

# 定义 PyTorch 神经网络
class NeuralNet(nn.Module):
    def __init__(self, input_size):
        super(NeuralNet, self).__init__()
        self.fc = nn.Linear(input_size, 1)

    def forward(self, x):
        return self.fc(x)

# 训练 PyTorch 神经网络
model = NeuralNet(X_train_leaves.shape[1])
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

X_train_tensor = torch.tensor(X_train_leaves, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)

for epoch in range(100):
    optimizer.zero_grad()
    output = model(X_train_tensor)
    loss = loss_fn(output, y_train_tensor)
    loss.backward()
    optimizer.step()

print("Training complete!")

结论

方法适用场景备注
纯 Python GBDT适合小规模数据使用 sklearn.tree.DecisionTreeRegressor
PyTorch 计算梯度 + GBDT适合梯度优化实验计算梯度后用 DecisionTreeRegressor 训练
XGBoost + PyTorch适合大规模数据先用 XGBoost 提取特征,再用 PyTorch 训练

如果你的数据是结构化的(如 表格数据),建议 直接使用 XGBoost/LightGBM,再结合 PyTorch 进行特征工程或后处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/961828.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

node 爬虫开发内存处理 zp_stoken 作为案例分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言 主要说3种我们补环境过后如果用…

python——Django 框架

Django 框架 1、简介 Django 是用python语言写的开源web开发框架,并遵循MVC设计。 Django的**主要目的是简便、快速的开发数据库驱动的网站。**它强调代码复用,多个组件可以很方便的以"插件"形式服务于整个框架,Django有许多功能…

嵌入式知识点总结 Linux驱动 (五)-linux内核

针对于嵌入式软件杂乱的知识点总结起来,提供给读者学习复习对下述内容的强化。 目录 1.内核镜像格式有几种?分别有什么区别? 2.内核中申请内存有哪几个函数?有什么区别? 3.什么是内核空间,用户空间&…

SpringBoot+Vue的理解(含axios/ajax)-前后端交互前端篇

文章目录 引言SpringBootThymeleafVueSpringBootSpringBootVue(前端)axios/ajaxVue作用响应式动态绑定单页面应用SPA前端路由 前端路由URL和后端API URL的区别前端路由的数据从哪里来的 Vue和只用三件套axios区别 关于地址栏url和axios请求不一致VueJSPS…

网络直播时代的营销新策略:基于受众分析与开源AI智能名片2+1链动模式S2B2C商城小程序源码的探索

摘要:随着互联网技术的飞速发展,网络直播作为一种新兴的、极具影响力的媒体形式,正逐渐改变着人们的娱乐方式、消费习惯乃至社交模式。据中国互联网络信息中心数据显示,网络直播用户规模已达到3.25亿,占网民总数的45.8…

将ollama迁移到其他盘(eg:F盘)

文章目录 1.迁移ollama的安装目录2.修改环境变量3.验证 背景:在windows操作系统中进行操作 相关阅读 :本地部署deepseek模型步骤 1.迁移ollama的安装目录 因为ollama默认安装在C盘,所以只能安装好之后再进行手动迁移位置。 # 1.迁移Ollama可…

《Trustzone/TEE/安全从入门到精通-标准版》

CSDN学院课程连接:https://edu.csdn.net/course/detail/39573 讲师介绍 拥有 12 年手机安全、汽车安全、芯片安全开发经验,擅长 Trustzone/TEE/ 安全的设计与开发,对 ARM 架构的安全领域有着深入的研究和丰富的实践经验,能够将复杂的安全知识和处理器架构知识进行系统整…

手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码)

手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码) 目录 手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码)Stable Diffusion 原理图Stable Diffusion的原理解释Stable Diffusion 和Di…

基于 AWS SageMaker 对 DeepSeek-R1-Distilled-Llama-8B 模型的精调与实践

在当今人工智能蓬勃发展的时代,语言模型的性能优化和定制化成为研究与应用的关键方向。本文聚焦于 AWS SageMaker 平台上对 DeepSeek-R1-Distilled-Llama-8B 模型的精调实践,详细探讨这一过程中的技术细节、操作步骤以及实践价值。 一、实验背景与目标 …

三、SysTick系统节拍定时器

3.1 SysTick简介 系统节拍定时器SysTick是ARM Cortex-M0内核提供的一个24位递减定时器,当计数值达到0时产生中断,可以为操作系统和其他管理软件提供固定时间的中断。 当系统节拍定时器被被使能时,定时器从重装值递减计数,到0进中断…

算法每日双题精讲 —— 前缀和(【模板】一维前缀和,【模板】二维前缀和)

在算法竞赛与日常编程中,前缀和是一种极为实用的预处理技巧,能显著提升处理区间和问题的效率。今天,我们就来深入剖析一维前缀和与二维前缀和这两个经典模板。 一、【模板】一维前缀和 题目描述 给定一个长度为 n n n 的整数数组 a a a&…

学习数据结构(2)空间复杂度+顺序表

1.空间复杂度 (1)概念 空间复杂度也是一个数学表达式,表示一个算法在运行过程中根据算法的需要额外临时开辟的空间。 空间复杂度不是指程序占用了多少bytes的空间,因为常规情况每个对象大小差异不会很大,所以空间复杂…

MybatisX插件快速创建项目

一、安装插件 二、创建一个数据表测试 三、IDEA连接Mysql数据库 四、选择MybatiX构造器 五、配置参数 六、项目结构

基于SpringBoot的假期周边游平台的设计与实现(源码+SQL脚本+LW+部署讲解等)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

Java设计模式:结构型模式→组合模式

Java 组合模式详解 1. 定义 组合模式(Composite Pattern)是一种结构型设计模式,它允许将对象组合成树形结构以表示“部分-整体”的层次。组合模式使得客户端能够以统一的方式对待单个对象和对象集合的一致性,有助于处理树形结构…

FastReport.NET控件篇之富文本控件

简介 FastReport.NET 提供了 RichText 控件,用于在报表中显示富文本内容。富文本控件支持多种文本格式(如字体、颜色、段落、表格、图片等),非常适合需要复杂排版和格式化的场景。 富文本控件(RichText)使用场景不多&#xff0c…

单片机基础模块学习——NE555芯片

一、NE555电路图 NE555也称555定时器,本文主要利用NE555产生方波发生电路。整个电路相当于频率可调的方波发生器。 通过调整电位器的阻值,方波的频率也随之改变。 RB3在开发板的位置如下图 测量方波信号的引脚为SIGHAL,由上面的电路图可知,NE555已经构成完整的方波发生电…

(done) MIT6.S081 2023 学习笔记 (Day6: LAB5 COW Fork)

网页:https://pdos.csail.mit.edu/6.S081/2023/labs/cow.html 任务1:Implement copy-on-write fork(hard) (完成) 现实中的问题如下: xv6中的fork()系统调用会将父进程的用户空间内存全部复制到子进程中。如果父进程很大,复制过程…

三天急速通关JavaWeb基础知识:Day 1 后端基础知识

三天急速通关JavaWeb基础知识:Day 1 后端基础知识 0 文章说明1 Http1.1 介绍1.2 通信过程1.3 报文 Message1.3.1 请求报文 Request Message1.3.2 响应报文 Response Message 2 XML2.1 介绍2.2 利用Java解析XML 3 Tomcat3.1 介绍3.2 Tomcat的安装与配置3.3 Tomcat的项…

SQLServer 不允许保存更改(主键)

在我们进行数据库表格编辑的时候,往往会出现同一个名字,就比如我们的账号一样,我们在注册自己QQ的时候,我们通常注册过的账号,别人就不能注册了,这是为了保证严密性 所以我们需要点击表格>右键>设计 点击某一列>右键>设计主键 当我们Ctrls 保存的时候回弹出下…