梯度提升树GBDT系列算法

Boosting方法的基本元素与基本流程💫

在Boosting集成算法当中,我们逐一建立多个弱评估器(基本是决策树),并且下一个弱评估器的建立方式依赖于上一个弱评估器的评估结果,最终综合多个弱评估器的结果进行输出。

这个过程相当于有意地加重“难以被分类正确的样本”的权重,同时降低“容易被分类正确的样本”的权重,而将后续要建立的弱评估器的注意力引导到难以被分类正确的样本上。

不同的Boosting算法之间的核心区别就在于上一个弱评估器的结果具体如何影响下一个弱评估器的建立过程。此外,Boosting算法在结果输出方面表现得十分多样。早期的Boosting算法的输出一般是最后一个弱评估器的输出,当代Boosting算法的输出都会考虑整个集成模型中全部的弱评估器。一般来说,每个Boosting算法会其以独特的规则自定义集成输出的具体形式

💥由此,我们可以确立任意boosting算法的三大基本元素以及boosting算法自适应建模的基本流程:

  • 损失函数L(x,y) :用以衡量模型预测结果与真实结果的差异
  • 弱评估器f(x) :(一般为)决策树,不同的boosting算法使用不同的建树过程
  • 综合集成结果H(x):即集成算法具体如何输出集成结果

几乎所有boosting算法的原理都围绕这三大元素构建。在此三大要素基础上,所有boosting算法都遵循以下流程进行建模:

💢正如之前所言,Boosting算法之间的不同之处就在于使用不同的方式来影响后续评估器的构建。无论boosting算法表现出复杂或简单的流程,其核心思想一定是围绕上面这个流程不变的。 

梯度提升树GBDT的基本思想 

梯度提升树(Gradient Boosting Decision Tree,GBDT)是提升法中的代表性算法,它即是当代强力的XGBoost、LGBM等算法的基石,也是工业界应用最多、在实际场景中表现最稳定的机器学习算法之一。在最初被提出来时,GBDT被写作梯度提升机器(Gradient Boosting Machine,GBM),它融合了Bagging与Boosting的思想、扬长避短,可以接受各类弱评估器作为输入,在后来弱评估器基本被定义为决策树后,才慢慢改名叫做梯度提升树。
作为一个Boosting算法,GBDT中自然也包含Boosting三要素,并且也遵循boosting算法的基本流程进行建模,不过需要注意的是,GBDT在整体建树过程中有几个关键点:

  • 弱评估器💯                       
  • GBDT的弱评估器输出类型不再与整体集成算法输出类型一致。对于基础的Bagging和Boosting算法来说,当集成算法执行的是回归任务时,弱评估器也是回归器,当集成算法执行分类任务时,弱评估器也是分类器。但对于GBDT而言,无论GBDT整体在执行回归/分类/排序任务,弱评估器一定是回归器。GBDT通过sigmoid或softmax函数输出具体的分类结果,但实际弱评估器一定是回归器。     
  • 损失函数💯
  • 在GBDT算法中,可以选择的损失函数非常多(‘deviance’, ‘exponential’),是因为这个算法从数学原理上做了改进——损失函数的范围不在局限于固定或者单一的某个损失函数,而是推广到了任意可微的函数。

  • GBDT分类器损失函数:‘deviance’, ‘exponential’

    GBDT回归器损失函数:‘squared_error’, ‘absolute_error’, ‘huber’, ‘quantile’

  • 拟合残差💯

GBDT依然自适应调整弱评估器的构建,但不再通过调整数据分布来间接影响后续弱评估器,而是通过修改后续弱评估器的拟合目标来直接影响后续弱评估器的结构。

具体地来说,在GBDT当中,我们不修改样本权重,但每次用于建立弱评估器的是样本以及当下集成输出与真实标签的差异()。这个差异在数学上被称之为残差(Residual),因此GBDT不修改样本权重,而是通过拟合残差来影响后续弱评估器结构

GBDT加入了随机森林中随机抽样的思想,在每次建树之前,允许对样本和特征进行抽样来增大弱评估器之间的独立性(也因此可以有袋外数据集)。虽然Boosting算法不会大规模地依赖于类似于Bagging的方式来降低方差,但由于Boosting算法的输出结果是弱评估器结果的加权求和,因此Boosting原则上也可以获得由“平均”带来的小方差红利。当弱评估器表现不太稳定时,采用与随机森林相似的方式可以进一步增加Boosting算法的稳定性

