机器学习 sklearn 中的超参数搜索方法

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 超参数搜索
    • 默认参数
    • GridSearchCV
    • RandomizedSearchCV
    • HalvingGridSearchCV
    • HalvingRandomSearchCV


超参数搜索

在建模时模型的超参数往往会对精度造成一定影响,而设置和调整超参数的取值,往往称为调参

在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜索,找到最优的超参数组合。

在 sklearn 中,提供了四种超参数搜索方法:

  • GridSearchCV
  • RandomizedSearchCV
  • HalvingGridSearchCV
  • HalvingRandomSearchCV

默认参数

为了方便起见,我们先定义一个默认参数的模型,用于后续的超参数搜索。

# 导入相关库
import random
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import (
    train_test_split,
    GridSearchCV,
    RandomizedSearchCV,
    HalvingGridSearchCV,
    HalvingRandomSearchCV,
)
from sklearn.ensemble import RandomForestRegressor

# 设置随机种子
seed = 1
random.seed(seed)
np.random.seed(seed)

# 加载数据集
data = datasets.load_diabetes()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)

# 使用默认参数模型进行分类并评分
reg = RandomForestRegressor(random_state=seed)
reg.fit(X_train, y_train)
print(round(reg.score(X_test, y_test), 6))

最终默认参数的模型在测试集上的 R 2 R^2 R2 分数约为 0.269413

GridSearchCV

GridSearchCV 是一种网格搜索超参数的方法,它会遍历所有的超参数组合,然后评估模型的性能,最终选择性能最好的一组超参数。

# 设置超参数搜索范围
param_grid = {
    "max_depth": [2, 4, 5, 6, 7],
    "min_samples_leaf": [1, 2, 3],
    "min_weight_fraction_leaf": [0, 0.1],
    "min_impurity_decrease": [0, 0.1, 0.2]
}

# 使用 GridSearchCV 进行超参数搜索
reg = GridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

在这个超参数搜索空间中,一共有 5 * 3 * 2 * 3 = 90 种超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

RandomizedSearchCV

RandomizedSearchCV 是一种随机搜索超参数的方法,它的使用方法与 GridSearchCV 类似,但是它不会遍历所有的超参数组合,而是在超参数的取值范围内随机选择一组超参数进行训练,然后评估模型的性能,最终选择性能最好的一组超参数。

# 使用 RandomizedSearchCV 进行超参数搜索
reg = RandomizedSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    n_iter=20,  # 设置迭代次数
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

RandomizedSearchCV 一共进行了 20 次迭代,即尝试了 20 组超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

HalvingGridSearchCV

HalvingGridSearchCVGridSearchCV 类似,但在迭代的过程中采用减半超参数搜索空间的方法,以此来减少超参数搜索的时间。

在搜索的最开始,HalvingGridSearchCV 使用很少的数据样本来在完整的超参数搜索空间中进行搜索,筛选其中最优的超参数,之后再增加数据进行进一步筛选。

# 使用 HalvingGridSearchCV 进行超参数搜索
reg = HalvingGridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

可以看到,HalvingGridSearchCV 得到的最优超参数组合与 GridSearchCV 得到的最优超参数组合相同。

HalvingRandomSearchCV

HalvingRandomSearchCVHalvingGridSearchCV 类似,都是逐步增加样本数量,减少超参数组合,但是 HalvingRandomSearchCV 每次生成的超参数组合是随机的。

# 使用 HalvingRandomSearchCV 进行超参数搜索
reg = HalvingRandomSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/226402.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一个能让你涨薪5k的自动化测试小技巧项目流程优化

引言及概览 项目流程优化是一个持续过程,每个公司,每个团队情况不一样,总原则是:如果在项目过程中你感觉到某一点很别扭、很不爽、痛了,那么这就是优化点。 优化的手段是多样化的,如通过流程规范去约束、…

将图像增广应用于Mnist数据集

将图像增广应用于Mnist数据集 不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用 首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片 import time import torch from torch im…

C++ STL容器与常用库函数

STL是提高C编写效率的一个利器 STL容器&#xff1a; 一、#include <vector> 英文翻译&#xff1a;vector &#xff1a;向量 vector是变长数组(动态变化)&#xff0c;支持随机访问&#xff0c;不支持在任意位置O(1)插入。为了保证效率&#xff0c;元素的增删一般应该在末尾…

测开速成考点

用户登录界面需要做哪些测试&#xff1f; ⑴功能&#xff1a;  正确的用户名&#xff0c;正确的密码&#xff0c;点击提交按钮&#xff0c;验证能否正确登录 是否跳转正确页面。  错误的用户名 或 密码&#xff0c;点击提交&#xff0c;提示相应的错误信息。  记住用户名…

JavaScript实现手写签名,可触屏手写,支持移动端与PC端双端保存

目录 1.HTML模板 2.获取DOM元素和定义变量 3.创建两个canvas元素&#xff0c;并设置它们的宽度和高度 4.绑定触摸事件&#xff1a;touchstart, touchmove, touchend和click 5.实现触摸事件回调函数&#xff1a;startDrawing, draw和stopDrawing 6.实现绘制线段的函数&…

C语言——输出菱形

法一&#xff1a; #include<stdio.h> #define N 7 //假设输出7层菱形 int main(){int i;//i控制第几行 int j;//j控制每一行空格的循环个数 int k;//k控制每一行*的循环次数 for(i1;i<4;i){//将图形分为两部分,前四行(第一部分) for(j1;j<4-i;j){//输出第i行的…

记录一下本地源码安装部署ThingsBoard可能踩到的坑

