optuna和 lightgbm

文章目录

  • optuna使用
    • 1.导入相关包
    • 2.定义模型可选参数
    • 3.定义训练代码和评估代码
    • 4.定义目标函数
    • 5.运行程序
    • 6.可视化
    • 7.超参数的重要性
    • 8.查看相关信息
    • 9.可视化的一个完整示例
    • 10.lightgbm实验

optuna使用

1.导入相关包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fvcore.nn import FlopCountAnalysis

import optuna


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30   # 128 * 30个训练
N_VALID_EXAMPLES = BATCHSIZE * 10   # 128 * 10个预测

2.定义模型可选参数

optuna支持很多种搜索方式:
(1)trial.suggest_categorical(‘optimizer’, [‘MomentumSGD’, ‘Adam’]):表示从SGD和adam里选一个使用;
(2)trial.suggest_int(‘num_layers’, 1, 3):从1~3范围内的int里选;
(3)trial.suggest_uniform(‘dropout_rate’, 0.0, 1.0):从0~1内的uniform分布里选;
(4)trial.suggest_loguniform(‘learning_rate’, 1e-5, 1e-2):从1e-5~1e-2的log uniform分布里选;
(5)trial.suggest_discrete_uniform(‘drop_path_rate’, 0.0, 1.0, 0.1):从0~1且step为0.1的离散uniform分布里选;

def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3) # 从[1,3]范围里面选一个
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)

3.定义训练代码和评估代码

# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        F.nll_loss(model(data), target).backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):
            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / N_VALID_EXAMPLES

    flops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),)).total()
    return flops, accuracy

4.定义目标函数

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)
    flops, accuracy = eval_model(model, val_loader)
    return flops, accuracy

5.运行程序

运行30次实验,每次实验返回 flops,accuracy

study = optuna.create_study(directions=["minimize", "maximize"]) # flops 最小化, accuracy 最大化
study.optimize(objective, n_trials=30, timeout=300)

print("Number of finished trials: ", len(study.trials))

6.可视化

flops, accuracy 二维图
optuna.visualization.plot_pareto_front(study, target_names=[“FLOPS”, “accuracy”])

在这里插入图片描述

7.超参数的重要性

对于flops
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[0], target_name=“flops”
)

对于accuracy
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[1], target_name=“accuracy”
)

在这里插入图片描述

8.查看相关信息

# https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/002_multi_objective.html
# 利用pytorch mnist 识别
# 设置了一些超参数,lr, layer number, feature_number等
# 然后目标是 flops 和 accurary

# 最后是可视化:
# 显示试验的一些结果:
# optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])
# 左上角是最好的

# 显示重要性:
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[0], target_name="flops"
# )
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[1], target_name="accuracy"
# )


# trials的属性:
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")

trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print(f"Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")

9.可视化的一个完整示例

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timeline

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)

        val_accuracy = eval_model(model, val_loader)
        trial.report(val_accuracy, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_accuracy

study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=30, timeout=300)

运行之后可视化:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

10.lightgbm实验

"""
Optuna example that optimizes a classifier configuration for cancer dataset using LightGBM.

In this example, we optimize the validation accuracy of cancer detection using LightGBM.
We optimize both the choice of booster model and their hyperparameters.

"""

import numpy as np
import optuna

import lightgbm as lgb
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split


# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
    train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
    dtrain = lgb.Dataset(train_x, label=train_y)

    param = {
        "objective": "binary",
        "metric": "binary_logloss",
        "verbosity": -1,
        "boosting_type": "gbdt",
        "lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
        "lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
        "num_leaves": trial.suggest_int("num_leaves", 2, 256),
        "feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),
        "bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
        "bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
        "min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
    }

    gbm = lgb.train(param, dtrain)
    preds = gbm.predict(valid_x)
    pred_labels = np.rint(preds)
    accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
    return accuracy


if __name__ == "__main__":
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=100)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

运行结果:
在这里插入图片描述

https://github.com/microsoft/LightGBM/tree/master/examples

