模型优化_XGBOOST学习曲线及改进,泛化误差

代码

from xgboost import XGBRegressor as XGBR
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.linear_model import LinearRegression as LR
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split,cross_val_score as CV,KFold
from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from time import time
import datetime

#加载数据
data=load_boston()
X=data.data
y=data.target

#划分数据集
Xtrain,Xtest,ytrain,ytest=train_test_split(X,y,test_size=0.3,random_state=420)

#定位模型,进行fit
reg=XGBR(n_estimators=100).fit(Xtrain,ytrain)

#进行预测
reg.predict(Xtest)
reg.score(Xtest,ytest)#返回的是R平方
MSE(ytest,reg.predict(Xtest))
reg.feature_importances_
#查看SKLEARN中所有的模型评估指标
import sklearn
sorted(sklearn.metrics.SCORERS.keys())



# ======================================
#交叉验证,与线性回归随机森林进行结果比对
reg=XGBR(n_estimators=100)

from sklearn.model_selection import train_test_split,cross_val_score
cross_val_score(reg,Xtrain,ytrain,cv=5).mean()#

#交叉验证既可以解决数据集的数据量不够大问题,也可以解决参数调优的问题。
#这块主要有三种方式:简单交叉验证(HoldOut检验)、k折交叉验证(k-fold交叉验证)
cross_val_score(reg,Xtrain,ytrain,cv=5,scoring="neg_mean_squared_error").mean()

#绘制学习曲线
def plot_learning_curve(estimator,title,X,y,
                       ax=None,#选择子图
                       ylim=None,#设置纵坐标的取值范围
                       cv=None,#交叉验证
                       n_jobs=None):
    from sklearn.model_selection import learning_curve
    train_sizes,train_scores,test_scores=learning_curve(estimator,X,y,shuffle=True,
                                                       cv=cv,random_state=420,n_jobs=n_jobs)
    if ax==None:
        ax=plt.gca()
    else:
        ax=plt.figure()
    ax.set_title(title)
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_xlabel("Traing example")
    ax.set_ylabel("Score")
    ax.grid()#绘制网格
    ax.plot(train_sizes,np.mean(train_scores,axis=1),"o-",color="r",label="traing score")
    ax.plot(train_sizes,np.mean(test_scores,axis=1),"o-",color="g",label="test.py score")
    ax.legend(loc="best")
    return ax

#学习曲线的绘制
cv=KFold(n_splits=5,shuffle=True,random_state=42)
plot_learning_curve(XGBR(n_estimators=100,random_state=420),"XGB",Xtrain,ytrain,ax=None,cv=cv)

                        

#绘制学习曲线,查看n_estimators对模型的影响 

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5).mean()
    rs.append(cv1)
print(axis[rs.index(max(rs))],max(rs))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

泛化误差:用来衡量模型在未知数据集上的准确率

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)
    rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率
    var.append(cv1.var())
    ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

 

#绘制学习曲线,查看n_estimators对模型的影响

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,30,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)
    rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率
    var.append(cv1.var())
    ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
#添加方差线条
rs=np.array(rs)
var=np.array(var)#源代码这里*0.01
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.plot(axis,rs+var,c="black",linestyle="-.")
plt.plot(axis,rs-var,c="black",linestyle="-.")
plt.legend()
plt.show()

#看看泛化误差的可控部分如何

plt.figure(figsize=(20,5))
plt.plot(axis,ge,c='red',label="XGB")
plt.legend()
plt.show()
 

