【机器学习系列】Python实战:使用GridSearchCV优化AdaBoost分类器及其基分类器

目录

一、AdaBoost的标准实现中是否支持使用不同类型的基分类器?

二、Adaboost的参数

三、Python实现Adaboost

(一)导入库和数据集

(二) 划分训练集

(三)选择基分类器--决策树

 (四)创建Adaboost分类器

(五) 网格调参寻找Adaboost中最优参数:n_estimators和learning_rate

(六)创建 GridSearchCV 对象并执行网格搜索

 (七)获取最优参数和最佳准确率

(八) 打印最佳参数和最佳准确率

(九) 使用最佳参数训练AdaBoost分类器

(十) 进行预测并计算准确率

(十一) 输出评估报告


本文旨在深入探索AdaBoost算法的标准实现,并解释如何通过网格搜索(GridSearchCV)对其及其基分类器(如决策树)的参数进行优化,以在分类任务中达到更高的准确率。我们将从AdaBoost的基本概念讲起,介绍其在Python中的实现方式,并通过一个实例详细展示如何划分训练集、选择基分类器、创建AdaBoost分类器、调参优化以及评估预测性能。本篇博客将帮助读者理解AdaBoost算法的调优步骤,并能够运用网格搜索技术寻找最优的模型参数,从而提高模型在实际应用中的预测精度。

一、AdaBoost的标准实现中是否支持使用不同类型的基分类器?

在标准的AdaBoost实现中,由于算法设计时假设所有的基分类器都是同质的,因此通常所有的基分类器都是同一类型。这意味着在同一个AdaBoost模型中,通常所有基分类器都是决策树、KNN或其他单一类型的分类器。

然而,一些变体的集成学习方法允许使用不同类型的基分类器。例如,Stacking(堆叠)是一种集成学习方法,它可以在不同的层中使用不同类型的分类器。在Stacking中,第一层的分类器可以是不同类型的,而第二层的分类器则使用第一层的分类器的输出作为输入。

对于AdaBoost来说,如果你想在一个模型中同时使用决策树和KNN,你需要使用一个允许混合不同类型基分类器的集成学习方法,而不是标准的AdaBoost。例如,你可以手动实现一个自定义的集成学习方法,它结合了AdaBoost的权重更新机制和不同类型的基分类器。

二、Adaboost的参数

在机器学习领域,Adaboost算法是一个强大的集成学习技术,它结合多个弱学习器来生成一个强大的预测模型。为了优化Adaboost算法的性能,有几个关键参数需要调整:

1. 估计器数量(n_estimators):此参数决定集成中包含的弱学习器的数量。增加估计器的数量通常能提高模型的性能,但也会增加计算时间和可能引起过拟合。通过交叉验证选择最优的估计器数量是一种常见的做法。

2. 学习率(learning_rate):学习率决定了每个弱学习器对最终模型的贡献比例。较小的学习率能使训练过程更慢,但有助于防止过拟合;较大的学习率虽然能加快训练速度,但也可能导致过拟合。这个参数通常与估计器数量一起调整以找到最佳平衡点。

3. 基础估计器选择(estimator):选择合适的基础估计器(弱学习器)非常关键,常用的如决策树桩(一层浅决策树)或简单的线性模型。不同的基础估计器可能对不同数据集的适应性各不相同,因此尝试多种基础估计器是有益的。

4. 基础分类器超参数:如果基础估计器有超参数(例如决策树的max_depth),调整这些参数也会影响AdaBoost模型的性能。

5. 随机种子(random_state):设置一个固定的随机种子可以确保实验的可重复性。但在超参数调优时,尝试不同的随机种子可以测试模型在不同随机初始化下的稳定性。

6. 交叉验证:交叉验证是超参数调整中不可或缺的一部分,可以帮助评估模型在未见数据上的性能,并防止在训练集上过拟合。使用网格搜索或随机搜索等技术可以有效探索超参数空间。

通过精心调整这些参数,可以显著提升Adaboost模型的性能,实现在各种数据集上的最优表现。

三、Python实现Adaboost

(一)导入库和数据集

from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
# 导入鸢尾花数据集
iris = load_iris()
X = iris.data # 特征
y = iris.target # 类别
feature_names = iris.feature_names # 特征名称
class_names = iris.target_names # 类别名称

(二) 划分训练集

 # 将数据集划分为训练集和测试集,比例为 8:2
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

(三)选择基分类器--决策树

# 首先创建基础分类器
base_estimator = DecisionTreeClassifier(criterion="entropy", max_depth=2, min_samples_split=2, min_samples_leaf=2, random_state=0)

 (四)创建Adaboost分类器

# 然后创建 AdaBoost 分类器实例
ada_classifier = AdaBoostClassifier(estimator = base_estimator,random_state=0) 
# ada_classifier.base_estimator = base_estimator

(五) 网格调参寻找Adaboost中最优参数:n_estimators和learning_rate


# 定义参数网格
param_grid = {
    'n_estimators': [10, 20, 30,40,50],
    'learning_rate': [0.01, 0.02,0.001]
}