https://blog.csdn.net/yang1015661763/article/details/131364826

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/942618.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Yonghong 企业日常问题 06】上传的文件不在白名单,修改allow.jar.digest属性添加允许上传的文件SH256值?

文章目录 前言问题描述问题分析问题解决1.允许所有用户上传驱动文件2.如果是想只上传白名单的驱动 前言 该方法适合永洪BI系列产品,包括不限于vividime desktop,vividime z-suit,vividime x-suit产品。 问题描述 当我们连接数据源的时候&a…

[项目][boost搜索引擎#4] cpp-httplib使用 log.hpp 前端 测试及总结

目录 编写http_server模块 1. 引入cpp-httplib到项目中 2. cpp-httplib的使用介绍 3. 正式编写http_server 九、添加日志到项目中 十、编写前端模块 十一. 详解传 gitee 十二、项目总结 项目的扩展 写在前面 [项目详解][boost搜索引擎#1] 概述 | 去标签 | 数据清洗 |…

项目练习:若依系统的svg-icon功能实现

文章目录 一、svg图片准备二、自定义Svg组件三、svg插件开发四、Svg组件使用 一、svg图片准备 src/assets/icons/svg 其中svg目录里,存放了所需要的图片 index.js import Vue from vue import SvgIcon from /components/SvgIcon// svg component// register glob…

水库大坝三维模型的开发和使用3Dmax篇

成果图 开发过程 工具插件three.js先加载模型做水体衔接水位测量标尺水位标记断面标记大坝监测点打点 上代码&#xff0c;技术交流V: bloxed <template><div class"box w100 h100"><el-row :gutter"20" v-loading"loading"e…

Win10提示“缺少fbgemm.dll”怎么办?缺失fbgemm.dll文件的修复方法来啦!

fbgemm.dll文件的作用 fbgemm.dll&#xff08;Facebook GEMM library&#xff09;是一个动态链接库文件&#xff0c;它主要用于优化矩阵乘法运算&#xff0c;提高计算性能。虽然它不是Windows 10系统的核心组件&#xff0c;但在某些应用程序或游戏中&#xff0c;尤其是那些需要…

Petalinux使用QSPI FLASH引导启动

目录 1. 预分配Flash空间 1.1 大小估计 1.2 其他注意事项 2. 配置Flash分区 3. 配置各主要文件在Flash中的地址范围 4. 配置boot.scr的偏移 5. 修改U-Boot环境变量在Flash的偏移量 6. 配置设备树中的Flash 7. 开启对EXT4分区管理的支持(根据需要) 8. 编译u-boot 9.…

Android——自定义按钮button

项目中经常高频使用按钮&#xff0c;要求&#xff1a;可设置颜色&#xff0c;有圆角且有按下效果的Button 一、自定义按钮button button的代码为 package com.fslihua.clickeffectimport android.annotation.SuppressLint import android.content.Context import android.gra…

黑龙江等保测评费用怎么收?

‌黑龙江二级等保测评费用‌&#xff1a;费用区间大致在3万至6万人民币之间&#xff0c;具体费用取决于测评机构的定价策略、所提供的服务内容以及企业的实际需求‌&#xff0c;服务内容包括防火墙、Web应用防火墙(WAF)、堡垒机、日志审计、漏洞扫描以及等保安全整改等‌。 ‌…

中文拼写检测纠正 Read, Listen, and See Leveraging Multimodal Information 论文

拼写纠正系列 NLP 中文拼写检测实现思路 NLP 中文拼写检测纠正算法整理 NLP 英文拼写算法&#xff0c;如果提升 100W 倍的性能&#xff1f; NLP 中文拼写检测纠正 Paper java 实现中英文拼写检查和错误纠正&#xff1f;可我只会写 CRUD 啊&#xff01; 一个提升英文单词拼…

vue2 elementui if导致的rules判断失效

优化目标 和 目标转化出价必填的 切换的时候还会隐藏掉 这时候的if语句会导致rules判断失效 我的办法是把判断拉到外面 别放在el-form-item里 <section v-if"unitForm.baseTarget OCPM && unitForm.cpaTargetOptions ! undefined && unitForm.cpaTa…

前端(Ajax)

1.客户端请求 向https://jsonplaceholder.typicode.com/users发送get请求 const xhr new XMLHttpRequest(); console.log(xhr.readyState); xhr.open(‘get’, ‘https://jsonplaceholder.typicode.com/users’) console.log(xhr.readyState); xhr.send(); console.log(xhr.…

uboot, s5pv210 , main_loop 分析(16)

main_loop 的代码如下&#xff1a; 4443 void main_loop (void)42 {41 #ifndef CONFIG_SYS_HUSH_PARSER E 40 ▎ static char lastcommand[CONFIG_SYS_CBSIZE] { 0, }; ■ Use of undeclared identifier CONFIG_SYS_CBSIZE39 ▎ int len;38 ▎ int rc 1;37 ▎ …

信号强劲,通信清晰:北斗三号多模对讲机TD70——专业通信解决方案

在边防海防等国家安全的关键领域&#xff0c;通信的稳定性和可靠性关乎着任务的成败和战士们的安全。北斗三号多模对讲机TD70&#xff0c;凭借其卓越的性能和全面的功能&#xff0c;成为了边防海防通信的利器&#xff0c;守护着国家安全的前沿哨兵。 一、三网融合&#xff0c;…

Arduino驱动DS18B20测量环境温度

DS18B20是一款高精度的单总线数字温度传感器&#xff0c;具体参数如下表所示&#xff1a; 参数名称 参数特性 测温范围 -55~125℃ 测量精度 在-10~85℃范围内的误差为0.5℃ 分辨率 9~12位数字信号&#xff0c;分辨率分别为0.5℃、0.25℃、0.125℃和0.0625℃ 通信方式 …

vector快慢指针+例题详解

1.快慢指针 例题 给定一个链表&#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;我们使用整数 pos 来表示链表尾连接到链表中的位置&#xff08;索引从…

C++--------效率和表示

C 效率和表示 效率 时间效率&#xff1a;在 C 中&#xff0c;不同的数据结构和算法有着各异的时间复杂度。例如&#xff0c;访问数组元素的时间复杂度是 O ( 1 ) O(1) O(1)&#xff0c;而遍历链表查找元素的时间复杂度最坏情况下是 O ( n ) O(n) O(n)。选择合适的算法与数据…

【Mac】终端改色-让用户名和主机名有颜色

效果图 配置zsh 1.打开终端&#xff0c;进入.zshrc配置 cd ~ vim .zshrc2.添加如下配置并保存 # 启用命令行颜色显示 export CLICOLOR1 ## 加载颜色支持 autoload -U colors && colors # 配置 zsh 提示符 PROMPT"%{$fg_bold[red]%}%n%{$reset_color%}%{$fg_bol…

模拟——郑益慧_笔记1_绪论

B站视频链接 模电是数电的基础&#xff1b;参考书&#xff1a; 模拟电子技术基础&#xff08;第四版&#xff09;华成英、童诗白主编&#xff0c;高等教育出版社&#xff1b;电子技术基础 模拟部分 康华光主编&#xff0c;高等教育出版社&#xff1b; 电子技术的发展史 电子…

YOLOv11模型改进-模块-引入多尺度大核注意力Multi-scale Large Kernel Attention

MLKA 的提出源于图像超分辨率任务的挑战性&#xff0c;该任务需重建低质量图像缺失的高频信息&#xff0c;但因 LR 与 HR 图像对应关系复杂&#xff0c;寻找像素相关性困难。此前模型扩展容量的方法增加了训练负担和数据收集成本&#xff0c;而采用的注意力机制无法同时获取局部…

【gym】给定的强化学习环境简介(二)

文章目录 环境介绍一 box2dbipedal_walkercar_dynamicscar_racinglunar_lander 二、 classic_controlacrobotCartPolecontinuous_mountain_carmountain_carpendulum 三、toy_textblackjackcliffwalkingfrozentaxi 四、mujocoAnt&#xff1a;HalfCheetah&#xff1a;Hopper&…