模型融合之模型堆叠

一、理论

模型堆叠(Model Stacking)是一种集成学习的方法,其本质是将多个基学习器(Individual Learner)的预测结果作为新的特征,再训练一个元学习器(Meta Learner)来进行最终的预测。每个基学习器可以使用不同的算法或参数来训练,因此理论上你可以将任何可以输出概率值的模型作为基学习器进行模型堆叠

StackingClassifier应用于分类问题,StackingRegressor应用于回归问题。

level 0上训练的多个强学习器被称为基学习器(base-model),也叫做个体学习器。在level 1上训练的学习器叫元学习器(meta-model)。根据行业惯例,level 0上的学习器是复杂度高、学习能力强的学习器,例如集成算法、支持向量机,而level 1上的学习器是可解释性强、较为简单的学习器,如决策树、线性回归、逻辑回归等。有这样的要求是因为level 0上的算法们的职责是找出原始数据与标签的关系、即建立原始数据与标签之间的假设,因此需要强大的学习能力。但level 1上的算法的职责是融合个体学习器做出的假设、并最终输出融合模型的结果,相当于在寻找“最佳融合规则”,而非直接建立原始数据与标签之间的假设。

二、实例 

# 常用工具库
import re
import numpy as np
import pandas as pd
import matplotlib as mlp
import matplotlib.pyplot as plt
import time

# 算法辅助 & 数据
import sklearn
from sklearn.model_selection import KFold, cross_validate
from sklearn.datasets import load_digits  # 分类 - 手写数字数据集
from sklearn.datasets import load_iris
#from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

# 算法(单一学习器)
from sklearn.neighbors import KNeighborsClassifier as KNNC
from sklearn.neighbors import KNeighborsRegressor as KNNR
from sklearn.tree import DecisionTreeRegressor as DTR
from sklearn.tree import DecisionTreeClassifier as DTC
from sklearn.linear_model import LinearRegression as LR
from sklearn.linear_model import LogisticRegression as LogiR
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.ensemble import GradientBoostingRegressor as GBR
from sklearn.ensemble import GradientBoostingClassifier as GBC
from sklearn.naive_bayes import GaussianNB
import xgboost as xgb

# 融合模型
from sklearn.ensemble import StackingClassifier

data = load_digits()
X = data.data  #(1797, 64),代表了1797个样本,每个样本有64个特征,这对应了8x8像素的图片
y = data.target #一维数组,其维度为 (1797,),包含了对应图片的真实数字标签(0-9)

# 划分数据集
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.2, random_state=1412)

def fusion_estimators(clf):
    """
    对融合模型做交叉验证,对融合模型的表现进行评估
    """
    cv = KFold(n_splits=5, shuffle=True, random_state=1412)
    results = cross_validate(clf, Xtrain, Ytrain , cv=cv , scoring="accuracy", n_jobs=-1, return_train_score=True, verbose=False)
    test = clf.fit(Xtrain, Ytrain).score(Xtest, Ytest)
    print("train_score:{}".format(results["train_score"].mean())
          , "\n cv_mean:{}".format(results["test_score"].mean())
          , "\n test_score:{}".format(test))


def individual_estimators(estimators):
    """
    对模型融合中每个评估器做交叉验证,对单一评估器的表现进行评估
    """
    for estimator in estimators:
        cv = KFold(n_splits=5, shuffle=True, random_state=1412) #创建了一个5折交叉验证的KFold对象cv
        results = cross_validate(estimator[1], Xtrain, Ytrain, cv=cv , scoring="accuracy", n_jobs=-1, return_train_score=True, verbose=False)
        test = estimator[1].fit(Xtrain, Ytrain).score(Xtest, Ytest)
        print(estimator[0]
              , "\n train_score:{}".format(results["train_score"].mean()) #训练集得分的平均值
              , "\n cv_mean:{}".format(results["test_score"].mean())      #交叉验证得分的平均值
              , "\n test_score:{}".format(test), "\n")                    #测试集得分
        # estimator包含2个元素的元组('Logistic Regression', LogisticRegression(C=0.1, max_iter=3000, n_jobs=8, random_state=1412))

