学习率设置(写给自己看)

现往你的.py文件上打上以下代码:

import torch
import numpy as np
from torch.optim import SGD
from torch.optim import lr_scheduler
from torch.nn.parameter import Parameter

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, lr=0.1)

然后在最后的循环打上以下代码:

epochs=100
for epoch in (1,epochs+1):
    train()
    test()
    lr_schedulers.step()

这里的train和test是你的训练和测试调用的函数。

学习率参数很难调节,针对图像分类任务,一般使用的是:

1.阶梯型衰减,

就是在指定的批次上降低指定倍数,比如如果100个epoch,设置在1/3和3/4处学习率减小一倍,这种有两种实现方式:

方式一:

lr_schedulers=lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

方式二:

epochs=100
for epoch in (1,epochs+1):
    if epoch%30 == 0:
        lr = lr*0.1
    train()
    test()
    lr_schedulers.step()

2.MultiStepLR:多个不同速率的衰减

方式一:

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)

方式二:

for epoch in (1,epochs+1):
    if epoch == 30:
        lr = lr*0.1
    if epoch == 40:
        lr = lr*0.5
    train()
    test()
    lr_schedulers.step()

3.指数型下降的学习率调节器

公式:

 curr_rate:当前的学习率

 init_rate:初始的学习率

gamma:衰减系数

epochs:计数器,从0计数到训练的迭代次数

decay_step:控制衰减速度

公式表达的含义其实很明显,gamma衰减系数代表的就是衰减函数的形状,>1学习率就增长了,<1学习率就衰减了。代码实现:

X = []
Y = []
# 初始学习率
learning_rate = 0.1
# 衰减系数
decay_rate = 0.1
# decay_steps控制衰减速度
# 如果decay_steps大一些,(global_step / decay_steps)就会增长缓慢一些
#   从而指数衰减学习率decayed_learning_rate就会衰减得慢一些
#   否则学习率很快就会衰减为趋近于0
decay_steps = 60
# 迭代轮数
global_steps = 120
# 指数学习率衰减过程
for global_step in range(0,global_steps):
    decayed_learning_rate = learning_rate * decay_rate**(global_step / decay_steps)
    X.append(global_step / decay_steps)
    Y.append(decayed_learning_rate)
    if global_step==0 or global_step==global_steps-1:
        print("global step: %d, learning rate: %f" % (global_step,decayed_learning_rate))
    
fig = plt.figure(1)
ax = fig.add_subplot(1,1,1)
curve = ax.plot(X,Y,'b',label="learning rate")
ax.legend()
ax.set_xlabel("epochs / decay_steps")
ax.set_ylabel("learning_rate")

你通过设置初始学习率和最后想要下降到的学习率试着模拟一下。 效果还是不错的。

 实现方式:

实质上pytorch里面有:

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

但是和上面的公式是有出入的,他的实现方式其实就是当前的学习率乘以gamma系数值,所以在最后学习率肯定会同样的衰减率torch里面下降的是比上面的快的,所以有两种策略,第一种调整gamma系数然后打印每次的学习率的数值调整到自己想要的学习率大小,即:

我i试了试改成0.96差不多就可以了。

第二种就是把上面的方式封装成一个函数,在for循环里每次调用他,封装成函数就可以使用

LambdaLR学习策略

了,它可以自定义函数,实现方式如下:

# 初始学习率
learning_rate = 0.1
# 衰减系数
decay_rate = 0.1
# decay_steps 控制衰减速度
decay_steps = 60
# 迭代轮数
global_steps = 120


# 自定义指数衰减函数
def exponential_decay(initial_lr, decay_rate, decay_steps, global_step):
    return initial_lr * decay_rate**(global_step / decay_steps)



scheduler = LambdaLR(optimizer, lr_lambda=lambda step: exponential_decay(learning_rate, decay_rate, decay_steps, step))

# 记录学习率的变化
lr_history = []

# 模拟训练过程
for epoch in range(global_steps):
    # 执行训练步骤
    # ...

    # 记录当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    lr_history.append(current_lr)

    # 更新学习率
    scheduler.step()

这个函数就非常的方便,像是上面的多阶段衰减也可以使用这个函数进行实现。

 OneCycleLR

scheduler=lr_scheduler.OneCycleLR(optimizer,max_lr=0.1,pct_start=0.5,total_steps=120,div_factor=10,final_div_factor=10)

可视化 OneCycleLR:

import torch
from torch.optim.lr_scheduler import OneCycleLR
import matplotlib.pyplot as plt

# 定义神经网络和优化器
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

net = SimpleNet()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