从这个过程中观察n_estimators参数对模型的影响,我们可以得出以下结论:
首先,XGB中的树的数量决定了模型的学习能力,树的数量越多,模型的学习能力越强。只要XGB中树的数量足够
了,即便只有很少的数据, 模型也能够学到训练数据100%的信息,所以XGB也是天生过拟合的模型。但在这种情况
下,模型会变得非常不稳定。
第二,XGB中树的数量很少的时候,对模型的影响较大,当树的数量已经很多的时候,对模型的影响比较小,只能有
微弱的变化。当数据本身就处于过拟合的时候,再使用过多的树能达到的效果甚微,反而浪费计算资源。当唯一指标
或者准确率给出的n_estimators看起来不太可靠的时候,我们可以改造学习曲线来帮助我们。
第三,树的数量提升对模型的影响有极限,最开始,模型的表现会随着XGB的树的数量一起提升,但到达某个点之
后,树的数量越多,模型的效果会逐步下降,这也说明了暴力增加n_estimators不一定有效果。
这些都和随机森林中的参数n_estimators表现出一致的状态。在随机森林中我们总是先调整n_estimators,当
n_estimators的极限已达到,我们才考虑其他参数,但XGB中的状况明显更加复杂,当数据集不太寻常的时候会更加
复杂。这是我们要给出的第一个超参数,因此还是建议优先调整n_estimators,一般都不会建议一个太大的数目,
300以下为佳。

参考:

XGBOOST学习曲线及改进,泛化误差-CSDN博客

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/418174.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

javascript作用域编译浅析

作用域思维导图 1:编译原理 分词/词法分析 如果词法单元生成器在判断a是一个独立的词法单元还是其他词法单元的一部分时,调用的是有状态的解析规则,那么这个过程就被称为词法分析。 解析/语法分析 由词法单元流转换成一个由元素逐级嵌套所组…

Linux 下安装Jupyter

pip3 install jupyter pip3 install ipython -------------------------------------------- pip3 install jupyterlab jupyter lab pip3 list | grep jupyterlab 启动: python3 -m jupyter lab 2.安装朱皮特 pip3 install -i https://pypi.douban.com/simpl…

免费音频剪辑

在数字时代,音频剪辑已成为许多职业和爱好者不可或缺的技能。无论是制作播客、教育视频、还是进行广告宣传,高质量的音频剪辑都能为作品增色不少。今天,我要为大家强烈安利一款免费且功能强大的音频剪辑工具,它绝对是你办公桌上不…

[分类指标]准确率、精确率、召回率、F1值、ROC和AUC、MCC马修相关系数

准确率、精确率、召回率、F1值 定义: 1、准确率(Accuracy) 准确率是指分类正确的样本占总样本个数的比例。准确率是针对所有样本的统计量。它被定义为: 准确率能够清晰的判断我们模型的表现,但有一个严重的缺陷&…

【OJ比赛日历】快周末了,不来一场比赛吗? #03.02-03.08 #11场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号会推送最新的比赛消息,欢迎关注! 以下信息仅供参考,以比赛官网为准 目录 2024-03-02(周六) #4场比赛2024-03-03…

如何解决局域网tcp延迟高来进行安全快速内外网传输呢?

在当今企业运营中,数据的快速流通变得至关重要,但局域网内的TCP延迟问题却成为了数据传输的障碍。本文旨在分析局域网TCP延迟的成因,并探讨几种企业数据传输的常见模式,以及如何为企业选择合适的传输策略,以确保数据在…

软考52-上午题-【数据库】-关系模式2

一、关系模式的回顾 见:软考38-上午题-【数据库】-关系模式 二、关系模式 2-1、关系模式的定义 示例: 念法:A——>B A决定B,或者,B依赖于A。 2-2、函数依赖 1、非平凡的函数依赖 如果X——>Y,&a…

Javascript:输入输出

目录 一.前言 二.正文 1.输出 2.输入 3.字面量 概念: 三.结语 一.前言 Javascript作为运行浏览器的语言,对于学习前端的同学来说十分重要,那么从现在开始我们将开始介绍有关 Javascript。 二.正文 1.输出 document.write() : 向body内…

UVa11726 Crime Scene

题目链接 UVa11726 - Crime Scene 题意 给定n(n≤100)个物体,每个物体都是一个圆或者k(k≤10)边形,用长度尽量小的绳子把它们包围起来。 分析 孟加拉国Manzurur Rahman Khan (Sidky)大神出的难题&#xff…

HTML5+CSS3+JS小实例:右键菜单

实例:右键菜单 技术栈:HTML+CSS+JS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"> <head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><met…