# 逻辑回归没有增加多样性的选项
clf1 = LogiR(max_iter=3000, C=0.1, random_state=1412, n_jobs=8)
# 增加特征多样性与样本多样性
clf2 = RFC(n_estimators=100, max_features="sqrt", max_samples=0.9, random_state=1412, n_jobs=8)
# 特征多样性,稍微上调特征数量
clf3 = GBC(n_estimators=100, max_features=16, random_state=1412)

# 增加算法多样性,新增决策树与KNN
clf4 = DTC(max_depth=8, random_state=1412)
clf5 = KNNC(n_neighbors=10, n_jobs=8)
clf6 = GaussianNB()

# 新增随机多样性,相同的算法更换随机数种子
clf7 = RFC(n_estimators=100, max_features="sqrt", max_samples=0.9, random_state=4869, n_jobs=8)
clf8 = GBC(n_estimators=100, max_features=16, random_state=4869)

estimators = [("Logistic Regression", clf1), ("RandomForest", clf2)
    , ("GBDT", clf3), ("Decision Tree", clf4), ("KNN", clf5)
              # , ("Bayes",clf6)
    , ("RandomForest2", clf7), ("GBDT2", clf8)
              ]

#选择单个评估器中分数最高的随机森林作为元学习器
#也可以尝试其他更简单的学习器
final_estimator = RFC(n_estimators=100
                      , min_impurity_decrease=0.0025
                      , random_state= 420, n_jobs=8)
clf = StackingClassifier(estimators=estimators #level0的7个体学习器
                         ,final_estimator=final_estimator #level 1的元学习器
                         ,n_jobs=8)
fusion_estimators(clf)
individual_estimators(estimators)

结果:

参考链接:https://blog.csdn.net/weixin_46803857/article/details/128700297 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/288541.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

