【经验总结】 常用的模型优化器

优化器是一种用于优化模型权重和偏差的算法,它根据训练数据更新模型参数,以模型的预测结果更加准确。

1. 常见的优化器

  1. SGD(Stochastic Gradient Descent):SGD是一种基本的优化算法,它在每次迭代中随机选择一个样本进行梯度计算和参数更新。SGD使用固定的学习率,通常需要更多的迭代才能收敛,但在一些情况下也可以取得很好的效果,简单易于实现,但在非凸优化问题中可能会出现收敛速度慢的问题。

  2. RMSprop(Root Mean Square Propagation):RMSprop是自适应学习率的一种方法,它在训练过程中调整学习率,以便更好地适应不同特征的梯度。RMSprop通过维护梯度平方的移动平均来调整学习率。

  3. Adagrad(Adaptive Gradient):Adagrad是一种自适应学习率的优化器,它根据参数的历史梯度进行学习率调整。Adagrad适用于稀疏数据集,在训练初期对稀疏特征有较大的学习率,随着训练的进行逐渐减小。可能会在训练后期由于学习率过小导致收敛速度变慢。

  4. Adadelta:Adadelta是Adagrad的改进版本,它通过引入梯度平方的衰减平均来解决Adagrad学习率过早衰减的问题。Adadelta不需要手动设置学习率,并且在训练过程中可以自适应地调整学习率。

  5. Adam(Adaptive Moment Estimation):Adam是一种基于梯度的优化器,结合了自适应学习率和动量的概念。它在训练过程中自适应地调整学习率,并利用动量来加速梯度更新。Adam在很多NLP任务中表现良好。

  6. Adamax:Adamax是Adam的变体,它使用了无穷范数(infinity norm)来对梯度进行归一化。

  7. AdamW:AdamW是Adam的一种变体,它引入了权重衰减(weight decay)的概念。权重衰减可以有效防止模型过拟合。

  8. Nadam(Nesterov-accelerated Adaptive Moment Estimation):Nadam是Adam与Nesterov动量法的结合,它在Adam的基础上加入了Nesterov动量的修正项。

2. 示例代码

python代码:

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import imdb
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Adagrad, Adadelta, Adamax, Nadam
from transformers import AdamW
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing import sequence

# 加载IMDB情感分类数据集
(X_train, y_train), (_, _) = imdb.load_data(num_words=10000)
X_train = X_train[:3000]  # 只使用部分数据进行演示
y_train = y_train[:3000]


# 数据预处理:将序列填充为相同长度
X_train = sequence.pad_sequences(X_train, maxlen=10000)

# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=36)

# 构建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10000))
model.add(Dense(1, activation='sigmoid'))

# 定义优化器列表
optimizers = [Adam(), SGD(), RMSprop(), Adagrad(), Adadelta(), Adamax(), Nadam()]
optimizer_names = ['Adam', 'SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adamax', 'Nadam']
# optimizer_names = ['Adam', 'SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adamax', 'Nadam', 'AdamW']
histories = []

# 训练模型并记录历史
for optimizer in optimizers:
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=160, batch_size=32, verbose=0)
    histories.append(history)

# 绘制学习曲线
plt.figure(figsize=(12, 6))
for i, history in enumerate(histories):
    plt.plot(history.history['val_loss'], label=optimizer_names[i])