Linux x86平台获取sys_call_table

文章目录 前言一、根据call *sys_call_table来获取二、使用dump_stack三、使用sys_close参考资料 前言 Linux 3.10.0 – x86_64 最简单获取sys_call_table符号的方法&#xff1a; # cat /proc/kallsyms | grep sys_call_table ffffffff816beee0 R sys_call_table一、根据cal…

想给金三银四找工作的程序员几点建议,被面试官问的Android问题难倒了

Android没凉&#xff0c;只是比以前难混了 多年前Android异军突起&#xff0c;成了新的万亿级市场&#xff0c;无数掘金人涌入&#xff0c;期待可以一展拳脚。 那时候大环境下的手游圈&#xff0c;只要你能有个可以运行的连连看就能找到工作&#xff0c;走上赛道被浪潮推着前…

LeetCode73题:矩阵置零(python3)

代码思路&#xff1a; 这里用矩阵的第一行和第一列来标记是否含有0的元素&#xff0c;但这样会导致原数组的第一行和第一列被修改&#xff0c;无法记录它们是否原本包含 0。因此我们需要额外使用两个标记变量分别记录第一行和第一列是否原本包含 0。 class Solution:def setZe…

二、TensorFlow结构分析(2)

目录 1、会话 1.1 __init__(target,graphNone,configNone) 1.2 会话的run() 1.3 feed操作 TF数据流图图与TensorBoard会话张量变量OP高级API 1、会话 1.1 __init__(target,graphNone,configNone) def session_demo():# 会话的演示# Tensorflow实现加法运算a_t tf.constan…

python+Django+Neo4j中医药知识图谱与智能问答平台

文章目录 项目地址基础准备正式运行 项目地址 https://github.com/ZhChessOvO/ZeLanChao_KGQA 基础准备 请确保您的电脑有以下环境&#xff1a;python3&#xff0c;neo4j 在安装目录下进入cmd&#xff0c;输入指令“pip install -r requirement.txt”,安装需要的python库 打…

【二叉搜索树】【递归】【迭代】Leetcode 700. 二叉搜索树中的搜索

【二叉搜索树】【递归】【迭代】Leetcode 700. 二叉搜索树中的搜索 二叉搜索树解法1 递归法解法2 迭代法 ---------------&#x1f388;&#x1f388;题目链接&#x1f388;&#x1f388;------------------- 二叉搜索树 二叉搜索树&#xff08;Binary Search Tree&#xff…

【详识JAVA语言】运算符

什么是运算符 计算机的最基本的用途之一就是执行数学运算&#xff0c;比如&#xff1a; int a 10; int b 20;a b; a < b; 上述 和< 等就是运算符&#xff0c;即&#xff1a;对操作数进行操作时的符号&#xff0c;不同运算符操作的含义不同。 作为一门计算机语言&…

mprpc分布式RPC网络通信框架

mprpc 项目介绍 该项目是一个基于muduo、Protobuf和Zookeeper实现的轻量级分布式RPC网络通信框架。 可以把任何单体架构系统的本地方法调用&#xff0c;重构成基于TCP网络通信的RPC远程方法调用&#xff0c;实现同一台机器的不同进程之间的服务调用&#xff0c;或者不同机器…

FreeRTOS 软件定时器

目录 一、软件定时器简介 1、软件定时器概述 2、编写回调函数的注意事项 二、软件定时器实现机制 1、软件定时器实现机制 2、软件定时器相关配置 三、单次定时器 四、周期定时器 五、软件定时器的基本操作 1、创建软件定时器 2、复位软件定时器 3、开启软件定时器 …

【ZooKeeper 】安装和使用,以及java客户端

目录 1. 前言 2. ZooKeeper 安装和使用 2.1. 使用Docker 安装 zookeeper 2.2. 连接 ZooKeeper 服务 2.3. 常用命令演示 2.3.1. 查看常用命令(help 命令) 2.3.2. 创建节点(create 命令) 2.3.3. 更新节点数据内容(set 命令) 2.3.4. 获取节点的数据(get 命令) 2.3.5. 查看…