数据分析 | 调用Optuna库实现基于TPE的贝叶斯优化 | 以随机森林回归为例

1. Optuna库的优势

        对比bayes_opt和hyperoptOptuna不仅可以衔接到PyTorch等深度学习框架上,还可以与sklearn-optimize结合使用,这也是我最喜欢的地方,Optuna因此特性可以被使用于各种各样的优化场景。

 

2. 导入必要的库及加载数据

        用的是sklearn自带的房价数据,只是我把它保存下来了。

import optuna
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold,cross_validate
print(optuna.__version__)
from sklearn.ensemble import RandomForestRegressor as RFR
data = pd.read_csv(r'D:\2暂存文件\Sth with Py\贝叶斯优化\data.csv')
X = data.iloc[:,0:8]
y = data.iloc[:,8]

3. 定义目标函数与参数空间

        Optuna相对于其他库,不需要单独输入参数或参数空间,只需要直接在目标函数中定义参数空间即可。这里以负均方误差为损失函数。

def optuna_objective(trial) :
    # 定义参数空间
    n_estimators = trial.suggest_int('n_estimators',10,100,1)
    max_depth = trial.suggest_int('max_depth',10,50,1)
    max_features = trial.suggest_int('max_features',10,30,1)
    min_impurtity_decrease = trial.suggest_float('min_impurity_decrease',0.0, 5.0, step=0.1)

    # 定义评估器
    reg = RFR(n_estimators=n_estimators,
              max_depth=max_depth,
              max_features=max_features,
              min_impurity_decrease=min_impurtity_decrease,
              random_state=1412,
              verbose=False,
              n_jobs=-1)

    # 定义交叉过程,输出负均方误差
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    validation_loss = cross_validate(reg,X,y,
                                     scoring='neg_mean_squared_error',
                                     cv=cv,
                                     verbose=True,
                                     n_jobs=-1,
                                     error_score='raise')
    return np.mean(validation_loss['test_score'])

4.  定义优化目标函数

        在Optuna中我们可以调用sampler模块进行选用想要的优化算法,比如TPE、GP等等。

def optimizer_optuna(n_trials,algo):

    # 定义使用TPE或GP
    if algo == 'TPE':
        algo = optuna.samplers.TPESampler(n_startup_trials=20,n_ei_candidates=30)
    elif algo == 'GP':
        from optuna.integration import SkoptSampler
        import skopt
        algo = SkoptSampler(skopt_kwargs={'base_estimator':'GP',
                                          'n_initial_points':10,
                                          'acq_func':'EI'})
    study = optuna.create_study(sampler=algo,direction='maximize')
    study.optimize(optuna_objective,n_trials=n_trials,show_progress_bar=True)

    print('best_params:',study.best_trial.params,
              'best_score:',study.best_trial.values,
              '\n')

    return study.best_trial.params, study.best_trial.values

5. 执行部分

import warnings
warnings.filterwarnings('ignore',message='The objective has been evaluated at this point before trails')
optuna.logging.set_verbosity(optuna.logging.ERROR)
best_params, best_score = optimizer_optuna(200,'TPE')

6. 完整代码

import optuna
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold,cross_validate
print(optuna.__version__)
from sklearn.ensemble import RandomForestRegressor as RFR

data = pd.read_csv(r'D:\2暂存文件\Sth with Py\贝叶斯优化\data.csv')
X = data.iloc[:,0:8]
y = data.iloc[:,8]

def optuna_objective(trial) :
    # 定义参数空间
    n_estimators = trial.suggest_int('n_estimators',10,100,1)
    max_depth = trial.suggest_int('max_depth',10,50,1)
    max_features = trial.suggest_int('max_features',10,30,1)
    min_impurtity_decrease = trial.suggest_float('min_impurity_decrease',0.0, 5.0, step=0.1)

    # 定义评估器
    reg = RFR(n_estimators=n_estimators,
              max_depth=max_depth,
              max_features=max_features,
              min_impurity_decrease=min_impurtity_decrease,
              random_state=1412,
              verbose=False,
              n_jobs=-1)

    # 定义交叉过程,输出负均方误差
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    validation_loss = cross_validate(reg,X,y,
                                     scoring='neg_mean_squared_error',
                                     cv=cv,
                                     verbose=True,
                                     n_jobs=-1,
                                     error_score='raise')
    return np.mean(validation_loss['test_score'])