探索 Vue 实例方法的魅力:提升 Vue 开发技能(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

Green Sock | GSAP 动画库

1.什么是“GSAP”? GreenSock Animation Platform(GSAP) 是一个业界知名的动画工具套件,在超过1100万个网站上使用,其中包括大量获奖网站! 您可以使用GSAP在任何框架中制作几乎任何JavaScript可以触及的动…

迅腾文化传播:触动每个移动消费者心灵的品牌故事缔造者

迅腾文化传播:触动每个移动消费者心灵的品牌故事缔造者 在这个高速发展的移动互联网时代,信息如同浩渺星海中的流星,瞬息万变。每个人的手机、平板、智能手表等移动设备,都成为了他们与世界连接的窗口。品牌,作为这个…

谷歌推出创新SynCLR技术:借助AI生成的数据实现高效图像建模,开启自我训练新纪元!

谷歌推出了一种创新性的合成图像框架,这一框架独特之处在于它完全不依赖真实数据。这个框架首先从合成的图像标题开始,然后基于这些标题生成相应的图像。接下来,通过对比学习的技术进行深度学习,从而训练出能够精准识别和理解这些…

STM32 学习(三)OLED 调试工具

目录 一、简介 二、使用方法 2.1 接线图 2.2 配置引脚 2.3 编写代码 三、Keil 工具调试 一、简介 在进行单片机开发时,有很多调试方法,如下图: 其中 OLED 就是一种比较好用的调试工具: OLED 硬件电路如下&#xff0c…

使用Redis进行搜索

文章目录 构建反向索引 构建反向索引 在Begin-End区域编写 tokenize(content) 函数,实现文本标记化的功能,具体参数与要求如下: 方法参数 content 为待标记化的文本; 文本标记的实现:使用正则表达式提取全小写化后的…

【竞技宝】DOTA2:Mad Kings官宣新阵容 南美新星Jimpark加盟!

北京时间2024年1月3日,随着本月DOTA2ESL吉隆坡站的比赛结束,下一个值得关注的大赛梦幻联赛S22举办的时间要等到今年的二月份了。虽然休赛期的转会狂潮已经过去,但目前还是有很多队伍依然在调整新赛季的阵容。 近日,南美战队Mad Kings在社交平台上官宣了发文,原阵容的所有选手(一…

LeetCode 每日一题 Day 28293031 ||三则模拟||找循环节(hard)

1185. 一周中的第几天 给你一个日期,请你设计一个算法来判断它是对应一周中的哪一天。 输入为三个整数:day、month 和 year,分别表示日、月、年。 您返回的结果必须是这几个值中的一个 {“Sunday”, “Monday”, “Tuesday”, “Wednesday…

创建Qt项目

项目工程名称一般不要有特殊符号,不要有中文 项目工程保存路径可修改的,但路径不要带中文 构建系统,有3种,这里使用qmake qmake和cmake区别 构建过程不同,项目管理不同。 1、构建过程,qmake是Qt框架自带的…

完善 Golang Gin 框架的静态中间件:Gin-Static

Gin 是 Golang 生态中目前最受用户欢迎和关注的 Web 框架,但是生态中的 Static 中间件使用起来却一直很不顺手。 所以,我顺手改了它,然后把这个改良版开源了。 写在前面 Gin-static 的改良版,我开源在了 soulteary/gin-static&a…

Twincat中PLC编程的ST语言

在Twincat中,PLC编程使用的是Structured Text(ST)语言。ST语言是一种类似于Pascal的高级编程语言,专为工业自动化领域的程序开发而设计。它提供了结构化的控制流和数据操作,使得PLC编程更加高效和可靠。 https://kunal…

数字信号处理期末复习——计算大题(一)

个人名片: 🦁作者简介:一名喜欢分享和记录学习的在校大学生 🐯个人主页:妄北y 🐧个人QQ:2061314755 🐻个人邮箱:2061314755qq.com 🦉个人WeChat:V…

一文讲清数据资产入表实操

《中共中央 国务院关于构建数据基础制度更好发挥数据要素作用的意见》已发布一年,数据资产化和入表已成为2023年的热门话题,随着2023年底国家数据局吹风《"数据要素x"三年行动计划(2024-2026年)》即将发布,这…

JRT控制打印机

本次测试打印机控制和纸张方向控制。 打印机状态 选择打印机 控制纸张 定义纸张 旋转纸张 不旋转纸张 A4

提高工作效率的Postman环境变量使用方法

在 Postman 中,用 Environments 来管理环境变量。我们在开发的过程中,往往会用到多个环境:开发环境,测试环境,UAT 环境,生产环境等。我们要调用不同环境的 API 时,只需切换 Postman 的 Environm…

Vuex(vue2中的状态机)

目录 Vuex state属性 getters属性 mutations属性 actions属性 modules属性 辅助函数 Vuex 状态管理模式 维护公共状态 公共数据 使用状态机模块维护状态 A组件中分发工作(发起异步请求)--->获取数据--->提交突变(将数据提交给突变 ) 通过突变修改状态…

Win32 基本程序设计原理总结

目录 1. Windows系统 基本原理 2. 需要什么函数库(.LIB) 2.1 C Runtimes: 2.2 Windows API 3. 需要什么头文件(.H) 4. Windows 程序运行的本质 5. 窗口类的注册与窗口的诞生 6.消息 6.1 消息分类:…

【每天五道题,轻松公务员】Day1:图形推理

目录 专栏了解 ☞欢迎订阅☜ ★专栏亮点★ ◇专栏作者◇ 图形推理 题目一 题目二 题目三 题目四 题目五 答案 详细讲解 讲解一 讲解二 讲解三 讲解四 讲解五 专栏了解 ☞欢迎订阅☜ 欢迎订阅此专栏:考公务员,必订!https://…

科研学习|论文解读——信息行为和社会控制:了解家庭慢性病管理中的冲突信息行为

摘要 信息与控制的关系引起了社会科学家的兴趣。然而,许多先前的工作都集中在组织而不是家庭上。交互式信息行为的研究也侧重于组织和协作,而不是冲突。因此,在有慢性疾病的家庭中,我们调查了健康相关社会控制背景下的信息行为以及…

什么是安全信息和事件管理(SIEM),有什么用处

安全信息和事件管理(SIEM)对于企业主动识别、管理和消除安全威胁至关重要。SIEM 解决方案采用事件关联、AI 驱动的异常检测以及机器学习驱动的用户和实体行为分析 (UEBA) 等机制来检测、审查和应对网络安全威胁。这些功能使 SIEM …