(六)创建 GridSearchCV 对象并执行网格搜索

# 创建 GridSearchCV 对象
grid_search = GridSearchCV(estimator=ada_classifier, param_grid=param_grid, cv=5, scoring='accuracy')
# 执行网格搜索
grid_search.fit(X_train, y_train)

 (七)获取最优参数和最佳准确率

# 获取最佳参数和最佳准确率
best_params = grid_search.best_params_
best_score = grid_search.best_score_

(八) 打印最佳参数和最佳准确率

# 打印最佳参数和最佳准确率
print("Best parameters:", best_params)
print("Best cross-validation score (accuracy):", best_score)

(九) 使用最佳参数训练AdaBoost分类器

best_ada_classifier = AdaBoostClassifier(estimator = base_estimator,random_state=0,n_estimators=best_params['n_estimators'], learning_rate=best_params['learning_rate'])
best_ada_classifier.fit(X_train, y_train)

(十) 进行预测并计算准确率

# 使用训练好的AdaBoost分类器进行预测
predictions = best_ada_classifier.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, predictions)
# 打印准确率
print("Accuracy with best parameters:", accuracy)

(十一) 输出评估报告

from sklearn.metrics import accuracy_score, classification_report
print('模型的准确率为:\n', accuracy_score(y_test, predictions))
print('模型的评估报告:\n', classification_report(y_test, predictions))

 补充:为什么改变基础弱分类器的参数,Adaboost的准确率不变?

当你改变决策树(base_classifier)的参数,例如criterion、max_depth、min_samples_split和min_samples_leaf,实际上是在改变每个基础分类器的性能。然而,AdaBoostClassifier的fit方法会根据给定的estimator(这里就是基础弱分类器)训练多个弱分类器,并对它们的权重进行调整。如果这些基础模型的性能变化不明显,或者它们之间有很好的互补性,那么整个AdaBoost模型的组合效果可能不会显著提升准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/725475.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ADOP带你了解:数据中心的高速互联解决方案

随着大语言模型和AIGC的飞速发展,数据中心对于高速、高可靠性的网络连接需求日益增长。ADOP系列产品正是在这样的背景下应运而生,为现代数据中心提供了全面的连接解决方案。 ADOP系列产品概览 ADOP系列产品旨在为云、高性能计算、Web 2.0、企业、电信、…

镁光全球扩张HBM3E内存生产,提升市场占有率

镁光科技(Micron)近期透露了其在高性能内存领域的大胆布局,意图显著提升其在高带宽内存(HBM)市场的份额。根据《Nikkei》的最新报道,这家内存巨头正计划在全球范围内扩大HBM3E内存的生产能力,一…

Golang | Leetcode Golang题解之第167题两数之和II-输入有序数组

题目&#xff1a; 题解&#xff1a; func twoSum(numbers []int, target int) []int {low, high : 0, len(numbers) - 1for low < high {sum : numbers[low] numbers[high]if sum target {return []int{low 1, high 1}} else if sum < target {low} else {high--}}r…

高速公路收费图片分析系统深入理解

当今社会&#xff0c;随着交通运输业的快速发展&#xff0c;高速公路已成为人们出行的重要选择。而高速公路收费系统作为保障道路可持续运营的重要组成部分&#xff0c;其效率和准确性对于保障道路畅通和交通安全至关重要。近年来&#xff0c;随着技术的不断进步&#xff0c;高…

Spring MVC学习记录(基础)

目录 1.SpringMVC概述1.1 MVC介绍1.2 Spring MVC介绍1.3 Spring MVC 的核心组件1.4 SpringMVC 工作原理 2.Spring MVC入门2.1 入门案例2.2 总结 3.RequestMapping注解4.controller方法返回值4.1 返回ModelAndView4.2 返回字符串4.2.1 逻辑视图名4.2.2 Redirect重定向4.2.3 forw…

Java | Leetcode Java题解之第168题Excel表列名称

题目&#xff1a; 题解&#xff1a; class Solution {public String convertToTitle(int columnNumber) {StringBuffer sb new StringBuffer();while (columnNumber ! 0) {columnNumber--;sb.append((char)(columnNumber % 26 A));columnNumber / 26;}return sb.reverse().t…

信息打点web篇----企业宏观资产打点

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 专栏描述&#xff1a;因为第一遍过信息收集的时候&#xff0c;没怎么把收集做回事 导致后来在实战中&#xff0c;遭遇资产获取少&#xff0c;可渗透点少的痛苦&#xff0c;如今决定 从头来过&#xff0c;全面全方位…

Parallels Desktop 19 for mac破解版安装激活使用指南

Parallels Desktop 19 for Mac 乃是一款适配于 Mac 的虚拟化软件。它能让您在 Mac 计算机上同时运行多个操作系统。您可借此创建虚拟机&#xff0c;并于其中装设不同的操作系统&#xff0c;如 Windows、Linux 或 macOS。使用 Parallels Desktop 19 mac 版时&#xff0c;您可在 …

