GBDT算法详解

GBDT算法详解

梯度提升决策树(Gradient Boosting Decision Trees,GBDT)是机器学习中一种强大的集成算法。它通过构建一系列的决策树,并逐步优化模型的预测能力,在各种回归和分类任务中取得了显著的效果。本文将详细介绍GBDT算法的原理,并展示其在实际数据集上的应用。

GBDT算法原理

GBDT是一种集成学习方法,通过逐步建立多个决策树,每棵树都在前一棵树的基础上进行改进。GBDT的基本思想是逐步减少残差(即预测误差),使模型的预测能力不断提高。

算法步骤

  1. 初始化模型:使用常数模型初始化,比如回归问题中可以用目标值的均值初始化模型。
  2. 计算残差:计算当前模型的残差,即预测值与真实值之间的差异。
  3. 拟合残差:用新的决策树拟合残差,并更新模型。
  4. 更新模型:将新决策树的预测结果加到模型中,以减少残差。
  5. 重复步骤2-4:直到达到预设的迭代次数或残差足够小。

公式表示

初始化模型:
F 0 ( x ) = arg ⁡ min ⁡ γ ∑ i = 1 n L ( y i , γ ) F_0(x) = \arg\min_{\gamma} \sum_{i=1}^{n} L(y_i, \gamma) F0(x)=argγmini=1nL(yi,γ)
对于每一次迭代 (m = 1, 2, \ldots, M):

  1. 计算负梯度(残差): r i m = − [ ∂ L ( y i , F ( x i ) ) ∂ F ( x i ) ] F ( x ) = F m − 1 ( x ) r_{im} = -\left[ \frac{\partial L(y_i, F(x_i))}{\partial F(x_i)} \right]_{F(x) = F_{m-1}(x)} rim=[F(xi)L(yi,F(xi))]F(x)=Fm1(x)

  2. 拟合一个新的决策树来预测残差: h m ( x ) = arg ⁡ min ⁡ h ∑ i = 1 n ( r i m − h ( x i ) ) 2 h_m(x) = \arg\min_{h} \sum_{i=1}^{n} (r_{im} - h(x_i))^2 hm(x)=arghmini=1n(rimh(xi))2

  3. 更新模型: F m ( x ) = F m − 1 ( x ) + ν h m ( x ) F_m(x) = F_{m-1}(x) + \nu h_m(x) Fm(x)=Fm1(x)+νhm(x)
    其中, ν \nu ν是学习率,控制每棵树对最终模型的贡献。

GBDT算法的特点

  1. 高准确性:GBDT通过逐步减少残差,不断优化模型,使其在很多任务中具有很高的准确性。
  2. 灵活性:GBDT可以处理回归和分类任务,并且可以使用各种损失函数。
  3. 鲁棒性:GBDT对数据的噪声和异常值有一定的鲁棒性。
  4. 可解释性:决策树本身具有一定的可解释性,通过特征重要性等方法可以解释GBDT模型。

GBDT参数说明

以下是GBDT(Gradient Boosting Decision Trees,梯度提升决策树)常用参数及其详细说明:

参数名称描述默认值示例
n_estimators树的棵数,提升迭代的次数100n_estimators=200
learning_rate学习率,控制每棵树对最终模型的贡献0.1learning_rate=0.05
max_depth树的最大深度,控制每棵树的复杂度3max_depth=4
min_samples_split分裂一个内部节点需要的最少样本数2min_samples_split=5
min_samples_leaf叶子节点需要的最少样本数1min_samples_leaf=3
subsample样本采样比例,用于训练每棵树1.0subsample=0.8
max_features寻找最佳分割时考虑的最大特征数Nonemax_features='sqrt'
loss要优化的损失函数devianceloss='exponential'
criterion分裂节点的标准friedman_msecriterion='mae'
init初始估计器Noneinit=some_estimator
random_state随机数种子,用于结果复现Nonerandom_state=42
verbose控制训练过程信息的输出频率0verbose=1
warm_start是否使用上次调用的解决方案来初始化训练Falsewarm_start=True
presort是否预排序数据以加快分裂查找deprecated-
validation_fraction用于提前停止训练的验证集比例0.1validation_fraction=0.2
n_iter_no_change如果在若干次迭代内验证集上的损失没有改善,则提前停止训练Nonen_iter_no_change=10
tol提前停止的阈值1e-4tol=1e-3
ccp_alpha最小成本复杂度修剪参数0.0ccp_alpha=0.01

通过合理调整这些参数,可以优化GBDT模型在特定任务和数据集上的性能。

