机器学习参数调优

手动调参  

分析影响模型的参数,设计步长进行交叉验证

我们以随机森林为例:

本文将使用sklearn自带的乳腺癌数据集,建立随机森林,并基于泛化误差(Genelization Error)与模型复杂度的关系来对模型进行调参,从而使模型获得更高的得分。

泛化误差是机器学习中,用来衡量模型在未知数据上的准确率的指标;

1、导入相关包

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2、导入乳腺癌数据集,建立模型

由于sklearn自带的数据集已经很工整了,所以无需做预处理,直接使用。

# 导入乳腺癌数据集
data = load_breast_cancer()

# 建立随机森林
rfc = RandomForestClassifier(n_estimators=100, random_state=90)

用交叉验证计算得分
score_pre = cross_val_score(rfc, data.data, data.target, cv=10).mean()
score_pre

3、调参

随机森林主要的参数有n_estimators(子树的数量)、max_depth(树的最大生长深度)、min_samples_leaf(叶子的最小样本数量)、min_samples_split(分支节点的最小样本数量)、max_features(最大选择特征数)。它们对随机森林模型复杂度的影响如下图所示:

n_estimators是影响程度最大的参数,我们先对其进行调整

# 调参,绘制学习曲线来调参n_estimators(对随机森林影响最大)
score_lt = []

# 每隔10步建立一个随机森林,获得不同n_estimators的得分
for i in range(0,200,10):
    rfc = RandomForestClassifier(n_estimators=i+1
                                ,random_state=90)
    score = cross_val_score(rfc, data.data, data.target, cv=10).mean()
    score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),
      '子树数量为:{}'.format(score_lt.index(score_max)*10+1))

# 绘制学习曲线
x = np.arange(1,201,10)
plt.subplot(111)
plt.plot(x, score_lt, 'r-')
plt.show()

如图所示,当n_estimators从0开始增大至21时,模型准确度有肉眼可见的提升。这也符合随机森林的特点:在一定范围内,子树数量越多,模型效果越好。而当子树数量越来越大时,准确率会发生波动,当取值为41时,获得最大得分。


框架自动调参

Optuna是一个自动化的超参数优化软件框架,专门为机器学习而设计。 这里对其进行简单的入门介绍,详细的学习可以参考:https://github.com/optuna/optuna

 

optuna是一个使用python编写的超参数调节框架。一个极简的 optuna 的优化程序中只有三个最核心的概念,目标函数(objective),单次试验(trial),和研究(study)。其中 objective 负责定义待优化函数并指定参/超参数数范围,trial 对应着 objective 的单次执行,而 study 则负责管理优化,决定优化的方式,总试验的次数、试验结果的记录等功能。

  • objective:根据目标函数的优化Session,由一系列的trail组成。
  • trail:根据目标函数作出一次执行。
  • study:根据多次trail得到的结果发现其中最优的超参数。

随机森林iris数据集调优

from sklearn.datasets import load_iris
x, y = load_iris().data, load_iris().target
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
def objective(trial):
    global x, y
    X_train, X_test, y_train, y_test=train_test_split(x, y, train_size=0.3)# 数据集划分
    param = {
        "n_estimators": trial.suggest_int('n_estimators', 5, 20),
        "criterion": trial.suggest_categorical('criterion', ['gini','entropy'])
    }

    dt_clf = RandomForestClassifier(**param)
    dt_clf.fit(X_train, y_train)
    pred_dt = dt_clf.predict(X_test)
    score = (y_test==pred_dt).sum() / len(y_test)
    return score
