梯度提升决策树(GBDT)

GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器(树),并把多颗决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。

原理

GBDT的核心思想是将多个弱学习器(通常是决策树)组合成一个强大的预测模型。具体而言,GBDT的定义如下:

  • 初始化:首先GBDT使用一个常数(通常是目标变量的平均值,也可以是其他合适的初始值。初始预测值代表了模型对整体数据的初始估计。)作为初始预测值。这个初始预测值代表了我们对目标变量的初始猜测。
  • 迭代训练:GBDT是一个迭代算法,通过多轮迭代来逐步改进模型。在每一轮迭代中,GBDT都会训练一棵新的决策树,目标是减少前一轮模型的残差(或误差)。残差是实际观测值与当前模型预测值之间的差异,新的树将学习如何纠正这些残差。
  • 1)计算残差:在每轮迭代开始时,计算当前模型对训练数据的预测值与实际观测值之间的残差。这个残差代表了前一轮模型未能正确预测的部分。
  • 2):训练新的决策树:使用计算得到的残差作为新的目标变量,训练一棵新的决策树。这棵树将尝试纠正前一轮模型的错误,以减少残差。
  • 3):更新模型:将新训练的决策树与之前的模型进行组合。具体地,将新树的预测结果与之前模型的预测结果相加,得到更新后的模型。
  • 集成:最终,GBDT将所有决策树的预测结果相加,得到最终的集成预测结果。这个过程使得模型能够捕捉数据中的复杂关系,从而提高了预测精度。

每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
GBDT的弱分类器使用的是树模型。
在这里插入图片描述
如图是一个非常简单的帮助理解的示例,我们用GBDT去预测年龄:

  • 第一个弱分类器(第一棵树)预测一个年龄(如20岁),计算发现误差有10岁;
  • 第二棵树预测拟合残差,预测值6,计算发现差距还有4岁;
  • 第三棵树继续预测拟合残差,预测值3,发现差距只有1岁了;
  • 第四课树用1岁拟合剩下的残差,完成。
    最终,四棵树的结论加起来,得到30岁这个标注答案(实际工程实现里,GBDT是计算负梯度,用负梯度近似残差)。

GBDT的优势

1)高精度预测能力
GBDT以其强大的集成学习能力而闻名,能够处理复杂的非线性关系和高维数据。它通常能够在分类和回归任务中取得比单一决策树或线性模型更高的精度。
2)对各种类型数据的适应性
GBDT对不同类型的数据(数值型、类别型、文本等)具有很好的适应性,不需要对数据进行特别的预处理。这使得它在实际应用中更易于使用。

  • 处理混合数据类型
    在现实世界的数据挖掘任务中,常常会遇到混合数据类型的情况。例如,在房价预测问题中,特征既包括数值型(如房屋面积和卧室数量),还包括类别型(如房屋位置和建筑类型)和文本型(如房屋描述)数据。GBDT能够直接处理这些混合数据,无需将其转换成统一的格式。这简化了数据预处理的步骤,节省了建模时间。
  • 不需要特征缩放
    与某些机器学习算法(如支持向量机和神经网络)不同,GBDT不需要对特征进行缩放或归一化。这意味着特征的尺度差异不会影响模型的性能。在一些算法中,特征的尺度不一致可能导致模型无法正确学习,需要进行繁琐的特征缩放操作。而GBDT能够直接处理原始特征,减轻了数据预处理的负担。

3)在数据不平衡情况下的优势

  • 加权损失函数
    GBDT使用的损失函数允许对不同类别的样本赋予不同权重。这意味着模型可以更关注少数类别,从而提高对不平衡数据的处理能力。
  • 逐步纠正错误
    GBDT的迭代训练方式使其能够逐步纠正前一轮模型的错误。在处理不平衡数据时,模型通常会在多轮迭代中重点关注难以分类的少数类别样本。通过逐步纠正错误,模型逐渐提高了对少数类别的分类能力,从而改善了预测结果。

4)鲁棒性与泛化能力
GBDT在处理噪声数据和复杂问题时表现出色。其鲁棒性使得它能够有效应对数据中的异常值或噪声,不容易受到局部干扰而产生较大的预测误差。
5)特征重要性评估
GBDT可以提供有关特征重要性的信息,帮助用户理解模型的决策过程。通过分析每个特征对模型预测的贡献程度,用户可以识别出哪些特征对于问题的解决最为关键。这对于特征选择、模型解释和问题理解非常有帮助。
6)高效处理大规模数据
尽管GBDT通常是串行训练的,每棵树依赖于前一棵树的结果,但它可以高效处理大规模数据。这得益于GBDT的并行化实现和轻量级的决策树结构。此外,GBDT在处理大规模数据时可以通过特征抽样和数据抽样来加速训练过程,而不会牺牲太多预测性能。

关键参数与调优

参数解释