GBDT算法在回归问题中的应用

在本节中,我们将使用波士顿房价数据集来展示如何使用GBDT算法进行回归任务。

导入库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score

加载和预处理数据

# 生成合成回归数据集
X, y = make_regression(n_samples=1000, n_features=20, noise=0.1, random_state=42)

# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

训练GBDT模型

# 训练GBDT模型
gbdt = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbdt.fit(X_train, y_train)

预测与评估

# 预测
y_pred = gbdt.predict(X_test)

# 评估
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f'Mean Squared Error: {mse:.2f}')
print(f'R^2 Score: {r2:.2f}')

特征重要性

# 特征重要性
# 特征重要性
feature_importances = gbdt.feature_importances_
plt.barh(range(X.shape[1]), feature_importances, align='center')
plt.yticks(np.arange(X.shape[1]), [f'Feature {i}' for i in range(X.shape[1])])
plt.xlabel('Feature Importance')
plt.ylabel('Feature')
plt.title('Feature Importances in GBDT')
plt.show()

在这里插入图片描述

GBDT算法在分类问题中的应用

在本节中,我们将使用20类新闻组数据集来展示如何使用GBDT算法进行文本分类任务。

导入库

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

加载和预处理数据

# 生成分类数据集
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=42)

# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

训练GBDT模型

# 训练GBDT模型
gbdt = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbdt.fit(X_train, y_train)

预测与评估

# 预测
y_pred = gbdt.predict(X_test)

# 评估
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')

# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:')
print(conf_matrix)

# 分类报告
class_report = classification_report(y_test, y_pred)
print('Classification Report:')
print(class_report)

结语

本文详细介绍了GBDT算法的原理和特点,并展示了其在回归和分类任务中的应用。首先介绍了GBDT算法的基本思想和公式,然后展示了如何在回归数据集使用GBDT进行回归任务,以及如何在分类数据集上使用GBDT进行文本分类任务。

我的其他同系列博客

支持向量机(SVM算法详解)
knn算法详解
GBDT算法详解
XGBOOST算法详解
CATBOOST算法详解
随机森林算法详解
lightGBM算法详解
对比分析:GBDT、XGBoost、CatBoost和LightGBM
机器学习参数寻优:方法、实例与分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/731022.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【mysql 安装启动失败】 没有网下 libssl.so.10 not found 如何解决?

问题描述: libssl.so.10 > not found libcrypto.so.10 > not found [rootmysql tools]# ls -l /usr/sbin/mysqld -rwxr-xr-x. 1 root root 64290024 Sep 14 2022 /usr/sbin/mysqld [rootmysql tools]# ldd /usr/sbin/mysqldlinux-vdso.so.1 (0x00007fff97105…

Blazor 组件:创建、生命周期、嵌套和 UI 集成

在本文中,您将获得以下问题的答案。 什么是 Blazor 组件?如何使用组件?Blazor 组件的生命周期是什么?我们可以从一个组件调用另一个组件吗?如何创建 Blazor 组件?在组件中哪里写 C# 代码? 什么…

碳化硅陶瓷膜的生产工艺和应用

一、生产工艺 碳化硅陶瓷膜的生产工艺多样,其中浸渍提拉法和喷涂法为两大主流技术。 浸渍提拉法 浸渍提拉法是一种广泛应用的制备方法。其过程主要包括:先将陶瓷颗粒或者聚合物前体分散在水或有机溶剂中,形成均质稳定的制膜液。随后&#xff…

深入探索C++中的AVL树

引言 在数据结构和算法的世界里,平衡二叉搜索树(Balanced Binary Search Tree, BST)是一种非常重要的数据结构。AVL树(Adelson-Velsky和Landis发明的树)就是平衡二叉搜索树的一种,它通过自平衡来维护其性质…

ELK+Filebeat+kafka+zookeeper构建海量日志分析平台

ELK是什么(What)? ELK组件介绍 ELK 是ElasticSearch开源生态中提供的一套完整日志收集、分析以及展示的解决方案,是三个产品的首字母缩写,分别是ElasticSearch、Logstash 和 Kibana。除此之外,FileBeat也是…

海外版coze前端代码助手

定位 解决前端同事的开发问题 参数配置 测试 支持 最屌的大模型及语音播报。 体验地址 海外版前端代码助手 需要魔法才能体验油

索尼MXF文件断电变2G恢复方法(PXW-Z280V)

PXM-Z280V算是索尼比较经典的机型,也是使用MXF文件格式的机型之一。近期接到很多例索尼MXF量突然不正常的案例(如变成512字节或者2G),下面来看下这个案例。 故障存储: 128G存储卡 /文件系统:exFAT 故障现象: 客户反…

