深度学习中的学习率调度器(scheduler)分析并作图查看各方法差异

文章目录

    • 1. 指数衰减调度器(Exponential Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 2. 余弦退火调度器(Cosine Annealing Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 3. 步长衰减调度器(Step Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)
      • 工作原理
      • 适用场景
      • 实现示例
    • 5. 多步衰减调度器(MultiStep Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 总结
    • 参考资料

在深度学习模型训练过程中, 学习率调度器(Learning Rate Scheduler)是优化过程中不可或缺的重要组成部分。它们能够在训练的不同阶段自动调整学习率,从而提高模型的收敛速度和最终性能。选择合适的学习率调度器对于优化训练过程至关重要,不同的调度器适用于不同的训练需求和模型架构。本文将介绍几种常用的学习率调度器,并通过 PyTorch 提供的 torch.optim.lr_schedulertransformers 库中的调度器,展示具体的实现示例及其适用场景。可以通过 运行示例代码来作图查看学习率变化情况,能帮助大家更好的了解不同方法的区别。

1. 指数衰减调度器(Exponential Decay Scheduler)

请添加图片描述

工作原理

指数衰减调度器通过在每个训练步骤中以固定的速率减小学习率,从而逐步降低学习率。这种调度器适用于需要平稳且持续减小学习率的训练过程,有助于模型在训练后期稳定收敛。

适用场景

  • 稳定收敛:适用于希望学习率在整个训练过程中持续且缓慢降低,以避免训练后期的震荡。
  • 简单调整:当训练过程相对稳定,不需要复杂的学习率调整策略时,指数衰减是一个简单有效的选择。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR

# 定义优化器和参数
initial_lr = 5e-5
num_training_steps = 3000
decay_rate = 0.99
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义指数衰减调度器
scheduler = ExponentialLR(optimizer, gamma=decay_rate)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Exponential Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

2. 余弦退火调度器(Cosine Annealing Scheduler)

请添加图片描述

工作原理

余弦退火调度器通过余弦函数调整学习率,使其在训练过程中呈现周期性变化。这种调度器特别适用于处理模型训练中的振荡现象,能够在训练末期提供较低的学习率以帮助模型更好地收敛。

适用场景

  • 避免局部最优:通过周期性调整学习率,可以帮助模型跳出局部最优解。
  • 动态调整:适用于需要在训练过程中动态调整学习率以应对不同训练阶段需求的场景。
  • 模型复杂度较高:对于复杂模型,如深层神经网络,余弦退火有助于更好地探索参数空间。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# 优化器和参数定义同上
initial_lr = 5e-5
num_training_steps = 3000
T_max = 1000  # 一个周期内的步数
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义余弦退火调度器
scheduler = CosineAnnealingLR(optimizer, T_max=T_max)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Cosine Annealing Scheduler')
plt.legend()
plt.grid(True)
plt.show()

3. 步长衰减调度器(Step Decay Scheduler)

在这里插入图片描述

工作原理

步长衰减调度器在训练过程中每隔一定的步数(step_size)后按指定的因子(gamma)降低学习率。这种调度器适用于需要在训练过程中分阶段减小学习率的场景,有助于模型在不同训练阶段进行有效的学习。

适用场景

  • 分阶段训练:适用于需要在训练的特定阶段进行学习率调整的任务,如先快速学习再细致调整。
  • 明确的训练阶段:当训练过程可以划分为多个明确的阶段,每个阶段需要不同学习率时,步长衰减是理想选择。
  • 资源受限的训练:在有限的训练资源下,通过分阶段调整学习率可以更有效地利用计算资源。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
step_size = 500  # 每隔 step_size 个 step,学习率衰减一次
gamma = 0.1      # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义步长衰减调度器
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Step Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)

在这里插入图片描述

工作原理

多项式衰减与预热调度器结合了学习率预热和多项式衰减的优势。训练初期通过预热阶段逐步增加学习率,随后按照多项式函数逐步降低学习率。这种调度器适用于如 BERT 等复杂模型的训练,有助于在训练初期稳定模型参数并在后期促进收敛。

适用场景

  • 复杂模型训练:适用于需要在训练初期进行稳定性的复杂模型,如 Transformer、BERT 等。
  • 防止初期震荡:通过预热阶段逐步增加学习率,可以防止训练初期由于学习率过高导致的梯度震荡。
  • 需要精细控制:适用于需要对学习率进行精细控制,以实现最佳收敛效果的任务。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from transformers import get_polynomial_decay_schedule_with_warmup

# 调度器参数
initial_lr = 5e-5
warmup_steps = 100
num_training_steps = 3000
lr_end = 1e-7  # 最低学习率
power = 2.0    # 多项式衰减的幂次
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义多项式衰减与预热调度器
scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=num_training_steps, 
    lr_end=lr_end, 
    power=power
)  # 二次衰减

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.axvline(x=warmup_steps, color='r', linestyle='--', label='End of Warm-up')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Polynomial Decay Scheduler with Warm-up')
plt.legend()
plt.grid(True)
plt.show()

5. 多步衰减调度器(MultiStep Decay Scheduler)

在这里插入图片描述

工作原理

多步衰减调度器在预设的多个步数(milestones)时刻按指定的因子(gamma)降低学习率。这种调度器允许在训练过程中在多个关键点调整学习率,适用于需要在多个阶段显著改变学习率的训练任务。

适用场景

  • 多阶段训练:适用于训练过程中有多个关键阶段,每个阶段需要不同学习率的任务。
  • 灵活调整:当训练过程不规则或需要根据训练进展手动调整学习率时,多步衰减提供了灵活性。
  • 特定任务需求:适用于一些特定任务或模型架构,需要在特定步数后调整学习率以优化性能。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import MultiStepLR

# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
milestones = [1000, 2000]  # 指定的步数
gamma = 0.1  # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)

# 定义多步衰减调度器
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):
    optimizer.step()
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
for i, milestone in enumerate(milestones):
    if i == 0:
        plt.axvline(x=milestone, color='r', linestyle='--', label=f'Milestone at Step {milestone}')
    else:
        plt.axvline(x=milestone, color='r', linestyle='--')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('MultiStep Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

注意:在多步衰减调度器的绘图代码中,plt.axvline 函数仅在第一个里程碑处添加标签,后续的里程碑标签设置为 None'_nolegend_',以避免图例中出现重复的标签。

总结

以上示例代码展示了不同学习率调度器的实现方式以及学习率随训练步骤变化的过程。选择合适的调度器可以根据具体任务和模型的需求来优化训练效果。以下是各类调度器的快速参考:

  • 指数衰减调度器(Exponential Decay Scheduler):适用于希望学习率持续且缓慢降低,稳定收敛的训练过程。
  • 余弦退火调度器(Cosine Annealing Scheduler):适用于需要动态调整学习率以避免局部最优,尤其适合复杂模型。
  • 步长衰减调度器(Step Decay Scheduler):适用于分阶段训练,明确划分训练阶段的任务。
  • 多项式衰减与预热调度器(Polynomial Decay with Warm-up):适用于复杂模型训练,防止初期震荡并促进后期收敛。
  • 多步衰减调度器(MultiStep Decay Scheduler):适用于多阶段训练,需要在多个关键点调整学习率的任务。

在实际应用中,可以根据模型的复杂度、数据集的特性以及训练的阶段性需求,灵活选择和调整学习率调度策略,以实现最佳的训练效果。

参考资料

  • PyTorch 官方文档 - Learning Rate Scheduler
  • Transformers 库 - 调度器

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/954336.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IPSEC实验

实验要求 某小型企业为扩大网络规模,设立分公司,今日要求分公司能够访问主公司对应的资源,为此很是苦恼 为满足其跨区域访问对端网络的要求,现要求使用IPSEC搭建隧道使得分公司能够与主公司通讯 实验拓扑 该公司与分公司拓扑大…

[c语言日寄]精英怪:三子棋(tic-tac-toe)3命慢通[附免费源码]

哈喽盆友们,今天带来《c语言》游戏中[三子棋boss]速通教程!我们的目标是一边编写博文,一边快速用c语言实现三子棋游戏。准备好瓜子,我们计时开始! 前期规划 在速通中,我们必须要有清晰的前期规划&#xf…

TensorFlow DAY3: 高阶 API(Keras,Estimator)(完)

TensorFlow 作为深度学习框架,当然是为了帮助我们更便捷地构建神经网络。所以,本次实验将会了解如何使用 TensorFlow 来构建神经网络,并学会 TensorFlow 构建神经网络的重要函数和方法。 知识点 Keras 顺序模型Keras 函数模型Keras 模型存储…

数据结构(Java版)第九期:LinkedList与链表

专栏:数据结构(Java版) 个人主页:手握风云 目录 一、LinkedList的模拟实现 1.1. 头插法 1.2. 尾插法 1.3. 插入中间节点 1.4. 删除某个节点 1.5. 删除所有为key的元素 二、LinkedList的使用 2.1. 什么是LinkedList 2.2. LinkedList的使⽤ 三、…

ubuntu18.04开发环境下samba服务器的搭建

嵌入式linux的发展很快,最近准备在一个新项目上采用新一代的linux核心板,发现linux内核的版本已经更新到5.4以上甚至6.0以上;之前常用的linux内核版本是2.6.4,虽然在某些项目上还能用但是明显跟不上时代的步伐了,所以要…

【优先算法】滑动窗口--(结合例题讲解解题思路)(C++)

目录 1. 例题1:最大连续1的个数 1.1 解题思路 1.2代码实现 1.3 错误示范如下:我最开始写了一种,但是解答错误,请看,给大家做个参考 2. 将 x 减到 0 的最小操作数 2.1解题思路 2.2代码实现 1. 例题1&#xff…

数据结构二叉树-C语言

数据结构二叉树-C语言 1.树1.1树的概念与结构1.2树的相关术语1.3树的表示1.4树形结构实际运用场景 2.二叉树2.1概念与结构2.2特殊的二叉树2.2.1满二叉树2.2.2完全二叉树 2.3二叉树存储结构2.3.1顺序结构2.3.2链式结构 3.实现顺序结构的二叉树4.实现链式结构二叉树4.1前中后序遍…

Qt/C++进程间通信:QSharedMemory 使用详解(附演示Demo)

在开发跨进程应用程序时,进程间通信(IPC)是一个关键问题。Qt 框架提供了多种 IPC 技术,其中 QSharedMemory 是一种高效的共享内存方式,可以实现多个进程之间快速交换数据。本文将详细讲解 QSharedMemory 的概念、用法及…

【vue3项目使用 animate动画效果】

vue3项目使用 animate动画效果 前言一、下载或安装npm 安装 二、引入组件三、复制使用四、完整使用演示总结 前言 提示:干货篇,不废话,点赞收藏,用到会后好找藕~ 点击这里,直接看官网哦 👉 官网地址&#…

Android 15应用适配指南:所有应用的行为变更

Android系统版本适配,一直是影响App上架Google Play非常重要的因素。 当前Google Play政策规定 新应用和应用更新 必须以 Android 14(API 级别 34)为目标平台,才能提交到Google Play。现有应用 必须以 Android 13(AP…

qml TargetDirection详解

1、概述 TargetDirection是QML(Qt Modeling Language)中一个用于指定粒子系统中粒子移动方向的类型。它允许粒子朝向一个目标点移动,这个目标点可以是QML界面上的一个具体位置,也可以是另一个QML元素的中心。TargetDirection通常…

Linux C 使用ZBar库解析二维码和条形码

1. 编译zbar库 下载 zbar 库源码,这里需要注意下,如果识别的二维码中有中文的话,会出现乱码,一般二维码里中文为UTF-8编码,zbar会默认给你把UTF-8转换为ISO8859-1。有两种解决办法,一是自己再转换一下编码…

金融项目实战 06|Python实现接口自动化——日志、实名认证和开户接口

目录 一、日志封装及应用(理解) 二、认证开户接口脚本编写 1、代码编写 1️⃣api目录 2️⃣script目录 2、BeautifulSoup库 1️⃣简介及例子 2️⃣提取html数据工具封装 3、认证开户参数化 一、日志封装及应用(理解) &…

基于springboot+vue+微信小程序的宠物领养系统

基于springbootvue微信小程序的宠物领养系统 一、介绍 本项目利用SpringBoot、Vue和微信小程序技术,构建了一个宠物领养系统。 本系统的设计分为两个层面,分别为管理层面与用户层面,也就是管理者与用户,管理权限与用户权限是不…

【微服务】面试题 5、分布式系统理论:CAP 与 BASE 详解

分布式系统理论:CAP 与 BASE 详解 一、CAP 定理 背景与定义:1998 年由加州大学科学家埃里克布鲁尔提出,分布式系统存在一致性(Consistency)、可用性(Availability)、分区容错性(Part…

【Vue】Vue组件--上

目录 一、组件基础 二、组件的嵌套关系 1. 基础架构 2. 嵌套 三、组件注册方式 1. 局部注册: 2. 全局注册: 四、组件传递数据 1. 基础架构 2. 传递多值 3. 动态传递数据 五、组件传递多种数据类型 1. Number 2. Array 3. Object 六、组…

鸿蒙UI开发——键盘弹出避让模式设置

1、概 述 我们在鸿蒙开发时,不免会遇到用户输入场景,当用户准备输入时,会涉及到输入法的弹出,我们的界面针对输入法的弹出有两种避让模式:上抬模式、压缩模式。 下面针对输入法的两种避让模式的设置做简单介绍。 2、…

python Streamlit和AKShare 实现的股票数据查询系统

1. 系统概述 这是一个基于Streamlit和AKShare的股票数据查询系统,提供了便捷的股票数据查询和可视化功能。系统支持按板块筛选股票、多股票代码查询、数据导出等功能。 1.1 主要功能 股票代码直接输入查询按板块筛选和选择股票历史数据和实时行情查询财务报表数据…

蓝桥杯备赛:顺序表和单链表相关算法题详解(上)

目录 一.询问学号(顺序表) 1.题目来源: 2.解析与代码实现: (1)解析: (2)代码实现: 二.寄包柜(顺序表) 1.题目来源: …

数据结构-ArrayLIst-一起探索顺序表的底层实现

各位看官早安午安晚安呀 如果您觉得这篇文章对您有帮助的话 欢迎您一键三连,小编尽全力做到更好 欢迎您分享给更多人哦 大家好,我们今天来学习java数据结构的第一章ArrayList(顺序表) 1.ArrayList的概念 那小伙伴就要问了线性表到…