plt.title('Validation Loss Comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/610128.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

揭秘Ping32如何实现上网行为监控

企业上网行为管理软件在现代企业管理中扮演着举足轻重的角色。它不仅能够监控和记录员工的上网行为,还能有效防止数据泄露和不当使用,从而保障企业的信息安全。 一、Ping32上网监控软件的具体功能包括: 1.网页浏览监控:对Chrome…

jvm面试题30问

什么是JVM的跨平台? 什么是JVM的语言无关性? 什么是JVM的解释执行 什么是JIT? JIT:在Java编程语言和环境中,即时编译器(JIT compiler,just-in-time compiler)是一个把Java的字节码(…

流量卡就该这么选,用起来性价比真的超高!

很多朋友会私信小编,让小编给大家推荐几款流量卡,在这里小编告诉大家,流量卡可以推荐,但是每个人的喜好不同,小编也忙不过来,今天,小编整理了一篇选购指南,大家可以参考选择&#xf…

2024 B2B企业出海营销白皮书(展会篇)

来源:科特勒&微吼 根据36氪研究院发布的《2023-2024年中国企业出海发展研究报告》中指出,随着全球化浪潮席卷以及中国智造的崛起,中国企业出海主力从过去的低附加值行业逐步扩展至信息技术、先进制造、医疗健康、汽车交通、新消费等附加…

106短信平台疑难解答:为何手机正常却收不到短信?

当您使用群发短信平台发送消息时,有时尽管系统提示发送成功,但手机却未能收到短信。这背后可能隐藏着一些不为人知的原因。 首先,我们要明确,在正常情况下,只要手机状态正常,都应该能够接收到短信。然而&am…

为什么站长们喜欢使用新加坡站群服务器呢?

为什么站长们喜欢使用新加坡站群服务器呢? 站群优化一直是站长们追逐的目标之一,而新加坡站群服务器则备受站长们的青睐。为什么会如此呢?让我们深入了解一下。 为什么站长们喜欢使用新加坡站群服务器呢? 站群,简单来说,就是一组相互关联…

Python专题:十、字典(1)

数据类型:字典,是一个集合性质的数据类型 字典的初始化 字典{关键字:数值} 新增元素 修改元素 字典元素访问 字典[关键字} in 操作符 字典关键字检测 字典元素遍历 ①遍历关键字

Android build.prop生成过程源码分析

Android的build.prop文件是在Android编译时刻收集的各种property【LCD density/语言/编译时间, etc.】&#xff1b;编译完成之后&#xff0c;文件生成在out/target/product/<board【OK1000】>/system/目录下&#xff1b;在Android运行时刻可以通过property_get()[c/c域] …

深度学习论文: LightGlue: Local Feature Matching at Light Speed

深度学习论文: LightGlue: Local Feature Matching at Light Speed LightGlue: Local Feature Matching at Light Speed PDF: https://arxiv.org/pdf/2306.13643 PyTorch代码: https://github.com/shanglianlm0525/CvPytorch PyTorch代码: https://github.com/shanglianlm0525/…

python数据分析——数据预处理

数据预处理 前言一、查看数据数据表的基本信息查看info&#xff08;&#xff09;示例 查看数据表的大小shape&#xff08;&#xff09;示例 数据格式的查看type()dtype&#xff08;&#xff09;dtypes&#xff08;&#xff09;示例一示例二 查看具体的数据分布describe()示例 二…

机器人学【一、刚体运动】

机器人学 文章目录 机器人学1. 刚体运动1.1 刚体变换刚体刚体变换 1.2 三维空间中的旋转运动群求质点坐标的相对变换旋转矩阵的合成法则用线性算子来计算叉积叉积的右手法则叉积用于计算线速度旋转的指数坐标Rodrigues公式计算旋转矩阵的例子四元数 1.3 三维空间中的刚体运动齐…

二分查找入门、二分查找模板

二分查找的具体实现是一个难点&#xff0c;挺复杂的&#xff0c;可以背住一个模板&#xff0c;然后以后再慢慢学习。下面是y总的二分模板(比较难懂&#xff0c;之后再学) y总的模板 二分的本质是在一个边界内&#xff0c;定义了两种不同的形状&#xff0c;其中某点是这两个性…

Golang | Leetcode Golang题解之第68题文本左右对齐

题目&#xff1a; 题解&#xff1a; // blank 返回长度为 n 的由空格组成的字符串 func blank(n int) string {return strings.Repeat(" ", n) }func fullJustify(words []string, maxWidth int) (ans []string) {right, n : 0, len(words)for {left : right // 当前…

详细解析DBC文件

《AUTOSAR谱系分解(ETAS工具链)》之总目录_autosar的uart模块-CSDN博客

Docker Desktop 修改容器的自启动设置

Docker Desktop 允许用户控制容器的自启动行为。如果你不希望某个容器在 Docker 启动时自动启动&#xff0c;你可以通过以下步骤来更改设置&#xff1a; 1. 打开 Docker Desktop 应用。 2. 点击右上角的设置&#xff08;Settings&#xff09;按钮&#xff0c;或者使用快捷键 Cm…

Hive Aggregation 聚合函数

Hive Aggregation 聚合函数 基础聚合 增强聚合

找最大数字-第12届蓝桥杯国赛Python真题解析

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第60讲。 找最大数字&#…

67万英语单词学习词典ACCESS\EXCEL数据库

这似乎是最多记录的英语单词学习词典&#xff0c;包含复数、过去分词等形式的单词。是一个针对想考级的人员辅助背单词学英语必备的数据&#xff0c;具体请自行查阅以下的相关截图。 有了数据才能想方设法做好产品&#xff0c;结合权威的记忆理论&#xff0c;充分调动用户的眼…

OpenSearch 与 Elasticsearch:7 个主要差异及如何选择

OpenSearch 与 Elasticsearch&#xff1a;7 个主要差异及如何选择 1. 什么是 Elasticsearch&#xff1f; Elasticsearch 是一个基于 Apache Lucene 构建的开源、RESTful、分布式搜索和分析引擎。它旨在处理大量数据&#xff0c;使其成为日志和事件数据管理的流行选择。 Elasti…

国产猫粮哪家强?福派斯三文鱼猫粮成新宠!

1️⃣ 品质保证&#xff1a;福派斯三文鱼猫粮是一款由国内知名宠物食品品牌生产的猫粮产品。该品牌有着严格的品质控制&#xff0c;确保每一粒猫粮都符合国家相关标准和规范&#xff0c;为猫咪提供安全、健康的食品。 2️⃣ 营养丰富&#xff1a;福派斯三文鱼猫粮采用新鲜三文鱼…