【Python深度学习系列】网格搜索神经网络超参数:批量大小和迭代周期数(案例+源码)

这是我的第297篇原创文章。

一、引言

图片

      在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:

  • model.add()
    • neurons(隐含层神经元数量)
    • init_mode(初始权重方法)
    • activation(激活函数)
    • dropout(丢弃率)
  • model.compile()
    • loss(损失函数)
    • optimizer(优化器)
      • learning rate(学习率)
      • momentum(动量)
      • weight decay(权重衰减系数)
  • model.fit()
    • batch size(批量大小)
    • epochs(迭代周期数)

      一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优在深度学习中,Epoch(周期)和 Batch Size(批大小)是训练神经网络时经常使用的两个重要的超参数。

  • Epoch(周期):一个Epoch就是将所有训练样本训练一次的过程。然而,当一个Epoch的样本(也就是所有的训练样本)数量可能太过庞大(对于计算机而言),就需要把它分成多个小块,也就是就是分成多个Batch 来进行训练。

  • Batch(批 / 一批样本):将整个训练样本分成若干个Batch。

  • Batch_Size(批大小):每批样本的大小。即1次迭代所使用的样本量。

  • Iteration(一次迭代):训练一个Batch就是一次Iteration(这个概念跟程序语言中的迭代器相似)每次迭代更新1次网络结构的参数

  • step(一步):训练一个样本就是一个step。

       比如我有1000个训练样本,bachsize设置为10,则数据分成了100个batch,所有训练样本训练一次即一个epoch需要100个iteration,训练一个batch就是一次iteration。本文采用网格搜索选择Epoch和Batch_size。

二、实现过程

2.1 准备数据

dataset:

dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)

图片

2.2 数据划分

# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)

2.3 创建模型

需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。

def create_model():
    # 创建模型
    model = Sequential()
    model.add(Dense(50, input_shape=(8, ), kernel_initializer='uniform', activation='relu'))
    model.add(Dropout(0.05))
    model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))

    # 编译模型
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
model = KerasClassifier(model=create_model)

这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier有批量大小、迭代次数的参数,不需要自定义表示。

2.4 定义网格搜索参数

param_grid = {'batch_size': [20, 40], 'epochs': [10, 50]}

param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。value是key可取的值,也就是要尝试的方案。

2.5 进行参数搜索

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model,  param_grid=param_grid)
grid_result = grid.fit(X, Y)

使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。

2.6 总结搜索结果

print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

结果:

图片

经过网格搜索,批量大小的最优选择是20,迭代次数最优选择是50。

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/686565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在线Logo背景去除:pixian.ai

文章目录 简介特色 简介 pixian.ai是一款智能图片背景去除工具,进入网页后,会非常醒目地提示你准备【Free】还是【Paid】,这点就非常好,不向有一些网站,主打免费使用,但时不时弹出“免费注册”&#xff0c…

1782java英语陪学记词系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java英语陪学记词系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助采用了java设计,系统具有完整的源代码和数据库,系统采用web模式,系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&…

RabbitMQ-工作模式(简单模式工作队列)

文章目录 简单模式(simple)工作队列(work)准备工作轮询调度消息确认消息持久性公平分发代码示例 本篇总结 更多相关内容可查看 简单模式(simple) 通俗概括:生产者-队列-消费者 想详细了解Rabbit的基础或简…

ESD防护SP3232E真+3.0V至+5.5V RS-232收发器

特征 采用3.0V至5.5V电源,符合真正的EIA/TIA-232-F标准 满载时最低 120Kbps 数据速率 1μA 低功耗关断,接收器处于活动状态 (SP3222E) 可与低至 2.7V 电源的 RS-232 互操作 增强的ESD规格: 15kV人体模型 15kV IEC1000…

软件杯 题目:基于深度学习卷积神经网络的花卉识别 - 深度学习 机器视觉

文章目录 0 前言1 项目背景2 花卉识别的基本原理3 算法实现3.1 预处理3.2 特征提取和选择3.3 分类器设计和决策3.4 卷积神经网络基本原理 4 算法实现4.1 花卉图像数据4.2 模块组成 5 项目执行结果6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 基…

计算机毕业设计Spark+Flink+Hive地铁客流量预测 交通大数据 地铁客流量大数据 交通可视化 大数据毕业设计 深度学习 机器学习

项目说明​ ​ 1该项目主要分析通刷卡数据,通过大数据技术来研究地铁客运能力及探索优化服务的方向​ 2主要讲解Flink流处理实时分析部分,离线部分较简单,暂时略过​ ​ 技术架构​ ​项目流程:​ 采用python请求深圳地铁数…

70 Realistic Mountain Environment Textures Cliff(70+张真实的山地环境纹理)