# 定义 OneCycleLR 学习率调度器
scheduler = OneCycleLR(optimizer, max_lr=0.1, pct_start=0.5, total_steps=120, div_factor=10, final_div_factor=10)

# 记录学习率的变化
lr_history = []

# 模拟训练过程
for epoch in range(120):
    # 执行训练步骤
    # ...

    # 记录当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    lr_history.append(current_lr)

    # 更新学习率
    scheduler.step()

# 绘制学习率变化曲线
plt.plot(range(120), lr_history, label="learning rate")
plt.xlabel("epochs")
plt.ylabel("learning rate")
plt.legend()
plt.show()

最后一个余弦退火学习率衰减CosineAnnealingLR

CosineAnnealingLR是余弦退火学习率,T_max是周期的一半,最大学习率在optimizer中指定,最小学习率为eta_min。这里同样能够帮助逃离鞍点。值得注意的是最大学习率不宜太大,否则loss可能出现和学习率相似周期的上下剧烈波动。

基本上的选择方式是选择1/4个余弦函数的周期。

可视化:

这里官方文档的公式说明讲的很清晰,自行学习吧: 

Parameters 参数

  • optimizer (Optimizer) - 包装优化器。

  • T_max (int) - 最大迭代次数。

  • eta_min (float) - 最低学习率。默认值:0。

  • last_epoch (int) - 上一个纪元的索引。默认值:-1。

  • verbose (bool) – 如果 True ,则在每次更新时向 stdout 打印一条消息。默认值: False .

今天的学习就到这里,散会!

ps:最近心情有点糟糕,六级+期末考试+实验出了些问题,好累,今晚好好睡一觉吧,晚安各位。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/217030.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

12.04 二叉树中等题

513. 找树左下角的值 给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1: 输入: root [2,1,3] 输出: 1 思路&#xff1a;找到最低层中最左侧的节点值&#xff0c;比较适合层序遍历&#xff0c;返回最…

【matlab】QR分解

QR分解 给定一个mn的矩阵A&#xff0c;其中m≥n&#xff0c;即矩阵A是高矩阵或者是方阵&#xff0c;QR分解将矩阵A分解为两个矩阵Q和R的乘积&#xff0c;其中矩阵Q是一个mn的各列正交的矩阵&#xff0c;即QTQI&#xff0c;矩阵R是一个nn的上三角矩阵&#xff0c;其对角线元素为…

初识动态规划算法(题目加解析)

文章目录 什么是动态规划正文力扣题第 N 个泰波那契数三步问题使用最小花费爬楼梯 总结 什么是动态规划 线性动态规划&#xff1a;是可以用一个dp表来存储内容&#xff0c;并且找到规律存储,按照规律存储。让第i个位置的值等于题目要求的答案 >dp表&#xff1a;dp表就是用一…

【数据结构】——栈|队列(基本功能)

目录 栈 基本概念 栈的常见基本操作 栈的存储 ✌栈的基本操作实现 栈的构建 栈的初始化 入栈 打印栈 出栈 获取栈顶元素 获取栈的有效元素个数 判断栈是否为空 销毁栈 队列 基本概念 队列的常见基本操作 ✌队列的基本操作实现 队列的构建 初始化 入队列 出…

不再只是android,华为自爆Harmony将对标iOS

今年10月&#xff0c;华为官方宣布&#xff0c;鸿蒙OS 4升级设备数量已突破1亿&#xff0c;成为史上升级最快的鸿蒙OS版本。 日前&#xff0c;据数码博主“定焦数码”消息&#xff0c;大厂技术员工做适配&#xff0c;通过线下沟通时&#xff0c;华为反复提到一个问题&#xff…

实战技巧:为Android应用设置独立的多语言

原文链接 实战技巧&#xff1a;为Android应用设置独立的多语言 通常情况下多语言的设置都在系统设置中&#xff0c;应用需要做的就是提供本应用所使用的字串的多语言翻译&#xff0c;使用时使用R.string.app_name类似的引用&#xff0c;然后系统会根据用户在系统设置中的选项来…

不瞒各位,不安装软件也能操作Xmind文档

大家好&#xff0c;我是小悟 作为搞技术的一个人群&#xff0c;时不时就要接收产品经理发过来的思维脑图&#xff0c;而此类文档往往是以Xmind编写的&#xff0c;如果你的电脑里面没有安装Xmind的话&#xff0c;不好意思&#xff0c;是打不开这类后缀结尾的文档。 打不开的话…

【雷电模拟器桥接问题解决方法】

