pytorch08:学习率调整策略

在这里插入图片描述

目录

  • 一、为什么要调整学习率?
    • 1.1 class _LRScheduler
  • 二、pytorch的六种学习率调整策略
    • 2.1 StepLR
    • 2.2 MultiStepLR
    • 2.3 ExponentialLR
    • 2.4 CosineAnnealingLR
    • 2.5 ReduceLRonPlateau
    • 2.6 LambdaLR
  • 三、学习率调整小结
  • 四、学习率初始化

一、为什么要调整学习率?

学习率(learning rate):控制更新的步伐
一般在模型训练过程中,在开始训练的时候我们会设置学习率大一些,随着模型训练epoch的增加,学习率会逐渐设置小一些。

1.1 class _LRScheduler

学习率调整的父类函数
在这里插入图片描述
主要属性:
• optimizer:关联的优化器
• last_epoch:记录epoch数
• base_lrs:记录初始学习率
主要方法:
• step():更新下一个epoch的学习率,该操作必须放到epoch循环下面
• get_lr():虚函数,计算下一个epoch的学习率

二、pytorch的六种学习率调整策略

2.1 StepLR

在这里插入图片描述

功能:等间隔调整学习率
主要参数:
• step_size:调整间隔数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现:

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 1 Step LR ------------------------------
# flag = 0
flag = 1
if flag:

    scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略,50轮下降一次,每次下降10倍

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()  # 学习率更新策略

    plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果:
在这里插入图片描述

因为我们设置每50个epoch降低一次学习率,所以在7774554

2.2 MultiStepLR

在这里插入图片描述

功能:按给定间隔调整学习率
主要参数:
• milestones:设定调整时刻数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现

flag = 1
if flag:

    milestones = [50, 125, 160]  # 设置学习率下降的位置
    scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

根据我们设置milestones = [50, 125, 160],发现学习率在这三个地方发生下降。

2.3 ExponentialLR

在这里插入图片描述

功能:按指数衰减调整学习率
主要参数:
• gamma:指数的底
调整方式:lr = lr * gamma^epoch;这里的gamma通常设置为接近1的数值,例如:0.95

代码实现

flag = 1
if flag:

    gamma = 0.95
    scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

可以发现学习率是呈指数下降的。

2.4 CosineAnnealingLR

在这里插入图片描述

功能:余弦周期调整学习率
主要参数:
• T_max:下降周期
• eta_min:学习率下限
调整方式:
在这里插入图片描述

代码实现

flag = 1
if flag:
    t_max = 50
    scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)
    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)
        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()
    plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

2.5 ReduceLRonPlateau

在这里插入图片描述

功能:监控指标,当指标不再变化则调整,例如:可以监控我们的loss或者准确率,当其不发生变化的时候,调整学习率。
主要参数:
• mode:min/max 两种模式
min模式:当某一个值不下降的时候我们调整学习率,通常用于监控损失
max模型:当某一个值不上升的时候我们调整学习率,通常用于监控精确度
• factor:调整系数
• patience:“耐心”,接受几次不变化
• cooldown:“冷却时间”,停止监控一段时间
• verbose:是否打印日志
• min_lr:学习率下限
• eps:学习率衰减最小值

代码实现

flag = 1
if flag:
    loss_value = 0.5
    accuray = 0.9
    factor = 0.1  # 学习率变换参数
    mode = "min"
    patience = 10  # 能接受多少轮不变化
    cooldown = 10  # 停止监控多少轮
    min_lr = 1e-4  # 设置学习率下限
    verbose = True  # 打印更新日志
    scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,
                                                        cooldown=cooldown, min_lr=min_lr, verbose=verbose)
    for epoch in range(max_epoch):
        for i in range(iteration):
            # train(...)
            optimizer.step()
            optimizer.zero_grad()
        #if epoch == 5:
           # loss_value = 0.4
        scheduler_lr.step(loss_value) #监控的标量是否下降

输出结果
在这里插入图片描述

2.6 LambdaLR

在这里插入图片描述
功能:自定义调整策略
主要参数:
• lr_lambda:function or list

代码实现