大量适合山区和其他岩石环境的纹理--悬崖、岩石、砾石等等 每个纹理都是可贴的/无缝的,并且完全兼容各种不同的场景--标准Unity地形、Unity标准着色器、URP、HDRP等等都兼容。 所有的纹理都是4096x4096,并包括一个HDRP掩码,以完全支持HDRP。 特点。 70种质地 70种材料 70个地…

基于springboot实现农产品直卖平台系统项目【项目源码+论文说明】

基于springboot实现农产品直卖平台系统的设计演示 摘要 计算机网络发展到现在已经好几十年了,在理论上面已经有了很丰富的基础,并且在现实生活中也到处都在使用,可以说,经过几十年的发展,互联网技术已经把地域信息的隔…

内网快速传输工具

常见的有LANDrop,支持多种设备,如电脑、pad、手机等等之间互传。但本文介绍的这款是很小的电脑间互传工具。 特点是非常的快速,文件很小,不用安装解压就可用。

transformers peft加载lora模型;TextStreamer流式输出,kv cache使用

1、transformers peft加载lora模型 https://github.com/hiyouga/LLaMA-Factory/blob/cae47379079ff811aa385c297481a27020a8da6b/scripts/loftq_init.py#L13 代码: from peft import AutoPeftModelForCausalLM, PeftModel from transformers import AutoTokenizer…

《手把手教你》系列练习篇之13-python+ selenium自动化测试 -压轴篇(详细教程)

1. 简介 “压轴”原本是戏曲名词,指一场折子戏演出的倒数第二个剧目。在现代社会中有很多应用,比如“压轴戏”,但压轴也是人们知识的一个盲区。“压轴”本意是指倒数第二个节目,而不是人们常说的倒数第一个,倒数第一个…

Incredibuild for Mac 来了!

Mac 开发者在寻找适合自己需求的工具时可能会遇到一些困难,因为 Mac 操作系统相对封闭,不像其他系统那样开放和灵活。尽管如此,Mac 开发者在开发应用程序时的需求(比如功能、效率等)和使用其他操作系统的开发者是类似的…

C++ - 查找算法 和 其他 算法

目录 一. 查找算法: 1.顺序查找: 2.二分查找: 二. 其他算法: 1.遍历算法: 2.求和、求平均值等聚合算法。 a.求和算法: b.求平均值算法: 一. 查找算法: 1.顺序查找&#xff1…

如何访问内网数据库?

现如今,随着信息化的不断发展,数据库已经成为了企业管理和数据存储的重要组成部分。由于安全等原因,很多公司和组织将自己的数据库部署在内网中,限制了外部的访问。有些情况下,我们仍然需要在外部网络环境中访问内网的…

C++开发基础之初探CUDA计算环境搭建

一、前言 项目中有使用到CUDA计算的相关内容。但是在早期CUDA计算环境搭建的过程中,并不是非常顺利,编写此篇文章记录下。对于刚刚开始研究的你可能会有一定的帮助。 二、环境搭建 搭建 CUDA 计算环境涉及到几个关键步骤,包括安装适当的 C…

:长亭雷池社区版动态防护体验测评

序 长亭雷池在最近发布了动态防护功能,据说可以动态加密保护网页前端代码和阻止爬虫行为、阻止漏洞扫描行为等。今天就来体验测试一下 WAF 是什么 WAF 是 Web Application Firewall 的缩写,也被称为 Web 应用防火墙。区别于传统防火墙,WAF …

Error:..\FreeRTOS\portable\RVDS\ARM_CM7\r0p1\port.c,265

移植完FreeRTOS后,使用Keil进行编译,编译未报错,串口打印助手打印了错误报告。 串口打印的错误报告: Error:..\FreeRTOS\portable\RVDS\ARM_CM7\r0p1\port.c,265看一下265行 该行所在函数为prvTaskExitError函数,功能…

阅读笔记:Multi-threaded Rasterization in the Chromium Compositor

Multi-threaded Rasterization in the Chromium Compositor PPT 原始链接: https://docs.google.com/presentation/d/1nPEC4YRz-V1m_TsGB0pK3mZMRMVvHD1JXsHGr8I3Hvc/edit?uspsharing PPT主要介绍了Chromium浏览器中使用多线程光栅化(Impl-side painting)的机制&a…

目标检测——FGVC-Aircraft数据集

引言 亲爱的读者们,您是否在寻找某个特定的数据集,用于研究或项目实践?欢迎您在评论区留言,或者通过公众号私信告诉我,您想要的数据集的类型主题。小编会竭尽全力为您寻找,并在找到后第一时间与您分享。 …

【Vue】Vue路由-重定向

问题 网页打开时, url 默认是 / 路径,未匹配到组件时,会出现空白 解决方案 重定向 → 匹配 / 后, 强制跳转 /home 路径 语法 { path: 匹配路径, redirect: 重定向到的路径 }, 比如: { path:/ ,redirect:/home }代码示例 const…