n_estimators:迭代次数,即最终模型中弱学习器的数量。
learning_rate(学习率):每次迭代时,新决策树对预测结果的贡献权重。
max_depth:决策树的最大深度,控制着树的复杂度。
min_samples_split:节点分裂所需的最小样本数。
subsample:用于训练每棵树的样本采样比例,小于1时可实现随机梯度提升。
loss:即我们GBDT算法中的损失函数。分类模型和回归模型的损失函数是不一样的。1)对于分类模型,有对数似然损失函数"deviance"和指数损失函数"exponential"两者输入选择。默认是对数似然损失函数"deviance"。2)对于回归模型,有均方差"ls", 绝对损失"lad", Huber损失"huber"和分位数损失“quantile”。默认是均方差"ls"。一般来说,如果数据的噪音点不多,用默认的均方差"ls"比较好。如果是噪音点较多,则推荐用抗噪音的损失函数"huber"。而如果我们需要对训练集进行分段预测的时候,则采用“quantile”。
subsample:用于训练每个弱学习器的样本比例。减小该参数可以降低方差,但也可能增加偏差。

调优策略

学习率与迭代次数的平衡:较低的学习率通常需要更多的迭代次数来达到较好的性能,但能减少过拟合的风险。
树的深度与样本采样:合理限制树的深度和采用子采样可以提高模型的泛化能力。
早停机制:在验证集上监控性能,一旦性能不再显著提升,则提前终止训练。

为了解决GBDT的效率问题,LightGBM和XGBoost等先进框架被提出,它们通过优化算法结构(如直方图近似)、并行计算等方式显著提高了训练速度。

python实现

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor

GradientBoostingClassifier(*, loss='deviance', learning_rate=0.1, n_estimators=100, subsample=1.0, 
        criterion='friedman_mse', min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
         max_depth=3, min_impurity_decrease=0.0, min_impurity_split=None, init=None, random_state=None,
         max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', 
         validation_fraction=0.1, n_iter_no_change=None, tol=0.0001, ccp_alpha=0.0)
         
GradientBoostingRegressor(*, loss='ls', learning_rate=0.1, n_estimators=100, subsample=1.0, criterion='friedman_mse', min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3, min_impurity_decrease=0.0, min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, tol=0.0001, ccp_alpha=0.0) 

回归实现

# 导入必要的库
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error

# 加载波士顿房价数据集
boston = load_boston()
X, y = boston.data, boston.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化GBDT回归器
gbdt_reg = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 训练模型
gbdt_reg.fit(X_train, y_train)

# 预测
y_pred = gbdt_reg.predict(X_test)

# 计算并打印均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

GBDT正则化

针对GBDT正则化,我们通过子采样比例方法和定义步长v方法来防止过拟合。

  • 子采样比例:通过不放回抽样的子采样比例(subsample),取值为(0,1]。如果取值为1,则全部样本都使用。如果取值小于1,利用部分样本去做GBDT的决策树拟合。选择小于1的比例可以减少方差,防止过拟合,但是会增加样本拟合的偏差。因此取值不能太低,推荐在[0.5, 0.8]之间。
  • 定义步长v:针对弱学习器的迭代,我们定义步长v,取值为(0,1]。对于同样的训练集学习效果,较小的v意味着我们需要更多的弱学习器的迭代次数。通常我们用步长和迭代最大次数一起来决定算法的拟合效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/694668.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】轻松拿捏-联合体

谢谢观看!希望以下内容帮助到了你,对你起到作用的话,可以一键三连加关注!你们的支持是我更新地动力。 因作者水平有限,有错误还请指出,多多包涵,谢谢! 联合体 一、联合体类型的声明二…

DDMA信号处理以及数据处理的流程---随机目标生成

Hello,大家好,我是Xiaojie,好久不见,欢迎大家能够和Xiaojie一起学习毫米波雷达知识,Xiaojie准备连载一个系列的文章—DDMA信号处理以及数据处理的流程,本系列文章将从目标生成、信号仿真、测距、测速、cfar…

RabbitMQ(五)集群配置、Management UI

文章目录 一、安装RabbitMQ1、前置要求2、安装docker版复制第一个节点的.erlang.cookie进入各节点命令行配置集群检查集群状态 3、三台组合集群安装版rabbitmq节点rabbitmq-node2节点rabbitmq-node3节点 二、负载均衡:Management UI1、说明2、安装HAProxy3、修改配置…

找出链表倒数第k个元素-链表题

