XGboost的整理

XGboost(extreme gradient boosting):高效实现了GBDT算法并进行了算法和工程上的许多改进。

XGboost的思路:

目标:建立k个回归树,使得树群的预测尽量接近真实值(准确率)而且有尽量大的泛化能力。

目标函数:L\left ( \phi \right )=\sum_{i}^{}l\left ( \hat{y_{i}}-y_{i} \right )+\sum_{k}^{}\Omega \left ( f_{k} \right )

i表示第i个样本,l\left ( \hat{y_{i}}-y_{i} \right )表示第i个样本的预测误差,误差越小越好,\sum_{k}^{}\Omega \left ( f_{k} \right )表示树的复杂度的函数,越小复杂度越低,泛化能力越强

\Omega \left ( f_{t} \right )=\gamma T+\frac{1}{2}\lambda \sum_{j=1}^{T}\omega _{j}^{2}

T:叶子的个数

\omega _{j}^{2}:w的L2模平方

目标要求预测尽量小,叶子节点尽量少,节点数值尽量不极端,回归树的参数(1)选取哪个特征分裂节点(2)节点的预测值。间接解决这两个参数的方法:贪心策略+最优化(二次最优化)

(1)选取哪个特征分裂节点:最简单的是枚举,选择loss function效果最好的那个

(2)确立节点的w以及最小的loss function,采用二次函数的求最值

步骤:选择一个feature分裂,计算loss function最小值,然后再选一个feature分列,又得到一个loss function最小值,枚举完成后,找一个效果最好的,把树分裂,在分裂的时候,每次节点分裂,loss function被影响的只有这个节点的样本,因而每次分裂,计算分裂的增益只需要关注打算分裂的那个节点的样本。接下来,继续分裂,按照上述方法,形成一棵树,再形成一棵树,每次在上一次的预测基础上取最优进一步分裂/建树。

停止条件:

①当引入的分裂带来的增益小于一个阈值的时候,可以剪掉这个分裂,所以并不是每一次分裂lossfunction整体都会增加的,有点预剪枝的意思,阈值参数为\gamma正则项里叶子节点数T的系数。

②当数达到最大深度时则停止建立决策树,设置一个超参数max_depth,树太深很容易出现过拟合。

③当样本权重和小于设定阈值时则停止建树,一个叶子节点样本太少时,终止,避免过拟合。

constant:常数,对于f\left ( x \right ),XGboost利用泰勒展开三项,做一个近似,f\left ( x \right )表示其中一颗回归树。

XGBoost与GBDT有什么不同:

1、GBDT是机器学习算法,XGboost是该算法的工程实现

2、在使用CART作为及分类器时,XGboost显式地加入了正则项来控制模型的复杂度,有利于防止过拟合,从而提高模型的泛化能力

3、GBDT在模型训练时只是用来代价函数的一阶导数信息,XGboost对代价函数进行二阶泰勒展开,可以同时使用一阶和二阶导数

4、传统的GBDT采用CART作为基分类器,XGboost支持多种类型的基分类器,比如线性分类器

5、传统的GBDT在每轮迭代时使用全部的数据,XGboost则采用了与随机森林相似的策略,支持对数据进行采样

6、传统的GBDT没有设计对缺失值的处理,而XGboost能够自动学习出缺失值的处理策略。

使用xgboost库中的XGBRegressor类来创建XGboost模型

import xgboost as xgb
xgb_clf=xgb.XGBRegressor(max_depth=8,
                         learning_rate=0.1,
                         objective="reg:linear",
                         eval_metric='rmse', 
                         n_estimators=3115,
                         colsample_bytree=0.6, 
                         reg_alpha=3, 
                         reg_lambda=2, 
                         gamma=0.6,
                         subsample=0.7, 
                         silent=1, 
                         n_jobs=-1)

XGBRegressor中的参数介绍:

max_depth:树的最大深度,增加这个值可以使模型更加复杂,并提高队训练数据的拟合程度,但可能会导致过拟合。通常需要通过交叉验证来调整这个参数。