使用git下载源码后, 必须运行 mvn clean install -DskipTests这一步很重要, 有很多文件需要初始化, 如果直接放入idea可能存在各种问题, 最好是用命令行执行 初始化时, 可能报错停止, 这个一般是网络问题, 可以尝试修改maven镜像, 这是我成功构建的镜像 <!--阿里云仓库--…

膜结构建筑:未来体育可持续发展的绿色引擎

随着城市化的飞速发展&#xff0c;现代建筑迫切需要创新性的解决方案&#xff0c;而膜结构建筑以其独特的设计理念和可持续性特点&#xff0c;正在成为未来城市发展的重要引擎。本文将深入探讨膜结构建筑在可持续城市发展中的关键作用&#xff0c;包括其在节能减排、资源有效利…

“分割“安卓用户,对标iOS,鸿蒙崛起~

近期关于**“华为于明年推出不兼容安卓的鸿蒙版本”**的消息传出&#xff0c;引起了业界的热议关注。自从2019年8月&#xff0c;美国制裁下&#xff0c;华为不再能够获得谷歌安卓操作系统相关付费服务&#xff0c;如此情况下&#xff0c;华为“备胎”鸿蒙操作系统一夜转正。 华…

Swift 中 User Defaults 的读取和写入

文章目录 前言介绍 User Defaults共享 User DefaultsUser Defaults 存储数据类型响应更改监控 User Defaults 更改覆盖User Defaults 设置考虑的替代方案Keychain 用于安全性用于跨平台的 CloudKit 结论 前言 User Defaults 是 Swift 应用程序存储在应用启动之间保持的首选项的…

音视频学习(二十)——rtsp收流(udp方式)

前言 本文主要介绍通过udp方式实现rtsp拉流。 流程图 流程说明&#xff1a; 相较于tcp方式“信令数据”复用同一连接拉流&#xff0c;udp方式拉流“信令数据”采用不同的连接&#xff0c;信令传输采用tcp&#xff0c;流数据传输采用udp&#xff1b;客户端向服务端&#xff0…

Android平板还能编程?Ubuntu本地安装code-server远程编程写代码

文章目录 1.ubuntu本地安装code-server2. 安装cpolar内网穿透3. 创建隧道映射本地端口4. 安卓平板测试访问5.固定域名公网地址6.结语 1.ubuntu本地安装code-server 准备一台虚拟机,Ubuntu或者centos都可以&#xff0c;这里以VMwhere ubuntu系统为例 下载code server服务,浏览器…

毕设:《基于hive的音乐数据分析系统的设计与实现》

文章目录 环境启动一、爬取数据1.1、歌单信息1.2、每首歌前20条评论1.3、排行榜 二、搭建环境1.1、搭建JAVA1.2、配置hadoop1.3、配置Hadoop环境&#xff1a;YARN1.4、MYSQL1.5、HIVE(数据仓库)1.6、Sqoop&#xff08;关系数据库数据迁移&#xff09; 三、hadoop配置内存四、导…

原生cesium、mars3d、supermap-cesium在vue3+vite中引入

1. 原生cesium 需要下载 yarn add cesiumyarn add vite-plugin-cesium2. mars3d 需要下载 yarn add mars3d mars3d-cesiumyarn add vite-plugin-mars3d3. supermap-cesium 只需要引入官网下载的包&#xff0c;build文件夹下的cesium&#xff0c;以及项目中引入的其他cesiu…

ChatGPT 应用开发(一)ChatGPT OpenAI API 免代理调用方式(通过 Cloudflare 的 AI Gateway)

前言 开发 ChatGPT 应用&#xff0c;我觉得最前置的点就是能使用 ChatGPT API 接口。首先我自己要能成功访问&#xff0c;这没问题&#xff0c;会魔法就可以本地调用。 那用户如何调用到我的应用 API 呢&#xff0c;我的理解是通过用户能访问到的中转服务器向 OpenAI 发起访问…

Linux篇:进程间通信

一、进程间通信原理&#xff1a; 1、通信是有成本的&#xff1a;两个或者多个进程&#xff0c;实现数据层面的交互&#xff0c;因为进程独立性的存在&#xff0c;导致进程通信的成本比较高。 2、进程间通信的方式&#xff1a; ①基本数据 ②发送命令 ③某种协同 ④通知 .....…

C++新经典模板与泛型编程:用成员函数重载实现std::is_class

用成员函数重载实现is_class std::is_class功能&#xff0c;是一个C11标准中用于判断某个类型是否为一个类类型&#xff08;但不是联合类型&#xff09;的类模板。当时在讲解的时候并没有涉及std::is_class的实现代码&#xff0c;在这里实现一下。简单地书写一个IsClass类模板…

Linux学习笔记7-IIC的应用和AP3216C

接下来进入其他两种串行通信方式&#xff1a;SPI和I2C的学习&#xff0c;因为以后的项目中会用到这些通信方式&#xff0c;而且正点原子的开发板里面也有用I2C和SPI通信的传感器来做实例&#xff0c;分别是一个距离传感器和六轴陀螺仪&#xff0c;这样就可以很好的通过实例来学…

TCP通讯

第二十一章 网络通信 本章节主要讲解的是TCP和UDP两种通信方式它们都有着自己的优点和缺点 这两种通讯方式不通的地方就是TCP是一对一通信 UDP是一对多的通信方式 接下来会一一讲解 TCP通信 TCP通信方式呢 主要的通讯方式是一对一的通讯方式&#xff0c;也有着优点和缺点…

maven-assembly-plugin 自定义打包

我想把input文件夹给打包进去 pom文件 <build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-assembly-plugin</artifactId><executions><execution><!-- 配置执行器 --><i…