如何用sklearn对随机森林调参

文章目录

  • 一、概述
  • 二、实操
    • 1、导入相关包
    • 2、导入乳腺癌数据集,建立模型
    • 3、调参
  • 三、总结

Link:https://zhuanlan.zhihu.com/p/126288078
Author:陈罐头

一、概述

sklearn是目前python中十分流行的用来实现机器学习的第三方包,其中包含了多种常见算法如:决策树,逻辑回归、集成算法(如随机森林)等等。

本文将使用sklearn自带的乳腺癌数据集,建立随机森林,并基于 泛化误差(Genelization Error)模型复杂度的关系来对模型进行调参,从而使模型获得更高的得分。

泛化误差是机器学习中,用来衡量模型在未知数据上的准确率的指标,其与模型复杂度的关系如下图所示:
在这里插入图片描述
当模型复杂度不足时,机器学习不足,会出现欠拟合现象,泛化误差变大;当复杂度逐渐提高到最佳模型复杂度时,泛化误差会达到最低点(即最高准确度);若复杂度仍在提高,泛化误差从最小值开始逐渐增大,出现过拟合现象。

因此,我们的目的,是通过不断调参来不断调整模型复杂度,尽可能地接近泛化误差最低点

二、实操

1、导入相关包

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2、导入乳腺癌数据集,建立模型

由于sklearn自带的数据集已经很工整了,所以无需做预处理,直接使用。

# 导入乳腺癌数据集
data = load_breast_cancer()

# 建立随机森林
rfc = RandomForestClassifier(n_estimators=100, random_state=90)

用交叉验证计算得分
score_pre = cross_val_score(rfc, data.data, data.target, cv=10).mean()
score_pre

在这里插入图片描述

初始得分

3、调参

随机森林主要的参数有n_estimators(子树的数量)、max_depth(树的最大生长深度)、min_samples_leaf(叶子的最小样本数量)、min_samples_split(分支节点的最小样本数量)、max_features(最大选择特征数)。

它们对随机森林模型复杂度的影响如下图所示:
在这里插入图片描述

可以看到,n_estimators是影响程度最大的参数,我们先对其进行调整:

# 调参,绘制学习曲线来调参n_estimators(对随机森林影响最大)
score_lt = []
# 每隔10步建立一个随机森林,获得不同n_estimators的得分
for i in range(0,200,10):
    rfc = RandomForestClassifier(n_estimators=i+1, random_state=90)
    score = cross_val_score(rfc, data.data, data.target, cv=10).mean()
    score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),
      '子树数量为:{}'.format(score_lt.index(score_max)*10+1))
# 绘制学习曲线
x = np.arange(1,201,10)
plt.subplot(111)
plt.plot(x, score_lt, 'r-')
plt.show()

在这里插入图片描述
如图所示,当n_estimators从0开始增大至21时,模型准确度有肉眼可见的提升。这也符合随机森林的特点:在一定范围内,子树数量越多,模型效果越好。而当子树数量越来越大时,准确率会发生波动,当取值为41时,获得最大得分。

接下来,我们在将取值范围缩小至41左右,以获得更好的取值。

# 在41附近缩小n_estimators的范围为30-49
score_lt = []
for i in range(30,50):
    rfc = RandomForestClassifier(n_estimators=i
                                ,random_state=90)
    score = cross_val_score(rfc, data.data, data.target, cv=10).mean()
    score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),
      '子树数量为:{}'.format(score_lt.index(score_max)+30))

# 绘制学习曲线
x = np.arange(30,50)
plt.subplot(111)
plt.plot(x, score_lt,'o-')
plt.show()

在这里插入图片描述

如图所示,当n_estimators=45时,获得最大得分score_max=0.9719,相较于score_pre提升0.005
在这里插入图片描述

由此我们发现:当n_estimators100减小至45时(模型复杂度由大到小),模型准确度提升了(泛化误差减小),说明在泛化误差图中,模型往左移动了!

因此,接下来的调参方向是使模型复杂度减小的方向,从而接近泛化误差最低点。我们使用能使模型复杂度减小,并且影响程度排第二的max_depth

# 建立n_estimators为45的随机森林
rfc = RandomForestClassifier(n_estimators=45, random_state=90)

# 用网格搜索调整max_depth
param_grid = {'max_depth':np.arange(1,20)}
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(data.data, data.target)

best_param = GS.best_params_
best_score = GS.best_score_
print(best_param, best_score)

在这里插入图片描述

如图所示,最佳深度为11,最大得分为0.9718,竟然比不调整深度的得分0.9719还低,难道我们刚才就已经十分接近最低泛化误差了吗?

