【Python深度学习系列】网格搜索选择神经网络超参数:隐含层神经元数量(案例+源码)

这是我的第259篇原创文章。

一、引言

图片

在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:

  • model.add()

    • neurons(隐含层神经元数量)

    • init_mode(初始权值)

    • activation(激活函数)

    • dropout(丢弃率)

  • model.compile()

    • loss(损失函数)

    • optimizer(优化器)

      • learning rate(学习率)

      • momentum(动量)

      • weight decay(权重衰减系数)

  • model.fit()

    • batch size(批量大小)

    • epochs(迭代次数)

一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优,本文采用网格搜索选择神经网络隐含层神经元数量。

二、实现过程

2.1 准备数据

dataset:

dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)

图片

2.2 数据划分

# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)

2.3 创建模型

需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。

def create_model(neurons_1):
    # 创建模型
    model = Sequential()
    model.add(Dense(neurons_1, input_shape=(8, ), kernel_initializer='uniform', activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))

    # 编译模型
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

model = KerasClassifier(model=create_model, epochs=100, batch_size=80, verbose=0, neurons_1=1)

这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier没有定义隐含神经元的参数,需要自定义一个表示隐含层神经元的参数neurons_1,并赋默认值为1。

2.4 定义网格搜索参数

param_grid = {'neurons_1': [1, 5, 10, 15, 20, 25, 30]}

param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。

2.5 进行参数搜索

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model,  param_grid=param_grid)
grid_result = grid.fit(X, Y)

使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。

2.6 总结搜索结果

print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

结果:

图片

经过网格搜索,隐含层神经元数量,最优的结果是30。

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/544704.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

探索 SAM 在遥感方面的能力

分割任意模型 (SAM) 现在可在不同类型的数据(例如近距离图像和航空图像)中自由克隆和使用。在我看来,SAM 模型在近距离图像上效果更好,因为这些图像对目标特征和物体有独特的视角,使模型更容易准确地区分和分割它们。 现在,我们将探讨 SAM 模型在不同遥感数据上的能力,包…

软考128-上午题-【软件工程】-白盒测试

一、白盒测试(结构测试) 白盒测试也称为结构测试,根据程序的内部结构和逻辑来设计测试用例,对程序的路径和过程进行测试,检查是否满足设计的需要。 白盒测试常用的技术是:逻辑覆盖、循环覆盖和基本路径测…

Web前端 JavaScript笔记4

1、元素内容 属性名称说明元素名.innerText输出一个字符串,设置或返回元素中的内容,不识别html标签元素名.innerHTML输出一个字符串,设置或返回元素中的内容,识别html标签元素名.textContent设置或返回指定节点的文本内容&#x…

LeetCode 678——有效的括号字符串

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 需要两个栈,一个用来保存左括号所在的位置索引,一个用来保存星号所在的位置索引。 从左往右遍历字符串,如果是左括号或者星号,则将位置索引分别入栈,如…

linux shell脚本编写(2)

Shell: 命令转换器,高级语言转换成二进制语言。是Linux的一个外壳,它包在Lniux内核的外面,用户和内核之间的交互提供了一个接口。 内置命令:在shell内部不需要shell编辑 外置命令:高级语言要用shell转换成二进制语言 …

机器学习 | 使用Scikit-Learn实现分层抽样

在本文中,我们将学习如何使用Scikit-Learn实现分层抽样。 什么是分层抽样? 分层抽样是一种抽样方法,首先将总体的单位按某种特征分为若干次级总体(层),然后再从每一层内进行单纯随机抽样,组成…

Kubernetes的Ingress Controller

前言 Kubernetes暴露服务的方式有一下几种:LoadBlancer Service、ExternalName、NodePort Service、Ingress,使用四层负载均衡调度器Service时,当客户端访问kubernetes集群内部的应用时,数据包的走向如下面流程所示:C…

计算机三级数据库技术备考笔记(十四)