Centos SFTP搭建

SFTP配置、连接及挂载教程_sftp连接-CSDN博客1、确认是否安装yum list installed | grep openssh-server 2、创建用户和组 sudo groupadd tksftpgroup sudo useradd -g tksftpgroup -d /home/www/tk_data -s /sbin/nologin tksftp01 sudo passwd tksftp013. 配置SFTP注意&a…

设置浏览器互不干扰

目录 一、查看浏览器文件路径 二、 其他盘新建文件夹Cache 三、以管理员运行CMD 四、执行命令 一、查看浏览器文件路径 chrome://version/ 二、 其他盘新建文件夹Cache D:\chrome\Cache 三、以管理员运行CMD 四、执行命令 Mklink /d "C:\Users\Lenovo\AppData\Loca…

国产化ETL产品必备的特性(非开源包装)

ETL负责将分布的、异构数据源中的数据如关系数据、平面数据文件等抽取到临时中间层后进行抽取、清洗(净化)、转换、装载、标准、集成(汇总)...... 最后加载到数据仓库或数据集市中,成为联机分析处理、数据挖掘的基础。…

关键属性描述ASYNC_REG

关键属性描述 属性信息 本章提供有关XilinxVivadoDesign Suite属性的信息。条目 每个属性包含以下信息(如适用): •物业说明,包括其主要用途。 •支持该特性的Xilinx FPGA体系结构,包括UltraScale™ 架构设备&#xff…

数据结构【二叉树】

前言 我们在前面学习了使用数组来实现二叉树,但是数组实现二叉树仅适用于完全二叉树(非完全二叉树会有空间浪费),所以我们本章讲解的是链式二叉树,但由于学习二叉树的操作需要有一颗树,才能学习相关的基本…

2024.6.23周报

目录 摘要 ABSTRACT 一、文献阅读 一、题目 二、摘要 三、网络架构 四、创新点 五、文章解读 1、Introduction 2、Method 3、实验 4、结论 二、代码实验 总结 摘要 本周阅读了一篇题目为NAS-PINN: NEURAL ARCHITECTURE SEARCH-GUIDED PHYSICS-INFORMED NEURAL N…

生成式AI和LLM的一些基本概念和名词解释

1. Machine Learning 机器学习是人工智能(AI)的一个分支,旨在通过算法和统计模型,使计算机系统能够从数据中学习并自动改进。机器学习算法使用数据来构建模型,该模型可用于预测或决策。机器学习应用于各种领域&#x…

Windows环境下使用VisualGDB进行Linux项目开发

1.新建项目-打开文件下的新建项目菜单 2.工程项目类型配置 3.Linux机器选择设置 4.设置代码位置 5.编译选项设置 6.调试环境设置

(Python)可变类型不可变类型;引用传递值传递;浅拷贝深拷贝

从一段代码开始说事,先上代码: a [[1],[2],[3]] b [[4,5],[6,7],[7,8]] for i,j in zip(a,b):print(i,j)i [9]#i[0] 8j[:2][1,2]print(i, j) print(a) print(b) 运行的结果: [1] [4, 5] [9] [1, 2] [2] [6, 7] [9] [1, 2] [3] [7, 8] …

后仿真中 module path polarity 问题

目录 一 未知极性 二 正极性 三 负极性 不知道大家有没有遇到这个问题:什么?我们知道的module path delay 指的是定义在specify...endspecify block 中的语句,指示输入-输出的延迟信息。 这里的module path 竟然还有极性问题,今天,来学习一下。 模块路径的极性是一…

使用dify.ai做一个婚姻法助手

步骤 1:注册并登录 Dify.ai 访问 Dify.ai 官网,注册一个账号并登录。 步骤 2:创建新项目 登录后,点击“创建新项目”。为项目命名,例如“婚姻法助手”。 步骤 3:导入婚姻法文本到知识库 在项目中&…

如何使用idea连接Oracle数据库?

idea版本:2021.3.3 Oracle版本:10.2.0.1.0(在虚拟机Windows sever 2003 远程连接数据库) 数据库管理系统:PLSQL Developer 在idea里面找到database,在idea侧面 选择左上角加号,新建&#xff…

定义和反射Annotation类(注解)

文章目录 前言一、定义Annotation类二、反射Anootation类 1.元注解2.反射注解总结 前言 在写代码的过程中,我们经常会写到注释,以此来提醒代码中的点。但是,这些注释不会被查看,也不在整个代码之中,只能在源代码中进行…