learning_rate:学习率,用于控制每次迭代更新权重时的步长。

objective:定义了学习任务和相应的损失函数,“reg:linear” 表示我们正在解决一个线性回归问题。

eval_metric:评估指标,用于在训练过程中对模型的表现进行评估,‘rmse’ 表示均方根误差(Root Mean Squared Error),它是回归问题中常用的性能指标。

n_estimators:森林中树的数量,值越大,模型越复杂,训练时间也会相应增加。通常需要通过交叉验证来调整这个参数。

colsample_bytree:构建每棵树时对特征进行采样的比例。较小的值可以减少过拟合,提高模型的泛化能力。

reg_alpha:L1正则化项的权重,增加这个值同样也可以增加模型的正则化强度。

gamma:树的叶子节点进一步分裂所需的最小损失减少量。较大值会导致模型更保守,可能会导致模型的过拟合。

subsample:用于训练每棵树的样本占整个训练集的比例。

silent:设置为1可以关闭在运行时的日志信息。

n_jobs:并行运行的作业数。

基本模型:

import pandas as pd
import xgboost as xgb
import pandas
import numpy as np

# 将pandas数据框加载到DMatrix
data_train = pandas.DataFrame(np.arange(12).reshape((4,3)), columns=['a', 'b', 'c'])
label_train = pandas.DataFrame(np.random.randint(2, size=4))
dtrain = xgb.DMatrix(data_train, label=label_train, missing=np.NaN) # 缺失值可以用构造函数中的默认值替换DMatrix

data_test = pandas.DataFrame(np.arange(12, 24).reshape((4,3)), columns=['a', 'b', 'c'])
label_test = pandas.DataFrame(np.random.randint(2, size=4))
dtest = xgb.DMatrix(data_test, label=label_test, missing=np.NaN) # 缺失值可以用构造函数中的默认值替换DMatrix

# # 将CSV文件加载到DMatrix
# # label_column specifies the index of the column containing the true label
# dtrain = xgb.DMatrix('train.csv?format=csv&label_column=0')
# dtest = xgb.DMatrix('test.csv?format=csv&label_column=0')
# # XGBoost 中的解析器功能有限。当使用Python接口时,建议使用pandasread_csv或其他类似的实用程序而不是XGBoost的内置解析器。

param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
param['nthread'] = 4
param['eval_metric'] = ['auc', 'ams@0'] # 指定多个评估指标
# 指定验证集以观察性能
evallist = [(dtrain, 'train'), (dtest, 'eval')]

# 训练
num_round = 20
bst = xgb.train(param, dtrain, num_round, evallist, early_stopping_rounds=10) # 返回最后一次迭代的模型,而不是最好的模型
# early_stopping_rounds=10作用:如果模型在10轮内没有改善,则训练将提前停止,如果设置多个指标,则最后一个指标将用于提前停止
# 训练完成后,保存模型
bst.save_model('test_xgboost/0001.model')
# 模型转储到文本文件中
bst.dump_model('test_xgboost/dump.raw.txt')
# 加载模型
bst = xgb.Booster({'nthread': 4})  # 初始化模型,将线程数设置为4
bst.load_model('test_xgboost/0001.model')  # 加载模型
# 如果训练期间启动提前停止,可以从最佳迭代中获得预测
ypred = bst.predict(dtest, iteration_range=(0, bst.best_iteration + 1))
ypred = pd.DataFrame(ypred)
ypred.to_csv('test_xgboost/xgb_predict.csv', index=False)

 使用scikit-learn的方法

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

import xgboost as xgb

X, y = load_breast_cancer(return_X_y=True) # 加载数据
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=94)
# stratify=y:按目标变量分层划分,确保训练集和测试集中目标变量的比例与原始数据集相同
# random_state=94: 设置随机种子,保证每次划分的结果相同

# 使用hist来构建树,并启用早期停止
early_stop = xgb.callback.EarlyStopping(
    rounds=2, metric_name='logloss', data_name='validation_0', save_best=True
)
clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop])
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# 保存模型
clf.save_model("test_xgboost/clf.json")