第十四章 数据仓库与数据挖掘 决策支持系统的发展 决策支持系统及其演化 操作型数据(Operalional Data)是指由企业的基本业务系统所产生的数据,操作型数据及相应数据处理所处的环境,即用于支持企业基本业务应用的环境,一般被称为联机事务处理(0nLine Transaction Processing,0…

COMSOL多孔介质流仿真

使用Comsol进行多孔介质流仿真_哔哩哔哩_bilibili 目录 多孔介质 饱和多孔介质中的流动 达西定律 Brinkman方程:用于过渡区 裂隙流 变饱和多孔介质流 理查兹方程 多孔介质多相流 多物理场耦合 多孔介质中的传热 多孔弹性接口 多孔介质稀物质传递 多孔介质…

c# 无处不在的二分搜索

我们知道二分查找算法。二分查找是最容易正确的算法。我提出了一些我在二分搜索中收集的有趣问题。有一些关于二分搜索的请求。我请求您遵守准则:“我真诚地尝试解决问题并确保不存在极端情况”。阅读完每个问题后,最小化浏览器并尝试解决它。 …

NSL-KDD数据集详细介绍及下载

链接:https://pan.baidu.com/s/1hX4xpVPo70vwLIo0gdsM8A?pwdq88b 提取码:q88b 一般认为数据质量决定了机器学习性能的上限,而机器学习模型和算法的优化最多 只能逼近这个上限。因此在数据采集阶段需要对采集任务进行规划。在数据采集之前, 主要是从数据…

第十二讲 查询计划 优化

到目前为止,我们一直在说,我们得到一个 SQL 查询,我们希望可以解析它,将其转化为某种逻辑计划,然后生成我们可以用于执行的物理计划。而这正是查询优化器【Optimizer】的功能,对于给定的 SQL ,优…

.net框架和c#程序设计第三次测试

目录 一、测试要求 二、实现效果 三、实现代码 一、测试要求 二、实现效果 数据库中的内容&#xff1a; 使用数据库中的账号登录&#xff1a; 若不是数据库中的内容&#xff1a; 三、实现代码 login.aspx文件&#xff1a; <% Page Language"C#" AutoEventW…

DB schema表中使用全局变量及在DB组件中查询

DB schema表中使用全局变量及在DB组件中查询 规则如下&#xff1a; 使用如下&#xff1a; 如果在unicloud-db组件上不加判断条件&#xff0c;就会报错&#xff0c;并进入到登录页。 那么就会进入到登录页&#xff0c;加上了判断条件&#xff0c;有数据了就不会了。 因为在sc…

TQ15EG开发板教程:在MPSOC上运行ADRV9371(vivado2018.3)

首先需要在github上下载两个文件&#xff0c;本例程用到的文件以及最终文件我都会放在网盘里面&#xff0c; 地址放在本文最后。首先在github搜索hdl选择第一个&#xff0c;如下图所示 GitHub网址&#xff1a;https://github.com/analogdevicesinc/hdl/releases 点击releases…

【Maven工具】

maven Maven是一个主要用于Java项目的构建自动化工具。它有助于管理构建过程&#xff0c;包括编译源代码、运行测试、将编译后的代码打包成JAR文件以及管理依赖项。Maven使用项目对象模型&#xff08;POM&#xff09;文件来描述项目配置和依赖关系。 Maven通过提供标准的项目…

分布式系统中的唯一ID生成方法

通常在分布式系统中&#xff0c;有生成唯一ID的需求&#xff0c;唯一ID有多种实现方式。我们选择其中几种&#xff0c;简单阐述一下实现原理、适用场景、优缺点等信息。 目录 数据库多主复制UUID工单服务器雪花算法总结 数据库多主复制 数据库通常有自增属性&#xff0c;在单机…

解决vue启动项目报错:npm ERR! Missing script: “serve“【详细清晰版】

目录 问题描述问题分析和解决情况一解决方法情况二&#xff08;常见于vue3&#xff09;解决方法情况三解决方法 问题描述 在启动vue项目时通常在控制台输入npm run serve 但是此时出现如下报错&#xff1a; npm ERR! Missing script: "serve" npm ERR! npm ERR! T…

80% 的人都不会的 15 个 Linux 实用技巧

熟悉 Linux 系统的同学都知道&#xff0c;它高效主要体现在命令行。通过命令行&#xff0c;可以将很多简单的命令&#xff0c;通过自由的组合&#xff0c;得到非常强大的功能。 命令行也就意味着可以自动化&#xff0c;自动化会使你的工作更高效&#xff0c;释放很多手工操作&…

纸制品ERP怎么样

在纸制品行业中&#xff0c; ERP系统的应用已经成为提升企业竞争力的关键因素。本文将探讨万达宝ERP系统在制造成本控制、商品生命周期管理以及自动对接主流平台方面的作用&#xff0c;并分析其在业务流程优化、高效调节各类关系以及多种模式生产方面的特点和益处。 制造成本控…