【深度学习】如何找到最优学习率

经过了大量炼丹的同学都知道,超参数是一个非常玄乎的东西,比如batch size,学习率等,这些东西的设定并没有什么规律和原因,论文中设定的超参数一般都是靠经验决定的。但是超参数往往又特别重要,比如学习率,如果设置了一个太大的学习率,那么loss就爆了,设置的学习率太小,需要等待的时间就特别长,那么我们是否有一个科学的办法来决定我们的初始学习率呢?

在这篇文章中,我会讲一种非常简单却有效的方法来确定合理的初始学习率。

学习率的重要性

目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

深度学习:如何找到最优学习率

这里 α 就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。

深度学习:如何找到最优学习率

学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。

深度学习:如何找到最优学习率
深度学习:如何找到最优学习率

从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。

实现

上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重,model zoo参考于Cadene的repo。目前这个repo刚刚开始,欢迎有兴趣的小伙伴加入我。

下面就是部分代码,近期会把找学习率的代码合并到mxtorch中。这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码点这里。

   
   
  1. criterion = torch.nn.CrossEntropyLoss()
  2. net = model_zoo.resnet50(pretrained=True)
  3. net.fc = nn.Linear(2048, 120)
  4.  
  5. with torch.cuda.device(0):
  6. net = net.cuda()
  7.  
  8. basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)
  9. optimizer = ScheduledOptim(basic_optim)
  10.  
  11.  
  12. lr_mult = (1 / 1e-5) ** (1 / 100)
  13. lr = []
  14. losses = []
  15. best_loss = 1e9
  16. for data, label in train_data:
  17. with torch.cuda.device(0):
  18. data = Variable(data.cuda())
  19. label = Variable(label.cuda())
  20. # forward
  21. out = net(data)
  22. loss = criterion(out, label)
  23. # backward
  24. optimizer.zero_grad()
  25. loss.backward()
  26. optimizer.step()
  27. lr.append(optimizer.learning_rate)
  28. losses.append(loss.data[0])
  29. optimizer.set_learning_rate(optimizer.learning_rate lr_mult)
  30. if loss.data[0] < best_loss:
  31. best_loss = loss.data[0]
  32. if loss.data[0] > 4 best_loss or optimizer.learning_rate > 1.:
  33. break
  34.  
  35. plt.figure()
  36. plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))
  37. plt.xlabel(‘learning rate’)
  38. plt.ylabel(‘loss’)
  39. plt.plot(np.log(lr), losses)
  40. plt.show()
  41. plt.figure()
  42. plt.xlabel(‘num iterations’)
  43. plt.ylabel(‘learning rate’)
  44. plt.plot(lr)

one more thing

通过上面的例子我们能够有一个非常有效的方法寻找初始学习率,同时在我们的认知中,学习率的策略都是不断地做decay,而上面的论文别出心裁,提出了一种循环变化学习率的思想,能够更快的达到最优解,非常具有启发性,推荐大家去阅读阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/190003.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Seurat Tutorial 1:标准分析流程,基于 PBMC 3K 数据集

目录 1 设置 Seurat 对象2 标准预处理工作流程 2.1 QC 和选择细胞进行进一步分析3 数据归一化4 识别高变特征&#xff08;特征选择&#xff09;5 标准化数据6 执行线性降维7 确定数据集的维度8 细胞聚类9 运行非线性降维 (UMAP/tSNE)10 寻找差异表达特征&#xff08;cluster b…

OSG编程指南<十二>:OSG二三维文字创建及文字特效

1、字体基础知识 适当的文字信息对于显示场景信息是非常重要的。在 OSG 中&#xff0c;osgText提供了向场景中添加文字的强大功能&#xff0c;由于有第三方插件 FreeType 的支持&#xff0c;它完全支持TrueType 字体。很多人可能对 FreeType 和 TrueType 还不太了解&#xff0c…

小程序项目:springboot+vue基本微信小程序的宠物领养系统

项目介绍 当今科技发展迅速&#xff0c;交通环境也变得越来越复杂。人们的出行方式变得多元化&#xff0c;这给视障人士带来了一定的困扰。而导盲犬可以帮助视障人士外出行走&#xff0c;提高他们的生活质量。在我国&#xff0c;导盲犬的数量远远少于视障人士的数量。由于导盲…

WPF绘图技术介绍

作者&#xff1a;令狐掌门 技术交流QQ群&#xff1a;675120140 csdn博客&#xff1a;https://mingshiqiang.blog.csdn.net/ 文章目录 WPF绘图基本用法绘制直线在XAML中绘制直线在C#代码中绘制直线使用Path绘制直线注意 矩形绘制在XAML中绘制矩形在C#代码中绘制矩形设置矩形的位…

<JavaEE> Java中线程有多少种状态(State)?状态之间的关系有什么关系?

目录 一、系统内核中的线程状态 二、Java中的线程状态 一、系统内核中的线程状态 状态说明就绪状态线程已经准备就绪&#xff0c;随时可以接受CPU的调度。阻塞状态线程处于阻塞等待&#xff0c;暂时无法在CPU中执行。 二、Java中的线程状态 相比于系统内核&#xff0c;Java…

火锅店管理系统外卖点餐配送小程序的效果