def optimizer_optuna(n_trials,algo):

    # 定义使用TPE或GP
    if algo == 'TPE':
        algo = optuna.samplers.TPESampler(n_startup_trials=20,n_ei_candidates=30)
    elif algo == 'GP':
        from optuna.integration import SkoptSampler
        import skopt
        algo = SkoptSampler(skopt_kwargs={'base_estimator':'GP',
                                          'n_initial_points':10,
                                          'acq_func':'EI'})
    study = optuna.create_study(sampler=algo,direction='maximize')
    study.optimize(optuna_objective,n_trials=n_trials,show_progress_bar=True)

    print('best_params:',study.best_trial.params,
              'best_score:',study.best_trial.values,
              '\n')

    return study.best_trial.params, study.best_trial.values

import warnings
warnings.filterwarnings('ignore',message='The objective has been evaluated at this point before trails')
optuna.logging.set_verbosity(optuna.logging.ERROR)
best_params, best_score = optimizer_optuna(200,'TPE')








 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/78815.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Ordinals 之后,以太坊铭文协议 Ethscriptions 如何再塑 NFT 资产形态

随着加密市场的发展,NFT 赛道逐渐形成了其独有的市场。但在加密熊市的持续影响下,今年 NFT 赛道的发展充满坎坷与挑战。据 NFTGO 数据显示,截至 8 月 7 日,与去年相比,NFT 市值总计约 56.4 亿美元,过去 1 年…

【业务功能篇69】Springboot 树形菜单栏功能设计

业务场景: 系统的界面,前端设计的时候,一般会给一个菜单栏,顶部横向以及左侧纵向的导航栏菜单,这里后端返回菜单栏的时候,就涉及层级父子项的问题,所以返回数据的时候,我们需要按照树化形式返回…

SpringBoot 操作Redis、创建Redis文件夹、遍历Redis文件夹

文章目录 前言依赖连接 RedisRedis 配置文件Redis 工具类操作 Redis创建 Redis 文件夹查询数据遍历 Redis 文件夹 前言 Redis 是一种高性能的键值存储数据库,支持网络、可基于内存亦可持久化的日志型,而 Spring Boot 是一个简化了开发过程的 Java 框架。…

CSRF

文章目录 CSRF(get)CSRF(post)CSRF Token CSRF(get) 根据提示的用户信息登录 点击修改个人信息 开启bp代理,点击submit 拦截到请求数据包 浏览器关闭代理 刷新页面 CSRF(post) 使用BP生成CSRF POC post请求伪造,可以通过钓鱼网站,诱导用户去…

【LeetCode】买卖股票的最佳时机最多两次购买机会