flag = 1
if flag:

    lr_init = 0.1

    weights_1 = torch.randn((6, 3, 5, 5))
    weights_2 = torch.ones((5, 5))

    optimizer = optim.SGD([
        {'params': [weights_1]},
        {'params': [weights_2]}], lr=lr_init)

    # 设置两种不同的学习率调整方法
    lambda1 = lambda epoch: 0.1 ** (epoch // 20)  # 每到20轮的时候学习率变为原来的0.1倍
    lambda2 = lambda epoch: 0.95 ** epoch  # 将学习率进行指数下降

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        for i in range(iteration):
            # train(...)

            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

        lr_list.append(scheduler.get_lr())
        epoch_list.append(epoch)

        print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))

    plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")
    plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.title("LambdaLR")
    plt.legend()
    plt.show()

输出结果
在这里插入图片描述

通过lambda方法定义了两种不同的学习率下降策略。

三、学习率调整小结

  1. 有序调整:Step、MultiStep、Exponential 和 CosineAnnealing
  2. 自适应调整:ReduceLROnPleateau
  3. 自定义调整:Lambda

四、学习率初始化

1、设置较小数:0.01、0.001、0.0001
2、搜索最大学习率: 参考该篇《Cyclical Learning Rates for Training Neural Networks》
方法:我们可以设置学习率逐渐从小变大观察精确度的一个变化,下面这幅图,当学习率为0.055左右的时候模型精确度最高,当学习率大于0.055的时候精确度出现下降情况,所以在模型训练过程中我们可以设置学习率为0.055作为我们的初始学习率。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/293846.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

宝宝洗衣机哪个牌子质量好?好用的小型洗衣机推荐

当婴儿的到来,确实会给家庭带来许多变化,就好比如对于宝宝相关衣物的清洗需求。对于新生儿及婴幼儿的衣服,一般都要给予特殊的照顾与清洗,以保证不含细菌及过敏原。尤其是刚刚出生的婴儿,这时候宝宝们的皮肤很是幼嫩。…

JAVA基础学习笔记-day12-泛型

JAVA基础学习笔记-day12-泛型 1. 泛型概述1.1 泛型的引入 2. 使用泛型举例2.1 集合中使用泛型2.2 相关使用说明 3. 自定义泛型结构3.1 泛型的基础说明3.2 自定义泛型类或泛型接口3.2.1 说明3.2.2 注意 3.3 自定义泛型方法3.3.1 说明 4. 泛型在继承上的体现5. 通配符的使用5.1 通…

MySQL面试题汇总