LCR 140. 训练计划 II - 力扣(LeetCode) 快慢指针。快指针臂慢指针快cnt个元素到最后; class Solution { public:ListNode* trainingPlan(ListNode* head, int cnt) {struct ListNode* quick head;struct ListNode* slow head;for(int i …

Spring配置多数据库(采用数据连接池管理)

一,前言 大家在开发过程中,如果项目大一点就会遇到一种情况,同一个项目中可能会用到很多个数据源,那么这篇文章,博主为大家分享在spring应用中如何采用数据库连接池的方式配置配置多数据源。 本篇文章采用大家用的最…

音视频转为文字SuperVoiceToText

音视频转为文字SuperVoiceToText,它能够把视频或语音文件高效地转换为文字,它是基于最为先进的 AI 大模型,通过在海量语音资料上进行训练学习而造就,具备极为卓越的识别准确率。 不仅如此,它支持包括汉语、英语、日语…

Java Set系列集合的使用规则和场景(HashSet,LinkedHashSet,TreeSet)

Set集合 package SetDemo;import java.util.HashSet; import java.util.Iterator; import java.util.Set;public class SetDemo {public static void main(String[] args) {/*Set集合的特点:1.Set系列集合的特点:Set集合是一个存储元素不能重复的集合方…

如何下载BarTender软件及详细安装步骤

BarTender是美国海鸥科技推出的一款优秀的条码打印软件,应用于 WINDOWS95 、 98 、 NT 、 XP 、 2000 、 2003 和 3.1 版本, 产品支持广泛的条形码码制和条形码打印机, 不但支持条形码打印机而且支持激光打印机,还为世界知名品牌条…

Modbus主站和从站的区别

Modbus主站,从站 在工业自动化领域,Modbus是一种常用的通信协议,用于设备之间的数据交换。在Modbus通信中,主站和从站是两个关键的角色。了解主站和从站之间的区别对正确配置和管理Modbus网络至关重要。 Modbus主站的特点和功能 1.通信请求发…

嵌入式仪器模块:音频综测仪和自动化测试软件

• 24 位分辨率 • 192 KHz 采样率 • 支持多种模拟/数字音频信号的输入/输出 应用场景 • 音频信号分析:幅值、频率、占空比、THD、THDN 等指标 • 模拟音频测试:耳机、麦克风、扬声器测试,串扰测试 • 数字音频测试:平板电…

【DrissionPage】Linux上如何将https改为http

最近有个老板找我做一个自动化的程序,要求部署到Linux上 这是一个http协议的网站,chrome在默认设置下,会将http的网站识别成不安全的内容,然后自动将http转化成https访问 但是,这个http的网站它的加载项里既有http的…

Python基础——字符串

一、Python的字符串简介 Python中的字符串是一种计算机程序中常用的数据类型【可将字符串看作是一个由字母、数字、符号组成的序列容器】,字符串可以用来表示文本数据。 通常使用一对英文的单引号()或者双引号(")…

三维重建 虚拟内窥镜(VE)是什么?怎么实现 使用场景

1.虚拟内窥镜: 就是利用计算机图形学、虚拟现实、图像处理和科学可视化等信息处理技术仿真光学内窥镜对病人进行诊断的一种技术。 VE(Virtual Endoscopy),虚拟内镜技术。这种CT重建图像可以模拟各种内镜检查的效果,它是假设视线位于所要观察…

【数据结构】队列——循环队列(详解)

目录 0 循环队列 1 特定条件下循环队列队/空队满判断条件 1.1 队列为空的条件 1.2 队列为满的条件 2 循环队列的实现 3 示例 4 注意事项 0 循环队列 循环队列(Circular Queue)是队列的一种实现方式,它通过将队列存储空间的最后一…

【Java】Hutool发送邮件功能

目录 开通qq邮箱的stmp实战pom.xmlapplication.ymlcontrollerservice实体类辅助类 需要实现一个通过邮箱找回密码的功能 正常来说,找回密码的验证码,一般来说,都是通过手机号来找回的居多,那为什么会有通过邮箱找回的方式该说不说…

Luminar Neo - AI智能修图软件超越PS和LR,简单易用又高效!

很多人都想美化自己的风景和人物的图片,得到更加美丽耀眼的效果。然而,专业摄影师和设计师在电脑上使用的后期工具如 Photoshop 和 LightRoom 过于复杂。 通常为了一些简单的效果,你必须学习许多教程。而一些针对小白用户的“一键式美颜/美化…

Python第二语言(八、Python包)

目录 1. 什么是Python包 2. 创包步骤 2.1 new包 2.2 查看创建的包 2.3 拖动文件到包下 3. 导入包 4. 安装第三方包 4.1 什么是第三方包 4.2 安装第三方包-pip 4.3 pip网络优化 1. 什么是Python包 包下有__init__.py就是包,无__init__.py就是文件夹。于Ja…

开发没有尽头,尽力既是完美

最近遇到了一些难题,开发系统总有一些地方没有考虑周全,偏偏用户使用的时候“完美复现”了这个隐藏的Bug...... 讲道理创业一年之久为了生存,我一直都有在做复盘,复盘的核心就是:如何提升营收、把控开发质量&#xff0…

gdb调试器

目录 一、前言 debug和release 二、调试操作 2.1、退出 quit 2.2、调试 run 2.3、打断点 b 2.4、查看断点 info b 2.5、查看代码 l 2.6、删除断点 d 2.7、逐过程 n 2.8、打印变量内容 p 2.9、逐语句(进入函数) s 2.10、查看函数调用堆栈 …