https://xgboost.readthedocs.io/en/latest/python/index.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434243.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker安装+基础命令

一、检测、配置安装环境 (1)查看linux版本,是否符合>centos 7 (2)查看网络是否通畅 (3)安装gcc,gcc-c编译器 (4)安装device-mapper-persistent-data和lvm2…

IPsec VPN协议框架

IPsec是IETF(Internet Engineering Task Force)制定的一组开放的网络安全协议。它并不是一个单独的协议,而是一系列为IP网络提供安全性的协议和服务的集合,包括认证头AH(Authentication Header)和封装安全载…

LeetCode刷题---填充每个节点的下一个右侧节点指针

官方题解:LeetCode官方题解 解题思想: 因为是一棵满二叉树,所以除了叶子节点外的其他节点都有两个子节点。 可以根据每一层来依次遍历 从根节点开始,根节点的左子节点的next节点就指向根节点的右子节点 因为根节点的next节点为NULL,开始从根…

C语言结构体的大小,结构体内存对齐

1. 结构体的大小 在自己正真了解过之前&#xff0c;一直认为结构体的大小就是结构体内部成员大小的总和。 但当你去尝试打印结构体的大小时&#xff0c;会发现事实并非如此&#xff0c;也不会像你想的那样简单。 #include <stdio.h>struct S1 {char c1;char c2;int i;…

腾讯云服务器99元一年购买入口链接

腾讯云服务器99元一年购买入口链接如下&#xff0c;现在已经降价到61元一年&#xff0c;官方活动链接如下&#xff1a; 腾讯云99元服务器一年购买页面腾讯云活动汇聚了腾讯云最新的促销打折、优惠折扣等信息&#xff0c;你在这里可以找到云服务器、域名、数据库、小程序等等多种…

springcloud:3.5测试慢调用熔断降级

服务提供者【test-provider8001】 Openfeign远程调用服务提供者搭建 文章地址http://t.csdnimg.cn/06iz8 相关接口 测试远程调用&#xff1a;http://localhost:8001/payment/index 服务消费者【test-consumer-resilience4j8004】 Openfeign远程调用消费者搭建 文章地址http://t…

Python 系统学习总结(基础语法+函数+数据容器+文件+异常+包+面向对象)

&#x1f525;博客主页&#xff1a; A_SHOWY&#x1f3a5;系列专栏&#xff1a;力扣刷题总结录 数据结构 云计算 数字图像处理 力扣每日一题_ 六天时间系统学习Python基础总结&#xff0c;目前不包括可视化部分&#xff0c;其他部分基本齐全&#xff0c;总结记录&#xff0…

使用数据库实现增删改查

#include<myhead.h>//定义添加数据函数int do_add(sqlite3 *ppDb) {//1.准备sql语句,输入要添加的信息int add_numb; //工号char add_name[20]; //姓名char add_sex[10]; //性别double add_score; //工资printf("请输入要添加的工号:")…

Android开发经典实战,Android面试题目

关于Android的近况 大家都知道&#xff0c;今年移动开发不那么火热了&#xff0c;完全没有了前两年Android开发那种火热的势头&#xff0c;如此同时&#xff0c;AI热火朝天&#xff0c;很多言论都说Android不行了。其实不光是Android&#xff0c;iOS也有类似的言论。 那么到底…

[项目设计] 从零实现的高并发内存池(四)

&#x1f308; 博客个人主页&#xff1a;Chris在Coding &#x1f3a5; 本文所属专栏&#xff1a;[高并发内存池] ❤️ 前置学习专栏&#xff1a;[Linux学习] ⏰ 我们仍在旅途 ​ 目录 6.内存回收 6.1 ThreadCache回收内存 6.2 CentralCache回收内存 Rele…

STM32CubeMX学习笔记11 ---RTC实时时钟