火锅是餐饮行业重要的组成部分&#xff0c;在每个地方都受到追捧&#xff0c;其从业商家众多&#xff0c;头部连锁品牌也不少&#xff0c;近些年行业市场规模也一直增加&#xff0c;而随着消费升级及数字化转型&#xff0c;传统火锅店经营痛点不少&#xff1a; 火锅店的需求非…

智能优化算法应用:基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蚁狮算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

SpringBoot中如何优雅地使用重试

1 缘起 项目中使用了第三方的服务&#xff0c; 第三方服务偶尔会出现不稳定、连接不上的情况&#xff0c; 于是&#xff0c;在调用时为了保证服务的相对高可用&#xff0c;添加了超时连接重试&#xff0c; 当连接第三方服务超时时&#xff0c;多重试几次&#xff0c;比如3次&a…

XML映射文件

<?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapperPUBLIC "-//mybatis.org//DTD Mapper 3.0//EN""http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"org.mybatis.example.BlogMapper&q…

FreeRTOS学习之路,以STM32F103C8T6为实验MCU(2-7:软件定时器)

学习之路主要为FreeRTOS操作系统在STM32F103&#xff08;STM32F103C8T6&#xff09;上的运用&#xff0c;采用的是标准库编程的方式&#xff0c;使用的IDE为KEIL5。 注意&#xff01;&#xff01;&#xff01;本学习之路可以通过购买STM32最小系统板以及部分配件的方式进行学习…

Python大数据考题

Python大数据考题&#xff1a; 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;可能很多算法学生都得去找开发&#xff0c;测开 测开的话&#xff0c;你就得学数据库&#xff0c;sql&#xff0c;oracle&#xff0c;尤其sql要…

CDA一级备考思维导图

CDA一级备考思维导图 第一章 数据分析概述与职业操守1、数据分析概念、方法论、角色2、数据分析师职业道德与行为准则3、大数据立法、安全、隐私 CDA一级复习备考资料共计七个章节&#xff0c;如需资料&#xff0c;请留言&#xff0c;概览如下图&#xff1a; 第一章 数据分析…

Spring面向切面编程(AOP);Spring控制反转(IOC);解释一下Spring AOP里面的几个名词;Spring 的 IoC支持哪些功能

文章目录 Spring面向切面编程(AOP)什么是AOPSpring AOP and AspectJ AOP 的区别&#xff1f;Spring AOP中的动态代理如何理解 Spring 中的代理&#xff1f;解释一下Spring AOP里面的几个名词Spring在运行时通知对象Spring切面可以应用5种类型的通知&#xff1a;什么是切面 Aspe…

【数据结构】堆(C语言)

今天我们来学习堆&#xff0c;它也是二叉树的一种&#xff08;我滴神树&#xff01;&#xff09; 目录 堆的介绍&#xff1a;堆的代码实现&#xff1a;堆的结构体创建&#xff1a;堆的初始化&#xff1a;堆的销毁&#xff1a;堆的push&#xff1a;堆的pop&#xff1a;判空 &am…

OpenCV简介及安装

前言 因为最近想做图像处理、人脸检测/识别之类的相关开发&#xff0c;所以就开始补OpenCV的相关知识&#xff0c;便开个专栏用于记录学习历程和在学习过程中遇到的一些值得注意的重点和坑。 学习过程基本上也是面向官方文档和Google。 简介 OpenCV(开源的计算机视觉库)是基于…

十大排序之归并排序(详解)

文章目录 &#x1f412;个人主页&#x1f3c5;算法思维框架&#x1f4d6;前言&#xff1a; &#x1f380;归并排序 时间复杂度O(n*logn)&#x1f387;1. 算法步骤思想&#x1f387;2、动画演示&#x1f387;3.代码实现 &#x1f412;个人主页 &#x1f3c5;算法思维框架 &#…

渗透测试【一】:渗透测试常见问题

渗透测试【一】&#xff1a;渗透测试常见问题 1、问题清单2、问题现象及解决办法2.1、点击劫持2.2、用户枚举2.3、Springboot未授权访问2.4、Swagger未授权访问2.5、Host头注入2.6、任意文件上传2.7、敏感路径泄露2.8、跨域资源共享2.9、Spring Cloud Gateway RCE2.10、Content…

vsVode C++开发远程虚拟机工程配置

在使用VS Code进行C/C的开发过程中&#xff0c;有三个至关重要的配置文件&#xff0c;分别是 tasks.json, launch.json 和 c_cpp_properties.json 1. tasks.json tasks.json 是在 vscode 中辅助程序编译的模块&#xff0c;可以代你执行类似于在命令行输入 “gcc hello.c -o h…

文件搜索工具HoudahSpot mac中文版特点

HoudahSpot mac是一款文件搜索工具&#xff0c;它可以帮助用户快速准确地找到文件和文件夹&#xff0c;支持高级搜索和过滤&#xff0c;同时提供了多种视图和操作选项&#xff0c;方便用户进行文件管理和整理。 HoudahSpot mac软件特点 高级搜索和过滤功能&#xff1a;软件支持…

大数据项目--学习笔记

新零售项目介绍 1&#xff0c;行业背景介绍 一&#xff0c;百货商店 百货商店是世界商业史上第一个实行新销售方法的现代大量销售组织。其新型销售方法有&#xff1a; 1&#xff0e;顾客可以毫无顾忌地、自由自在地进出商店&#xff1b; 2&#xff0e;商品销售实行“明码标价…