梯度提升树GBDT的快速实现         

 

sklearn当中集成了GBDT分类与GBDT回归,我们使用如下两个类来调用它们: 

  • class sklearn.ensemble.GradientBoostingClassifier
  • class sklearn.ensemble.GradientBoostingRegressor                       
  • GBDT算法的超参数看起来很多,但是仔细观察的话,你会发现GBDT回归器与GBDT分类器的超参数高度一致。并且所有超参数都给出了默认值,需要人为输入的参数为0。所以,就算是不了解参数的含义,我们依然可以直接使用sklearn库来调用GBDT算法。

使用GBDT完成分类任务

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.datasets import load_wine
from sklearn.ensemble import GradientBoostingClassifier as GBC
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.tree import DecisionTreeClassifier as DTC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

X,y = load_wine(return_X_y=True,as_frame=True)

# 切分训练集和测试集
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=0)

# 使用GBDT完成对红酒数据集的预测
clf = GBC()    #实例化GBDT分类器,并使用默认参数
clf = clf.fit(Xtrain,Ytrain)

train_score = clf.score(Xtrain,Ytrain)
test_score = clf.score(Xtest,Ytest)
print(f"GBDT在训练集上的预测准确率为{train_score}")
print(f"GBDT在测试集上的预测准确率为{test_score}")
  • GBDT在训练集上的预测准确率为1.0
  • GBDT在测试集上的预测准确率为0.9629629629629629

梯度提升分类与其他算法的对比

dtc = DTC(random_state=0) #实例化单棵决策树
dtc = dtc.fit(Xtrain,Ytrain)
score_dtc = dtc.score(Xtest,Ytest)

rfc = RFC(random_state=0) #实例化随机森林
rfc = rfc.fit(Xtrain,Ytrain)
score_rfc = rfc.score(Xtest,Ytest)

gbc = GBC(random_state=0) #实例化GBDT
gbc = gbc.fit(Xtrain,Ytrain)
score_gbc = gbc.score(Xtest,Ytest)
# 默认使用准确度(accuracy)作为评分方式,即预测正确的样本数占总样本数的比例

print("决策树:{}".format(score_dtc))
print("随机森林:{}".format(score_rfc))
print("GBDT:{}".format(score_gbc))
  • 决策树:0.9444444444444444
  • 随机森林:0.9814814814814815
  • GBDT:0.9629629629629629

💥画出决策树、随机森林和GBDT在十组五折交叉验证下的效果对比

score_dtc = []
score_rfc = []
score_gbc = []

for i in range(10):
    dtc = DTC()
    cv1 = cross_val_score(dtc,X,y,cv=5)
    score_dtc.append(cv1.mean())
    
    rfc = RFC()
    cv2 = cross_val_score(rfc,X,y,cv=5)
    score_rfc.append(cv2.mean())
    
    gbc = GBC()
    cv3 = cross_val_score(gbc,X,y,cv=5)
    score_gbc.append(cv3.mean())

plt.plot(range(1,11),score_dtc,label = "DecisionTree")
plt.plot(range(1,11),score_rfc,label = "RandomForest")
plt.plot(range(1,11),score_gbc,label = "GBDT")
plt.legend(bbox_to_anchor=(1.4,1))
plt.show()

使用GBDT完成回归任务

X,y = fetch_california_housing(return_X_y=True,as_frame=True)

Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=0)

# 使用GBDT完成对加利福尼亚房屋数据集的预测

gbr = GBR(random_state=0) #实例化GBDT
gbr = gbr.fit(Xtrain,Ytrain)
r2_gbdt = gbr.score(Xtest,Ytest) # 回归器默认评估指标为R2
r2_gbdt
# 0.7826346388949185

# 计算GBDT回归器的评估指标:均方误差MSE
from sklearn.metrics import mean_squared_error
pred = gbr.predict(Xtest)
MSE = mean_squared_error(Ytest,pred)
MSE

# 0.28979949770874125

梯度提升回归与其他算法的对比

import time
modelname = ["DecisionTree","RandomForest","GBDT","RF-D"]
models = [DTR(random_state=0)
          ,RFR(random_state=0)
          ,GBR(random_state=0)
          ,RFR(random_state=0,max_depth=3)]