本着严谨的态度,我们再进行调整。调整max_depth使模型复杂度减小,却获得了更低的得分,因此接下来我们需要朝着复杂度增大的方向调整。我们在n_estimators=45max_depth=11的情况下,对唯一能够增加模型复杂度的参数max_features进行调整:
在这里插入图片描述

查看数据集大小,发现一共有30列特征,由于max_features默认取值特征数量的开平方值,因此我们从5开始调整:

# 用网格搜索调整max_features
param_grid = {'max_features':np.arange(5,31)}

rfc = RandomForestClassifier(n_estimators=45
                            ,random_state=90
                            ,max_depth=11)
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(data.data, data.target)
best_param = GS.best_params_
best_score = GS.best_score_
print(best_param, best_score)     

在这里插入图片描述

输出结果为5,和默认值一样。得分为0.9718,仍然小于0.9719。因此,仅需n_estimators=45就能使模型的准确率达到最高0.9719,相较于初始得分0.9667,提升0.005,最接近最小泛化误差,调参工作到此结束。

三、总结

总结一下在sklearn中调参的思路:

① 基于泛化误差模型复杂度的关系来进行调参;

② 根据对模型的影响程度,由大到小对参数排序,并确定哪些参数会使模型复杂度减小,哪些会增大;

③ 依次选择合适的参数,通过绘制学习曲线或网格搜索的方法调参,直到找到最大准确得分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/123355.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【ChatGPT】人工智能的下一个前沿

🎊专栏【ChatGPT】 🌺每日一句:慢慢变好,我是,你也是 ⭐欢迎并且感谢大家指出我的问题 文章目录 一、引言 二、ChatGPT的工作原理 三、ChatGPT的主要特点 四、ChatGPT的应用场景 五、结论与展望 ​​​​​​​ 一、引言 随着人工智能技…

【QEMU-tap-windows-Xshell】QEMU 创建 aarch64虚拟机(附有QEMU免费资源)

“从零开始:在Windows上创建aarch64(ARM64)虚拟机” 前言 aarch64(ARM64)架构是一种现代的、基于 ARM 技术的计算架构,具有诸多优点,如低功耗、高性能和广泛应用等。为了在 Windows 平台上体验…

界面控件DevExpress WPF PDF Viewer,更快实现应用的PDF文档浏览

DevExpress WPF PDF Viewer控件可以轻松地直接在Windows应用程序中显示PDF文档,而无需在最终用户的机器上安装外部PDF查看器。 P.S:DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress…

1995-2020年全国各省二氧化碳排放量面板数据

1995-2020年全国各省二氧化碳排放面板数据 1、时间:1995-2020 2、范围:全国、30省 3、来源:中国能源统计NJ 4、指标: 统计年度、地区代码、地区名称、煤炭二氧化碳排放量、焦炭二氧化碳排放量、原油二氧化碳排放量、汽油二氧…

Android Studio布局

线性布局 水平或竖直排列子元素的布局容器 相对布局 可针对容器内每个子元素设置相对位置(相对于父容器或同级子元素的位置) 网格布局 找了下面这篇文章连接可以参考(不再赘述) GridLayout(网格布局) | 菜鸟教程 (runoob.com) …

HCIA-PPPOE原理与配置