1.ROOT权限开启 2.开启网络桥接模式&#xff0c;选择静态IP设置&#xff0c;点击安装桥接网卡&#xff0c;填写IP地址&#xff08;注意&#xff1a;IP地址要与host主机在同一IP段内&#xff09; 3.重启后 adb shell就能进入到模拟器控制台中了&#xff0c;如果出现以下内容&…

进程程序替换和shell实现

先前fork说创建子进程执行代码&#xff0c;如何让子进程执行和父进程完全不一样的代码?程序替换。 一 单进程替换演示 1 execl函数使用 最近转到在vs code下写代码&#xff0c;之前也在xhell下用过execl函数&#xff0c;所以才想写篇博客总结总结&#xff0c;没想到在vs code…

(C语言)计算n的阶乘

要求使用双精度 #include<stdio.h> double factorial(int n) {if(n 1)return 1;return n * factorial(n-1); } int main() {int n ;double res;scanf("%d",&n);res factorial(n);printf("%lf",res); return 0; } 运行截图&#xff1a; 注&am…

oops-framework框架 之 界面管理(三)

引擎&#xff1a; CocosCreator 3.8.0 环境&#xff1a; Mac Gitee: oops-game-kit 注&#xff1a; 作者dgflash的oops-framework框架QQ群&#xff1a; 628575875 回顾 在上文中主要通过oops-game-kit大家了一个新的模版项目&#xff0c; 主要注意项是resources目录下的两个文…

Python Opencv实践 - Yolov3目标检测

本文使用CPU来做运算&#xff0c;未使用GPU。练习项目&#xff0c;参考了网上部分资料。 如果要用TensorFlow做检测&#xff0c;可以参考这里 使用GPU运行基于pytorch的yolov3代码的准备工作_little han的博客-CSDN博客文章浏览阅读943次。记录一下自己刚拿到带独显的电脑&a…

卷积神经网络(CNN):艺术作品识别

文章目录 一、前言一、设置GPU二、导入数据1. 导入数据2. 检查数据3. 配置数据集4. 数据可视化 三、构建模型四、编译五、训练模型六、评估模型1. Accuracy与Loss图2. 混淆矩阵3. 各项指标评估 一、前言 我的环境&#xff1a; 语言环境&#xff1a;Python3.6.5编译器&#xf…

继承 多态 拆箱装箱 128陷阱 枚举类

继承 在java里一个类只能继承一个类&#xff0c;但可以被多个类继承&#xff1b;c里一个类可以继承多个类&#xff1b; 子类可以使用父类的方法&#xff1b; 在java中&#xff0c;Object是所有类的父类&#xff1b; equals方法比较的是对象是否指向同一个地方&#xff0c;这个方…

原生横向滚动条 吸附 页面底部

效果图 /** 横向滚动条 吸附 页面底部 */ export class StickyHorizontalScrollBar {constructor(options {}) {const { el, style } optionsthis.createScrollbar(style)this.insertScrollbar(el)this.setScrollbarSize()this.onEvent()}/** 创建滚轴组件元素 */createS…

Windows下打包C++程序无法执行:无法定位程序输入点于动态链接库

1、问题描述 环境&#xff1a;CLionCMakeMinGW64遇到问题&#xff1a;打包的exe无法运行&#xff0c;提示无法定位程序输入点于动态链接库。 2、解决思路 ​ 通过注释头文件的方式&#xff0c;初步定位问题是因为使用了#include <thread> 多线程库引起的。而且exe文件…

外包干了2个月,技术倒退2年。。。。。

先说一下自己的情况&#xff0c;本科生&#xff0c;20年通过校招进入深圳某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年国庆&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试…

如何创建maven项目的多模块项目

Maven多模块项目是指一个Maven项目中包含多个子模块&#xff0c;每个子模块又是一个独立的Maven项目&#xff0c;但它们之间可以存在依赖关系。Maven多模块项目可以方便地管理多个子模块的依赖和构建过程&#xff0c;同时也可以提高项目的可维护性和可扩展性。创建maven项目的父…

ChatGPT发布一年后,搜索引擎的日子还好吗?

导读&#xff1a;生成式AI&#xff0c;搜索引擎的终结者还是进化加速器 ChatGPT发布刚刚一年&#xff0c;互联网世界已经换了人间。 2023年&#xff0c;以ChatGPT和大模型为代表的生成式AI浪潮对全球互联网、云计算、人工智能领域都带来巨大冲击。而且生成式AI在各行各业的应用…

深入理解JVM虚拟机第二十七篇:详解JVM当中InvokeDynamic字节码指令,Java是动态类型语言么?

😉😉 学习交流群: ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚📚 工作微信:BigTreeJava 拉你进微信群,免费领取! 🍎🍎4:本文章内容出自上述:Sp…