for name,model in zip(modelname,models):
    start = time.time()
    result = cross_val_score(model,X,y,cv=5,scoring="neg_mean_squared_error").mean()
    end = time.time()-start
    print(name)
    print("\t MSE:{:.3f}".format(abs(result)))
    print("\t time:{:.2f}s".format(end))
    print("\n")

结果:

DecisionTree
	 MSE:0.818
	 time:0.66s


RandomForest
	 MSE:0.425
	 time:70.69s


GBDT
	 MSE:0.412
	 time:16.84s


RF-D
	 MSE:0.639
	 time:11.49s

 

对比决策树和随机森林来说,GBDT默认参数状态下已经能够达到很好的效果。

梯度提升树GBDT的重要参数和属性

由于GBDT超参数数量较多,因此我们可以将GBDT的参数分为以下5大类别,其他属性我们下次再进行分析验证💨

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/698771.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

div拖拽改变宽高

目前是点击按照右下角边框拖拽改变大小 如果要点击按住内容拖拽也改变大小 则传入事件 $ event即可 startDrag(index,$event)和 drag(index,$event) 以下代码可直接使用 <template><div><div>目前是点击按照右下角边框拖拽改变大小 <br> 如果要点击按…

如何系统学习vue框架

前言 在软件开发的浩渺星海中&#xff0c;编程规范如同航海的罗盘&#xff0c;为我们指引方向&#xff0c;确保我们的代码之旅能够顺利、高效地到达目的地。无论是个人开发者还是大型团队&#xff0c;编程规范都是提升代码质量、保障项目成功不可或缺的一环。 因此&#xff0c…

MySQL表设计经验汇总篇

文章目录 1、命名规范2、选择合适的字段类型3、主键设计要合理4、选择合适的字段长度5、优先考虑逻辑删除&#xff0c;而不是物理删除6、每个表都需要添加通用字段7、一张表的字段不宜过多8、定义字段尽可能not null9、合理添加索引10、通过业务字段冗余来减少表关联11、避免使…

【漏洞复现】宏景eHR openFile.jsp 任意文件读取漏洞

0x01 产品简介 宏景eHR人力资源管理软件是一款人力资源管理与数字化应用相融合&#xff0c;满足动态化、协同化、流程化、战略化需求的软件。 0x02 漏洞概述 宏景eHR openFile.jsp 接口处存在任意文件读取漏洞&#xff0c;未经身份验证攻击者可通过该漏洞读取系统重要文件(如…

树-二叉树的最大路径和

一、问题描述 二、解题思路 因为各个节点的值可能为负数&#xff0c;初始化res(最大路径和)的值为最小整数&#xff1a;Integer.MIN_VALUE 我们这里使用深度遍历&#xff08;递归&#xff09;的方法&#xff0c;先看某一个子树的情况&#xff1a; 这里有一个技巧&#xff0c;…

纯音听力检测图有哪些形状?

纯音听力检测图有哪些形状&#xff1f; 当选择合适的放大装置时,听力图形状很重要。例如,听力图为下降型或高频陡降型的顾客可能受益于开放式验配,即可以泄漏低频声音,并对高频声音进行放大。 听力图形状分为以下几种&#xff1a; 下降型:低频听力较好,高频听力较差 上升型…

icloud 邮箱登入失败

APP NAME mail2HOSTING APP NAME cloudos2CLIENT TIME Tue Jun 11 2024 09:00:47 GMT0800 (中国标准时间) (1718067647802)USER AGENT Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36HOSTNAME www.icloud.…

20个国家科学数据中心(下)

15、国家海洋科学数据中心 平台网址&#xff1a;https://mds.nmdis.org.cn/ 简介&#xff1a;国家海洋科学数据中心由国家海洋信息中心牵头&#xff0c;采用“主中心分中心数据节点”模式&#xff0c;联合相关涉海单位、科研院所和高校等十余家单位共同建设。以“建立…

普通人想要自学ai,该如何入手,看完这篇你就懂了,零基础教程!

学会了AIGC之后&#xff0c;我只想说&#xff1a;无敌是多么寂寞&#xff1f; 之前我整理一篇会议记录起码要2小时。现在交给AI &#xff0c;5分钟搞定&#xff1b; 之前整理账目总是出错&#xff0c;现在利用AI财务整合器&#xff0c;轻松解决统计难题&#xff1b; 之前写个…

逻辑题 :谁是凶手?