PPPOE原理与配置 实验拓扑图实现步骤家庭网关 AR201PPPOE客户端( ISP光猫)PPPOE服务器(ISP路由器) 实验拓扑图 实现步骤 家庭网关 AR201 E0/0/0-7为LAN口(二层接口)E0/0/8为WAN口(三层接口&am…

SpringBoot 学习笔记(四) - 原理篇

一、自动配置 1.1 bean加载方式 bean的加载方式1 - xml方式声明bean 导入依赖&#xff1a; <dependencies><dependency><groupId>org.springframework</groupId><artifactId>spring-context</artifactId><version>5.3.9</ver…

muduo源码剖析之TcpClient客户端类

简介 muduo用TcpClient发起连接&#xff0c;TcpClient有一个Connector连接器&#xff0c;TCPClient使用Conneccor发起连接, 连接建立成功后, 用socket创建TcpConnection来管理连接, 每个TcpClient class只管理一个TcpConnecction&#xff0c;连接建立成功后设置相应的回调函数…

HP惠普暗影精灵9P OMEN 17.3英寸游戏本17-cm2000(70W98AV)原装出厂Windows11-22H2系统镜像

链接&#xff1a;https://pan.baidu.com/s/1gJ4ZwWW2orlGYoPk37M-cg?pwd4mvv 提取码&#xff1a;4mvv 惠普暗影9Plus笔记本电脑原厂系统自带所有驱动、出厂主题壁纸、 Office办公软件、惠普电脑管家、OMEN Command Center游戏控制中心等预装程序 所需要工具&#xff1a;3…

论文实验可视化方法

真实值预测值误差 张永, 龚众望, 郑英, 等. 工业设备的健康状态评估和退化趋势预测联合研究. 中国科学: 技术科学, 2022, 52: 180–197 Zhang Y, Gong Z W, Zheng Y, et al. Joint study on health state assessment and degradation trend prediction of industrial equipment…

技术分享 | Spring Boot 异常处理

Java 异常类 首先让我们简单了解或重新学习下 Java 的异常机制。 Java 内部的异常类 Throwable 包括了 Exception 和 Error 两大类&#xff0c;所有的异常类都是 Object 对象。 Error 是不可捕捉的异常&#xff0c;通俗的说就是由于 Java 内部 JVM 引起的不可预见的异常&#…

2009-2018年全国各省财政透明度数据

2009-2018年全国各省财政透明度数据 1、时间&#xff1a;2009-2018年 2、指标&#xff1a;财政透明度 3、范围&#xff1a;31省 4、来源&#xff1a;财政透明度报告 5、指标解释&#xff1a; 财政透明度是公开透明的重要方面&#xff0c;体现了现代预算制度和法治政府的特…

深入分析MySQL索引与磁盘读取原理

索引 索引是对数据库表中一列或者多列数据检索时&#xff0c;为了加速查询而创建的一种结构。可以在建表的时候创建&#xff0c;也可以在后期添加。 USER表中有100万条数据&#xff0c;现在要执行一个查询"SELECT * FROM USER where ID999999"&#xff0c;如果没有索…

selenium xpath定位

selenium-xpath定位 <span style"background-color:#2d2d2d"><span style"color:#cccccc"><code class"language-javascript">element_xpath <span style"color:#67cdcc"></span> driver<span styl…

嵌入式养成计划-48----QT--信息管理系统:百川仓储管理

一百二十二、信息管理系统&#xff1a;百川仓储管理 122.1 UI界面 122.2 思路 客户端&#xff1a; 用户权限有两种类型&#xff0c;一种是用户权限&#xff0c;一种是管理员权限&#xff0c;登录时服务器端会根据数据库查询到的此用户名的权限返回不同的结果&#xff0c;客户…

CodeWhisperer 的正确使用

1、重点&#xff1a; 重点1&#xff1a; 推出 Amazon Bedrock。这项新服务允许用户通过 API 访问来自 AI21 Labs、Anthropic、Stability AI 和亚马逊的基础模型。&#xff08;Anthropic 就是之前跟 ChatGPT 掰手腕的 Claude 的模型。Stability AI 就是 Stable Diffusion 背后的…

IP 地址冲突检测工具

IP 冲突是一个术语&#xff0c;用于表示同一网络或子网中尝试使用相同 IP 地址的两个或多个设备的状态&#xff0c;这可能会导致发往特定主机的通信与其他主机混淆&#xff0c;因为两者都使用相同的 IP&#xff0c;为了避免这种情况&#xff0c;某些主机在发生 IP 冲突时会失去…

MySQL中的多列子查询

-- 多列子查询 -- 如何查询与WOARD 的部门和岗位完全相同的所有雇员(并且不含smith本人) -- (字段1&#xff0c;字段2...) (select 字段1&#xff0c;字段2 from ...) -- 分析&#xff1a; 1. 得到smith的部门和岗位 SELECT deptno,job FROM empWHERE ename WARD; -- 2.使…

大数据-玩转数据-Flume

一、Flume简介 Flume提供一个分布式的,可靠的,对大数据量的日志进行高效收集、聚集、移动的服务,Flume只能在Unix环境下运行。Flume基于流式架构,容错性强,也很灵活简单。Flume、Kafka用来实时进行数据收集,Spark、Flink用来实时处理数据,impala用来实时查询。二、Flume…

挑战100天 AI In LeetCode Day05(热题+面试经典150题)

挑战100天 AI In LeetCode Day05&#xff08;热题面试经典150题&#xff09; 一、LeetCode介绍二、LeetCode 热题 HOT 100-72.1 题目2.2 题解 三、面试经典 150 题-73.1 题目3.2 题解 一、LeetCode介绍 LeetCode是一个在线编程网站&#xff0c;提供各种算法和数据结构的题目&am…