1、RTC实时时钟简介 STM32的实时时钟RTC是一个独立的定时器&#xff0c;RTC模块拥有一组连续计数的计数器&#xff0c;在相应软件配置下&#xff0c;可提供时钟日历的功能&#xff0c;修改计数器的值可以重新设置系统当前的时间和日期 RTC模块和时钟配置系统&#xff08;RCC_B…

100 spring-security 中 /oauth/token 发送请求不携带参数 报错 “401 Unauthorized“

前言 最近存在这样的一个问题, 大致的复现方式是 访问 /oauth/token 接口, 然后不携带任何参数, 结果 服务器抛出了一个 "401 Unauthorized" 针对这个 401, 这里 梳理一下这个流程, 也会衍生出一些其他的问题 测试用例 客户端这边大致的情况是 构造参数, 然后发…

Linux 之三:CentOS7 目录结构 和 日期及时区设置

Linux 目录 以下是对这些目录的解释&#xff1a; /bin&#xff1a;bin是Binary的缩写, 这个目录存放着最经常使用的命令。/boot&#xff1a; 这里存放的是启动Linux时使用的一些核心文件&#xff0c;包括一些连接文件以及镜像文件。/dev &#xff1a; dev是Device(设备)的缩写…

[Java安全入门]二.序列化与反序列化

一.概念 Serialization&#xff08;序列化&#xff09;是一种将对象以一连串的字节描述的过程&#xff1b;反序列化deserialization是一种将这些字节重建成一个对象的过程。将程序中的对象&#xff0c;放入文件中保存就是序列化&#xff0c;将文件中的字节码重新转成对象就是反…

【广度优先搜索】【堆】【C++算法】407. 接雨水 II

作者推荐 【二分查找】【C算法】378. 有序矩阵中第 K 小的元素 本文涉及知识点 广度优先搜索 堆 LeetCoce407. 接雨水 II 给你一个 m x n 的矩阵&#xff0c;其中的值均为非负整数&#xff0c;代表二维高度图每个单元的高度&#xff0c;请计算图中形状最多能接多少体积的雨…

【C语言】还有柔性数组?

前言 也许你从来没有听说过柔性数组&#xff08;flexible array&#xff09;这个概念&#xff0c;但是它确实是存在的。C99中&#xff0c;结构中的最后⼀个元素允许是未知⼤⼩的数组&#xff0c;这就叫做『柔性数组』成员。 欢迎关注个人主页&#xff1a;逸狼 创造不易&#xf…

二级水平导航菜单栏的实现

1. 这个是本人设计的一带一路的二级水平导航栏HTML代码&#xff1b; 这里最后实现的效果是鼠标悬停在导航栏上面&#xff0c;就会显示下面的4个部分页面&#xff0c;这里只是以评论热 点作为例子&#xff0c;其他的类似&#xff1b; 2.首先要设计DIV&#xff0c;然后利用无…

GO语言环境安装---VScode.2024

目录 一、下载并安装GO 二、配置环境变量 三、VScode环境安装 由于工作原因&#xff0c;需要用到go来写web后端&#xff0c;正好从零记录下环境安装 一、下载并安装GO 首先在官网根据PC系统选择对应的包下载 源地址&#xff1a;https://go.dev/dl/ 打不开的用这个也行&a…

论文阅读:Dataset Quantization

摘要 最先进的深度神经网络使用大量&#xff08;百万甚至数十亿&#xff09;数据进行训练。昂贵的计算和内存成本使得在有限的硬件资源上训练它们变得困难&#xff0c;特别是对于最近流行的大型语言模型 (LLM) 和计算机视觉模型 (CV)。因此最近流行的数据集蒸馏方法得到发展&a…

第三天 Kubernetes进阶实践

第三天 Kubernetes进阶实践 本章介绍Kubernetes的进阶内容&#xff0c;包含Kubernetes集群调度、CNI插件、认证授权安全体系、分布式存储的对接、Helm的使用等&#xff0c;让学员可以更加深入的学习Kubernetes的核心内容。 ETCD数据的访问 kube-scheduler调度策略实践 预选与…