设 &#xff1a; A 甲是凶手 这个是题中1的 如果甲不是凶手 我们假设A条件是甲是凶手&#xff0c;取反就可是甲不是凶手&#xff0c;B 乙是凶手 这个是题中1的 如果乙或者是凶手 我们假设B条件乙是凶手C 乙是知情人 这个是题中1的 或者是知情人 我们假设C条件乙是知情人D …

RT-DETR 详解之 Uncertainty-minimal Query Selection

引言 在上一章博客中博主已经完成查询去噪向量构造部分的讲解&#xff08;DeNoise&#xff09;在本篇博客中&#xff0c;我们将进行Uncertainty-minimal Query Selection创新点的讲解。 Uncertainty-minimal Query Selection是RT-DETR提出的第二个创新点&#xff0c;其作用是…

大模型的高考数学成绩单:及格已经非常好了

让考生头皮发麻的高考数学&#xff0c;可难倒了顶尖 AI 大模型。 一年一度的高考即将落幕&#xff0c;衷心希望各位考生都超常发挥&#xff0c;考出满意的好成绩&#xff01;&#xff01; 和往年一样&#xff0c;除了让 AI 大模型写写高考作文&#xff0c;我们也选取了六家国…

超级会员小程序积分商城源码系统 前后端分离 带完整的安装代码包以及搭建部署

系统概述 在数字化时代&#xff0c;积分商城作为企业增强用户粘性、促进消费的重要工具&#xff0c;其重要性不言而喻。为了帮助企业快速构建高效、易用的积分兑换平台&#xff0c;我们特别推出了“超级会员小程序积分商城源码系统”&#xff0c;采用前后端分离架构设计&#…

硬盘危机:磁盘损坏无法打开的应对策略

在数字化时代&#xff0c;磁盘作为数据存储和传输的核心设备&#xff0c;其稳定性和安全性至关重要。然而&#xff0c;在日常使用过程中&#xff0c;我们时常会面临磁盘损坏无法打开的困境。这不仅会影响我们的工作效率&#xff0c;还可能造成重要数据的丢失。本文将深入探讨磁…

java中toCharArray用法详细分析(全)

将字符串中的字符转换为字符数组 public char[] toCharArray()括号内没有参数 返回值是一个字符数组接收 1.函数代码&#xff1a; package com.ithehema;public class Test {public static void main(String[] args) {String b"ss123456";char []cb.toCharArray()…

SCI三区快速检索——期刊推荐IEEE Access

IEEE Access 是一个综合性的、开放获取的多学科工程和技术期刊&#xff0c;由美国电气电子工程师协会&#xff08;IEEE&#xff09;出版。以下是关于IEEE Access期刊的一些关键信息&#xff1a; 1. 开放获取【即开源】 IEEE Access 是开放获取&#xff08;Open Access&#x…

【Linux】生产者消费者模型——阻塞队列BlockQueue

> 作者&#xff1a;დ旧言~ > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;理解【Linux】生产者消费者模型——阻塞队列BlockQueue。 > 毒鸡汤&#xff1a;有些事情&#xff0c;总是不明白&#xff0c;所以我不会坚持。早安!…

【Git】Windows下使用可视化工具Sourcetree

参考&#xff1a;[最全面] SourceTree使用教程详解(连接远程仓库&#xff0c;克隆&#xff0c;拉取&#xff0c;提交&#xff0c;推送&#xff0c;新建/切换/合并分支&#xff0c;冲突解决&#xff0c;提交PR) 1.Git工具–sourcetree 之前文章介绍过Linux系统中的Git工具&…

C++ 11 【可变参数模板】【lambda】

&#x1f493;博主CSDN主页:麻辣韭菜&#x1f493;   ⏩专栏分类&#xff1a;C修炼之路⏪   &#x1f69a;代码仓库:C高阶&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学习更多C知识   &#x1f51d;&#x1f51d; 目录 前言 一、新的类功能 1.1默认成员函数—…

78%的中小企业担心网络攻击会导致其业务中断,中小企业如何确保网络安全?

在当今数字化时代&#xff0c;网络攻击手段层出不穷&#xff0c;网络安全事件不断增加&#xff0c;根据ConnectWise的一项调查数据显示&#xff0c;94%的中小企业至少经历过一次网络攻击&#xff0c;78%的中小企业担心网络攻击会导致其业务中断&#xff0c;企业声誉受损。由此&…