代码随想录算法训练营第四十三天|518. 零钱兑换 II ,LCR 103. 零钱兑换,377. 组合总和 Ⅳ

518. 零钱兑换 II - 力扣&#xff08;LeetCode&#xff09; class Solution {public int change(int amount, int[] coins) {// 创建dp数组,dp[i][j] 表示使用前i个硬币&#xff08;下标为0的硬币是前1个&#xff09;凑成总金额j的硬币组合数int[][] dp new int[coins.length …

Vue--》从零开始打造交互体验一流的电商平台(三)

今天开始使用 vue3 + ts 搭建一个电商项目平台,因为文章会将项目的每处代码的书写都会讲解到,所以本项目会分成好几篇文章进行讲解,我会在最后一篇文章中会将项目代码开源到我的github上,大家可以自行去进行下载运行,希望本文章对有帮助的朋友们能多多关注本专栏,学习更多…

C#开发-集合使用和技巧(八)集合中的排序Sort、OrderBy、OrderByDescending

C#开发-集合使用和技巧&#xff08;八&#xff09;集合中的排序Sort、OrderBy、OrderByDescending List<T>.Sort()IEnumerable<T>.OrderBy()Enumerable<T>.OrderByDescending() 在C#中&#xff0c;List<T> 类提供了多种方法来进行排序&#xff0c;最常…

数学建模基础:数学建模概述

目录 前言 一、数学建模的步骤 二、模型的分类 三、模型评价指标 四、常见的数学建模方法 实际案例&#xff1a;线性回归建模 步骤 1&#xff1a;导入数据 步骤 2&#xff1a;数据预处理 步骤 3&#xff1a;建立线性回归模型 步骤 4&#xff1a;模型验证 步骤 5&…

springboot集成积木报表,怎么将平台用户信息传递到积木报表

springboot集成积木报表后怎么将平台用户信息传递到积木报表 起因是因为需要研究在积木报表做数据筛选的时候需要拿到系统当前登录用户信息做筛选新的模块 起因是因为需要研究在积木报表做数据筛选的时候需要拿到系统当前登录用户信息做筛选 官网有详细介绍怎么集成进去的&…

JAVA 注解搜索工具类与注解原理讲解(获取方法和类上所有的某个注解,父类继承的注解也支持获取)

文章目录 JAVA 注解搜索工具类与注解原理讲解&#xff08;获取方法和类上所有的某个注解&#xff0c;父类继承的注解也支持获取&#xff09;代码测试方法上加注解&#xff0c;类上不加类上加注解、方法上加注解 注解原理性能测试 JAVA 注解搜索工具类与注解原理讲解&#xff08…

HTML基本标签使用【超链接标签、表格标签、表单标签、input标签】

目录 一、基本介绍1.1 概念1.2 HTML的核心特点 二、HTML基本标签三、超链接标签四、表格标签✌<table> 标签属性✍<tr> 标签属性✌ <td> 和 <th> 标签属性演示注意事项 五、表单标签综合应用 最后 一、基本介绍 1.1 概念 HTML&#xff0c;全称为超文…

win11照片裁剪视频无法保存问题解决

win11照片默认走核显&#xff0c;intel的显卡可能无法解码&#xff0c;在设置里把照片的显示卡默认换成显卡就好了

NSSCTF中的[WUSTCTF 2020]朴实无华、[FSCTF 2023]源码!启动! 、[LitCTF 2023]Flag点击就送! 以及相关知识点

目录 [WUSTCTF 2020]朴实无华 [FSCTF 2023]源码&#xff01;启动! [LitCTF 2023]Flag点击就送&#xff01; 相关知识点 1.intval 绕过 绕过的方式&#xff1a; 2.session伪造攻击 [WUSTCTF 2020]朴实无华 1.进入页面几乎没什么可用的信息&#xff0c;所以想到使用dis…

小白学-WEBGL

第一天&#xff1a; 1.canvas和webgl的区别 Canvas 和 WebGL 都是用于在网页上绘制图形的技术&#xff0c;它们通过浏览器提供的 API 使开发者能够创建丰富的视觉内容&#xff0c;但它们的工作原理和用途有所不同。 Canvas Canvas API 提供了一个通过 JavaScript 和 HTML <…

【STM32-ST-Link】

STM32-ST-Link ■ ST-Link简介■ ST-Link驱动的安装。■ ST-Link编程软件(MDK)配置。■ ST-Link固件升级方法 ■ ST-Link简介 由于德产 J-LINK 价格非常昂贵&#xff0c; 而国产 J-LINK 因为版权问题将在万能的淘宝销声匿迹。 所以我们有必要给大家介绍 JTAG/SWD 调试工具中另…

PHP转Go系列 | 字符串的使用姿势

大家好&#xff0c;我是码农先森。 输出 在 PHP 语言中的输出比较简单&#xff0c;直接使用 echo 就可以。此外&#xff0c;在 PHP 中还有一个格式化输出函数 sprintf 可以用占位符替换字符串。 <?phpecho 码农先森; echo sprintf(码农:%s, 先森);在 Go 语言中调用它的输…