买卖股票的最佳时机 题目描述算法分析程序代码 链接: 买卖股票的最佳时机 题目描述 算法分析 程序代码 class Solution { public:int maxProfit(vector<int>& prices) {int n prices.size();vector<vector<int>> f(n,vector<int>(3,-0x3f3f3f))…

数据生成 | MATLAB实现WGAN生成对抗网络数据生成

数据生成 | MATLAB实现WGAN生成对抗网络数据生成 目录 数据生成 | MATLAB实现WGAN生成对抗网络数据生成生成效果基本描述程序设计参考资料 生成效果 基本描述 1.WGAN生成对抗网络&#xff0c;数据生成&#xff0c;样本生成程序&#xff0c;MATLAB程序&#xff1b; 2.适用于MATL…

基于 Debian 12 的MX Linux 23 正式发布!

导读MX Linux 是基于 Debian 稳定分支的面向桌面的 Linux 发行&#xff0c;它是 antiX 及早先的 MEPIS Linux 社区合作的产物。它采用 Xfce 作为默认桌面环境&#xff0c;是一份中量级操作系统&#xff0c;并被设计为优雅而高效的桌面与如下特性的结合&#xff1a;配置简单、高…

阿里云云解析DNS核心概念与应用

文章目录 1.DNS解析基本概念1.1.DNS基本介绍1.2.域名的分层结构1.3.DNS解析原理1.4.DNS递归查询和迭代查询的区别1.5.DNS常用的解析记录 2.使用DNS云解析将域名与SLB公网IP进行绑定2.1.进入云解析DNS控制台2.2.添加域名解析记录2.3.验证解析是否生效 1.DNS解析基本概念 DNS官方…

C++超基础语法

&#x1f493;博主个人主页:不是笨小孩&#x1f440; ⏩专栏分类:数据结构与算法&#x1f440; C&#x1f440; 刷题专栏&#x1f440; C语言&#x1f440; &#x1f69a;代码仓库:笨小孩的代码库&#x1f440; ⏩社区&#xff1a;不是笨小孩&#x1f440; &#x1f339;欢迎大…

SSH远程连接MacOS catalina并进行终端颜色配置

一、开关SSH服务 在虚拟机上安装了MacOS catalina&#xff0c;想要使用SSH远程进行连接&#xff0c;但是使用“系统偏好设置”/“共享”/“远程登录”开关进行打开&#xff0c;却一直是正在启动“远程登录”&#xff1a; 难道是catalina有BUG&#xff1f;不过还是有方法的&…

AcrelEMS-SW智慧水务能效管理平台解决方案-安科瑞黄安南

系统概述 安科瑞电气具备从终端感知、边缘计算到能效管理平台的产品生态体系&#xff0c;AcrelEMS-SW智慧水务能效管理平台通过在污水厂源、网、荷、储、充的各个关键节点安装保护、监测、分析、治理装置&#xff0c;用于监测污水厂能耗总量和能耗强度&#xff0c;重点监测主要…

Spring系列七:声明式事务

&#x1f418;声明式事务 和AOP有密切的联系, 是AOP的一个实际的应用. &#x1f432;事务分类简述 ●分类 1.编程式事务: 示意代码, 传统方式 Connection connection JdbcUtils.getConnection(); try { //1.先设置事务不要自动提交 connection.setAutoCommit(false…

【爬虫】爬取旅行评论和评分

以马蜂窝“普达措国家公园”为例&#xff0c;其评论高达3000多条&#xff0c;但这3000多条并非是完全向用户展示的&#xff0c;向用户展示的只有5页&#xff0c;数了一下每页15条评论&#xff0c;也就是75条评论&#xff0c;有点太少了吧&#xff01; 因此想了个办法尽可能多爬…

Swagger

目录 简介 使用方式&#xff1a; 常用注解 简介 使用Swagger你只需要按照他的规范去定义接口及接口相关信息再通过Swagger衍生出来的一系列项目和工具&#xff0c;就可以做到生成各种格式的接口文档&#xff0c;以及在线接口调试页面等等。 官网&#xff1a;https://swagger…

【0基础入门Python笔记】二、python 之逻辑运算和制流程语句

二、python 之逻辑运算和制流程语句 逻辑运算控制流程语句条件语句&#xff08;if语句&#xff09;循环结构&#xff08;for循环、while循环&#xff09;控制流程语句的嵌套以及elif 逻辑运算 Python提供基本的逻辑运算&#xff1a;不仅包括布尔运算&#xff08;and、or、not&…

基于深度信念神经网络+长短期神经网络的降雨量预测,基于dbn-lstm的降雨量预测,dbn原理,lstm原理

目录 背影 DBN神经网络的原理 DBN神经网络的定义 受限玻尔兹曼机(RBM) LSTM原理 DBN-LSTM的降雨量预测 基本结构 主要参数 数据 MATALB代码 结果图 展望 背影 DBN是一种深度学习神经网络,拥有提取特征,非监督学习的能力,通过dbn进行无监督学习提取特征,然后长短期神经…

mysql 插入数据锁等待超时报错:Lock wait timeout exceeded; try restarting transaction

报错信息 Lock wait timeout exceeded; try restarting transaction 锁等待超时 Lock wait timeout exceeded; try restarting transaction&#xff0c;是当前事务在等待其它事务释放锁资源造成的 解决办法 1、数据库中执行如下sql&#xff0c;查看当前数据库的线程情况&…

[C++ 网络协议编程] TCP/IP协议

目录 1. TCP/IP协议栈 2. TCP原理 2.1 TCP套接字中的I/O缓冲 2.2 TCP工作原理 2.2.1 三次握手&#xff08;连接&#xff09; 2.2.2 与对方主机的数据交换 2.2.3 四次握手&#xff08;断开与套接字的连接&#xff09; TCP&#xff08;Transmission Control Protocol传输控…

python bytes基本用法

目录 1 第一个字符变大写&#xff0c;其余字符变小写 capitalize() 2 生成指定长度内容&#xff0c;然后把指定的bytes放到中间 center() 3 计数 count() 4 解码 decode() 5 是否以指定的内容结尾 endswith() 6 将制表符调整到指定大小 expandtabs() 7 寻找指…

Springboot 实践(7)springboot添加html页面,实现数据库数据的访问

前文讲解&#xff0c;项目已经实现了数据库Dao数据接口&#xff0c;并通过spring security数据实现了对系统资源的保护。本文重点讲解Dao数据接口页面的实现&#xff0c;其中涉及页面导航栏、菜单栏及页面信息栏3各部分。 1、创建html页面 前文讲解中&#xff0c;资源目录已经…