study=optuna.create_study(direction='maximize')
n_trials=20 # try50次
study.optimize(objective, n_trials=n_trials)
print(study.best_value)
print(study.best_params)
#######################################结果######################################
[32m[I 2021-04-12 16:20:13,627][0m A new study created in memory with name: no-name-47fe20d7-e9c0-4bed-bc6d-8113edae0bec[0m
[32m[I 2021-04-12 16:20:13,652][0m Trial 0 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 15, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,662][0m Trial 1 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 5, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,680][0m Trial 2 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 15, 'criterion': 'entropy'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,689][0m Trial 3 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 7, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,704][0m Trial 4 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 14, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,721][0m Trial 5 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 14, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,733][0m Trial 6 finished with value: 0.9619047619047619 and parameters: {'n_estimators': 10, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,753][0m Trial 7 finished with value: 0.9619047619047619 and parameters: {'n_estimators': 18, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,764][0m Trial 8 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 8, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,771][0m Trial 9 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 5, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,795][0m Trial 10 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 20, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,809][0m Trial 11 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 9, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,827][0m Trial 12 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 12, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,842][0m Trial 13 finished with value: 0.9238095238095239 and parameters: {'n_estimators': 11, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,855][0m Trial 14 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 8, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,880][0m Trial 15 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 18, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,899][0m Trial 16 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 13, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,911][0m Trial 17 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 7, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,933][0m Trial 18 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 17, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,948][0m Trial 19 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 11, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m


0.9714285714285714
{'n_estimators': 14, 'criterion': 'gini'}
##################################################################################

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/70694.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【趋势检测和隔离】使用小波进行趋势检测和隔离研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

MD-MTSP:星雀优化算法NOA求解多仓库多旅行商问题MATLAB(可更改数据集,旅行商的数量和起点)

一、星雀优化算法NOA 星雀优化算法(Nutcracker optimizer algorithm,NOA)由Mohamed Abdel-Basset等人于2023年提出,该算法模拟星雀的两种行为,即:在夏秋季节收集并储存食物,在春冬季节搜索食物的存储位置。星雀优化算法(Nutcrack…

Azure通过自动化账户实现对资源变更

Azure通过自动化账户实现对资源变更 创建一个自动化账户第一种方式 添加凭据(有更改资源权限的账户,没有auth认证情况)创建一个Runbook,测试修改 AnalysisServices 定价层设置定时任务:开始定时任务: 第二种…

STM32F429IGT6使用CubeMX配置GPIO点亮LED灯

1、硬件电路 2、设置RCC,选择高速外部时钟HSE,时钟设置为180MHz 3、配置GPIO引脚 4、生成工程配置 5、部分代码 /* USER CODE BEGIN WHILE */while (1){/* USER CODE END WHILE *//* USER CODE BEGIN 3 */HAL_GPIO_WritePin(LED_RGB_GPIO_Port, LED_RGB_Pin, GPIO_…

RocketMQ 延迟消息

RocketMQ 延迟消息 RocketMQ 消费者启动流程 什么是延迟消息 RocketMQ 延迟消息是指,生产者发送消息给消费者消息,消费者需要等待一段时间后才能消费到。 使用场景 用户下单之后,15分钟未支付,对支付账单进行提醒或者关单处理…

OpenCV实例(八)车牌字符识别技术(二)字符识别

车牌字符识别技术(二)字符识别 1.字符识别原理及其发展阶段2.字符识别方法3.英文、数字识别4.车牌定位实例 1.字符识别原理及其发展阶段 匹配判别是字符识别的基本思想,与其他模式识别的应用非常类似。字符识别的基本原理就是对字符图像进行…

实时时钟+闹钟

在江科大实时时钟的基础上添加闹钟的配置,参考http://t.csdn.cn/YDlYy。 实现功能 :每隔time秒蜂鸣器响一次、设置闹钟的年月日时分秒,到时间蜂鸣器响。 前三个函数没有变,添加 void RTC_AlarmInit(void) 闹钟的中断配置void…

【算法题】螺旋矩阵I (求解n阶螺旋矩阵问题)

一、问题的提出 螺旋矩阵是一种常见的矩阵形式,它的特点是按照螺旋的方式排列元素。n阶螺旋矩阵是指矩阵的大小为nn,其中n为正整数。 二、解决的思路 当N1时,矩阵为; 当N2时,矩阵为; 当N>2(N为偶数如N4)时,矩阵…

Arduino ESP32 v2 使用记录:开发环境搭建

文章目录 目的开发环境搭建程序下载测试使用VS Code进行开发批量烧录固件到模块中总结 目的 在之前的文章 《使用Arduino开发ESP32(01):开发环境搭建》 中介绍了使用Arduino开发ESP32的开发环境搭建内容,只不过当时的 Arduino co…

Django进阶

1.orm 1.1 基本操作 orm,关系对象映射。 类 --> SQL --> 表 对象 --> SQL --> 数据特点:开发效率高、执行效率低( 程序写的垃圾SQL )。 编写ORM操作的步骤: settings.py,连…

Metasploitable2靶机漏洞复现

一、信息收集 nmap扫描靶机信息 二、弱口令 1.系统弱口令 在Kali Linux中使用telnet远程连接靶机 输入账号密码msfadmin即可登录 2.MySQL弱口令 使用mysql -h 靶机IP地址即可连接 3.PostgreSQL弱密码登录 输入psql -h 192.168.110.134 -U postgres 密码为postgres 输入\…

线性代数(二) 矩阵及其运算

前言 行列式det(A) 其实表示的只是一个值 ∣ a b c d ∣ a d − b c \begin{vmatrix} a & b\\ c & d\end{vmatrix} ad -bc ​ac​bd​ ​ad−bc,其基本变化是基于这个值是不变。而矩阵表示的是一个数表。 定义 矩阵与线性变换的关系 即得 ( a 11 a 12…

数据结构——单链表的实现(c语言版)

前言 单链表作为顺序表的一种,了解并且熟悉它的结构对于我们学习更加复杂的数据结构是有一定意义的。虽然单链表有一定的缺陷,但是单链表也有它存在的价值, 它也是作为其他数据结构的一部分出现的,比如在图,哈希表中。…

java之junit Test

JUnit测试简介 1.什么是单元测试 单元测试是针对最小的功能单元编写测试代码Java程序最小的功能单元是方法单元测试就是针对单个Java方法的测试 2.测试驱动开发 3.单元测试的好处 确保单个方法运行正常如果修改了方法代码,只需确保其对应的单元测试通过测试代码…

【深度学习】【风格迁移】Visual Concept Translator,一般图像到图像的翻译与一次性图像引导,论文

General Image-to-Image Translation with One-Shot Image Guidance 论文:https://arxiv.org/abs/2307.14352 代码:https://github.com/crystalneuro/visual-concept-translator 文章目录 Abstract1. Introduction2. 相关工作2.1 图像到图像转换2.2. Di…

使用chatGPT生成提示词,在文心一言生成装修概念图

介绍 家是情感的港湾,而家居装修则是将情感融入空间的艺术。如何在有限的空间里展现个性与美感,成为了现代人关注的焦点。而今,随着人工智能的发展,我们发现了一个新的创意助手——ChatGPT,它不仅为我们带来了更多可能…

nodejs+vue+elementui招聘求职网站系统的设计与实现-173lo

(1)管理员的功能是最高的,可以对系统所在功能进行查看,修改和删除,包括企业和用户功能。管理员用例如下: 图3-1管理员用例图 (2)企业关键功能包含个人中心、岗位类型管理、招聘信息…

C语言每日一题:16:数对。

思路一&#xff1a;基本思路 1.x,y均不大于n&#xff0c;就是小于等于n。 2.x%y大于等于k。 3.一般的思路使用双for循环去遍历每一对数。 代码实现&#xff1a; #include <stdio.h> int main() {int n 0;int k 0;//输入scanf("%d%d", &n, &k);int x…

【深度学习注意力机制系列】—— ECANet注意力机制(附pytorch实现)

ECANet&#xff08;Efficient Channel Attention Network&#xff09;是一种用于图像处理任务的神经网络架构&#xff0c;它在保持高效性的同时&#xff0c;有效地捕捉图像中的通道间关系&#xff0c;从而提升了特征表示的能力。ECANet通过引入通道注意力机制&#xff0c;以及在…

【Plex】FRP内网穿透后 App无法使用问题

能搜索到这个文章的&#xff0c;应该都看过这位同学的分析【Plex】FRP内网穿透后 App无法使用问题_plex frp无效_Fu1co的博客-CSDN博客 这个是必要的过程&#xff0c;但是设置之后仍然app端无法访问&#xff0c;原因是因为网络端口的问题 这个里面的这个公开端口&#xff0c;可…