常规: 1、数据库三大范式 1NF : 表中字段的数据不可再拆分。(原子性)2NF : 在满足第一范式的情况下,遵循唯一性,消除部分依赖。即,表中任意一个主键或任意一组联合主键,可以确定除该主键外的所有的非主键值。(一个表…

(适趣AI)Vue笔试题

📑前言 本文主要是【Vue】——(适趣AI)Vue笔试题的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 …

Android studio报错误提示 Some Kotlin libraries attached to this project 问题解决方案

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂,风趣幽默",感觉非常有意思,忍不住分享一下给大家。 👉点击跳转到教程 Android新建项目后,报以下错误 错误提示内容为: 这个项目附带的一…

pygame学习(二)——绘制线条、圆、矩形等图案

导语 pygame是一个跨平台Python库(pygame news),专门用来开发游戏。pygame主要为开发、设计2D电子游戏而生,提供图像模块(image)、声音模块(mixer)、输入/输出(鼠标、键盘、显示屏)模…

Spring Boot 与 Spring 框架的区别

一、前言 Spring Boot 和 Spring 框架是由 Spring 项目提供的两个关键的技术栈,它们在 Java 开发中扮演着不同的角色。在阐述其区别之前,我们先大致了解下这两个框架 二、Spring 框架 1、背景 Spring 框架是一个全栈的企业应用开发框架,起…

营销的尽头是矩阵!如何通过小魔推短视频矩阵快速破圈?

“ 营销的尽头是矩阵! 相信很多做互联网的朋友都听过这么一句话 在抖音上我们看到过大批的博主,都是通过矩阵的方式火遍全网,就比如张琦、小杨哥等等,矩阵的方式适用于大多数的实体品牌,以及个人IP 等&#xff0c…

DevOps(3)

目录 11.描述root账户? 12.如何在发出命令时打开命令提示符? 14.Linux系统下交换分区的典型大小是多少? 15.什么是符号链接? 11.描述root账户? root账户就像一个系统管理员账户,允许你完全控制系统。 …

目标跟踪算法中的卡尔曼滤波学习

在使用多目标跟踪算法时,接触到卡尔曼滤波,一直没时间总结下,现在来填坑。 1. 背景知识 在理解卡尔曼滤波前,有几个概念值得考虑下:时序序列模型,滤波,线性动态系统 1. 时间序列模型 时间序…

AspectJ入门(二)— 应用

AspectJ便于调试、测试和性能调整工作。定义的行为范围从简单的跟踪到分析,再到应用程序内部一致性到测试。AspectJ可以干净地模块化这类功能,从而可以在需要时轻松地启用和禁用这些功能。 1 基础 本节将继续介绍AspectJ到一些基础功能,为后…

负载均衡案例:如何只用2GB内存统计20亿个整数中出现次数最多的整数

基于python实现。 如果是常规的小型文件&#xff0c;我们可以迅速地想到要建立字典。 以数字为key&#xff0c;以数字的出现次数为value&#xff0c;建立<int,int>类型的键值对存入字典&#xff0c;然后使用 max 函数结合字典的 items 方法来找到一个字典中 value 最大的…

2023 波卡年度报告选读:Polkadot SDK 与开发者社区

原文&#xff1a;https://dashboards.data.paritytech.io/reports/2023/index.html#section6 编译&#xff1a;OneBlock 编者注&#xff1a;Parity 数据团队发布的 2023 年 Polkadot 年度数据报告&#xff0c;对推动生态系统的关键数据进行了深入分析。报告全文较长&#xff…

一键减低PNG像素,轻松优化图片质量!

在数字时代&#xff0c;我们每天都要处理大量的图片文件&#xff0c;从网站设计、广告素材到社交媒体图片等。PNG作为一种常用的无损压缩格式&#xff0c;在保证图片质量的同时&#xff0c;也占用了较大的存储空间。为了优化存储空间和提高加载速度&#xff0c;我们需要对PNG图…

数据结构学习笔记——查找算法中的树形查找(B树、B+树)

目录 前言一、B树&#xff08;一&#xff09;B树的概念&#xff08;二&#xff09;B树的性质&#xff08;三&#xff09;B树的高度&#xff08;四&#xff09;B树的查找&#xff08;五&#xff09;B树的插入&#xff08;六&#xff09;B树的删除 二、B树&#xff08;一&#xf…

科技助力教育:数字化如何改变家校社协同育人?

近年来,随着社会的快速发展,教育的责任已不再仅局限于学校。家庭、学校和社会协同育人理念,正成为促进教育高质量发展的关键要素。 2023年初,教育部等十三部门联合印发《关于健全学校家庭社会协同育人机制的意见》,提出到“十四五”时期末,形成更加完善的由“学校积极主导、家…

Excel如何将单元格设为文本

文章目录 一、打开excel文件二、选中单元格三、右键设置单元格格式四、设置界面选择文本后点确定五、其他问题 在caa开发过程中遇到从CATUnicodeString转成CString时&#xff0c;通过SetItemText写入将ID号写入单元格&#xff0c;无法保存ID号中的数字0&#xff0c;故将单元格格…

统信UOS_麒麟KYLINOS修改图标显示名称

原文链接&#xff1a;统信UOS/麒麟KYLINOS修改图标显示名称 hello&#xff0c;大家好啊&#xff01;今天我要给大家介绍的是在统信UOS及麒麟KYLINOS操作系统上如何修改软件的名称。这种自定义可以帮助您更快地识别和访问常用的应用程序&#xff0c;也可以使您的桌面环境更加个性…

【MATLAB】CEEMD_LSTM神经网络时序预测算法

有意向获取代码&#xff0c;请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 CEEMD-LSTM神经网络时序预测算法是一种结合了完全扩展经验模态分解&#xff08;CEEMD&#xff09;和长短期记忆神经网络&#xff08;LSTM&#xff09;的时间序列预测方法。 CEEMD是一种改…

基于MyCat2.0实现MySQL分库分表方案

目录 一、MyCat概述 二、MyCat作用 2.1 数据分片 2.1.1 垂直拆分 2.1.1.1 垂直分库 2.1.1.2 垂直分表 2.1.1.3 总结 2.1.2 水平拆分 2.1.2.1 水平分库 2.1.2.2 水平分表 2.1.2.3 总结 2.2 读写分离 2.3 多数据源整合 三、MyCat 与ShardingJDBC的区别 3.1 MyCat …