【SGDR】《SGDR:Stochastic Gradient Descent with Warm Restarts》

在这里插入图片描述

arXiv-2016

code: https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py


文章目录

  • 1 Background and Motivation
  • 2 Related Work
  • 3 Advantages / Contributions
  • 4 Method
  • 5 Experiments
    • 5.1 Datasets and Metric
    • 5.2 Single-Model Results
    • 5.3 Ensemble Results
    • 5.4 Experiments on a Dataset of EEG Recordings
    • 5.5 Preliminary Experiments on a downsampled ImageNet Dataset
  • 6 Conclusion(own) / Future work


1 Background and Motivation

训练深度神经网络的过程可以视为找下面这个方程 min 解的过程

在这里插入图片描述
或者用二阶导的形式

在这里插入图片描述

然而 inverse Hessian 不好求(【矩阵学习】Jacobian矩阵和Hessian矩阵)

虽然有许多改进的优化方法来尽可能的逼近 inverse Hessian,但是,目前在诸多计算机视觉相关任务数据集上表现最好的方法还是 SGD + momentum

有了实践中比较猛的 optimization techniques 后,The main difficulty in training a DNN is then associated with the scheduling of the learning rate and the amount of L2 weight decay regularization employed.

本文,作者从 learning rate schedule 角度出发,提出了 SGDR 学习率策略,periodically simulate warm restarts of SGD

在这里插入图片描述
使得深度学习任务收敛的更快更好

2 Related Work

In applied mathematics, multimodal optimization deals with optimization tasks that involve finding all or most of the multiple (at least locally optimal) solutions of a problem, as opposed to a single best solution.

  • restarts in gradient-free optimization

    based on niching methods(见文末总结部分)

  • restarts in gradient-based optimization
    《Cyclical Learning Rates for Training Neural Networks》(WACV-2017)
    closely-related to our approach in its spirit and formulation but does not focus on restarts
    一个 soft restart 一个 hard restart(SGDR)

3 Advantages / Contributions

SGD + warm restart 技术的结合,或者说 warm restart 在 SGD 上的应用

两者均非原创,在一些小数据集(输入分辨率有限)上有提升,泛化性能还可以

速度上比SGD收敛的要快一些,x2 ~x4

4 Method

periodically simulate warm restarts of SGD

SGD with momentum
在这里插入图片描述
再加 warm start

在这里插入图片描述
蓝色和红色是之前的 step 式 learning rate schedule( A common learning rate schedule is to use a constant learning rate and divide it by a fixed constant in (approximately) regular interval)

其余颜色是作者的 SGDR 伴随不同的参数配置

核心公式

在这里插入图片描述

  • η \eta η 是学习率

  • i − t h i-th ith run

  • T c u r T_{cur} Tcur accounts for how many epochs have been performed since the last restar, T c u r = 0 T_{cur} = 0 Tcur=0 学习率最大,为 η m a x \eta_{max} ηmax T c u r = T i T_{cur}=T_i Tcur=Ti 时学习率最小为 η m i n \eta_{min} ηmin

  • T i T_i Ti cosine 的一个下降周期对应的 epoch 数或 iteration 数

每次 new start 的时候, η m i n \eta_{min} ηmin 或者 η m a x \eta_{max} ηmax 可调整

让每个周期变得越来越长的话,可以设置 T m u l t T_{mult} Tmult > 1(eg =2, it doubles the maximum number of epochs for every new restart. The main purpose of this doubling is to reach good test error as soon as possible)

code,来自 Cosine Annealing Warm Restart

# 导包
from torch import optim
from torch.optim import lr_scheduler

# 定义模型
model, parameters = generate_model(opt)

# 定义优化器
if opt.nesterov:
    dampening = 0
else:
    dampening = 0.9
optimizer = opt.SGD(parameters, lr=0.1, momentum=0.9, dampening=dampending, weight_decay=1e-3, nesterov=opt.nesterov)

# 定义热重启学习率策略
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0, last_epoch=-1)

在这里插入图片描述

5 Experiments

5.1 Datasets and Metric

CIFAR-10:Top1 error
CIFAR-100:Top1 error
a dataset of electroencephalographic (EEG)
downsampled ImageNet(32x32): Top1 and Top5 error

5.2 Single-Model Results

在这里插入图片描述
WRN 网络 with depth d and width k

CIFAR10 上 T m u l t i = 1 T_{multi} = 1 Tmulti=1 比较猛

CIFAR100 上 T m u l t i = 2 T_{multi} = 2 Tmulti=2 比较猛

在这里插入图片描述
收敛速度快的优势

Since SGDR achieves good performance faster, it may allow us to train larger networks

在这里插入图片描述

CIFAR10 上 T m u l t i = 1 T_{multi} = 1 Tmulti=1 比较猛,黑白色

CIFAR100 上 T m u l t i = 2 T_{multi} = 2 Tmulti=2 比较猛

只看收敛效果的话,白色最猛, cosine learning rate

5.3 Ensemble Results

这里是复刻下《SNAPSHOT ENSEMBLES: TRAIN 1, GET M FOR FREE》(ICLR-2017)中的方法
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
来自 优化器: Snapshots Ensembles 快照集成


在这里插入图片描述
具体的实验细节不是很了解,N =16,M=3 表示总共 200 epoch,16个 restart 周期,选出来3个模型平均?还是说跑了 16*200个 epoch,每次run(200epoch)选出来 M 个模型平均?

M = 3,30,70,150
M = 2, 70,150
M = 1 ,150

5.4 Experiments on a Dataset of EEG Recordings

在这里插入图片描述

5.5 Preliminary Experiments on a downsampled ImageNet Dataset

our downsampled ImageNet contains exactly the same images from 1000 classes as the original ImageNet but resized with box downsampling to 32 × 32 pixels.

在这里插入图片描述
在这里插入图片描述

6 Conclusion(own) / Future work

  • 实战中 T i T_i Ti 怎么设计比较好,是设置成最大 epoch 吗?还是多 restart 几次, T m u l T_{mul} Tmul 是不是大于 1 比等于 1 效果好?

  • Restart techniques are common in gradient-free optimization to deal with multimodal functions

  • Stochastic subGradient Descent with restarts can achieve a linear convergence rate for a class of non-smooth and non-strongly convex optimization problems

  • Cyclic Learning rate和SGDR-学习率调整策略论文两篇

    可以将 SGDR 称为hard restart,因为每次循环开始时学习率都是断崖式增加的,相反,CLR应该称为soft restart
    在这里插入图片描述

  • 理解深度学习中的学习率及多种选择策略

  • 什么是ill-conditioning 对SGD有什么影响? - Martin Tan的回答 - 知乎
    https://www.zhihu.com/question/56977045/answer/151137770

在这里插入图片描述
在这里插入图片描述


  • niching

    What is niching scheme?

    Niching methods:
    在这里插入图片描述

    小生境(Niche):来自于生物学的一个概念,是指特定环境下的一种生存环境,生物在其进化过程中,一般总是与自己相同的物种生活在一起,共同繁衍后代。例如,热带鱼不能在较冷的地带生存,而北极熊也不能在热带生存。把这种思想提炼出来,运用到优化上来的关键操作是:当两个个体的海明距离小于预先指定的某个值(称之为小生境距离)时,惩罚其中适应值较小的个体。

    海明距离(Hamming Distance):在信息编码中,两个合法代码对应位上编码不同的位数称为码距,又称海明距离。例如,10101和00110从第一位开始依次有第一位、第四、第五位不同,则海明距离为3。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/539242.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

kali工具----枚举工具

一、枚举工具 枚举是一类程序,它允许用户从一个网络中收集某一类的所有相关信息。本节将介绍DNS枚举和SNMP枚举技术。DNS枚举可以收集本地所有DNS服务和相关条目。DNS枚举可以帮助用户收集目标组织的关键信息,如用户名、计算机名和IP地址等,…

HarmonyOS实战开发-视频播放、如何实现了视频播放、暂停、调节倍速、切换视频的功能。

介绍 视频播放的主要工作是将视频数据转码并输出到设备进行播放,同时管理播放任务。本文将对视频播放全流程、视频切换、视频循环播放等场景开发进行介绍说明。 本示例主要展示了播放本地视频和网络视频相关功能,使用 ohos.multimedia.media,ohos.resourceManager,…

Python 全栈系列239 使用消息队列完成分布式任务

说明 在Python - 深度学习系列32 - glm2接口部署实践提到,通过部署本地化大模型来完成特定的任务。 由于大模型的部署依赖显卡,且常规量级的任务需要大量的worker支持,从成本考虑,租用算力机是比较经济的。由于任务是属于超高计…

【opencv】示例-inpaint.cpp 图像修复是通过填充损坏图像部分从而修复这些损坏的过程...

原始图像 这段代码展示了一个使用OpenCV库进行图像修复的例子。它首先包含了处理图像编码、解码、显示、处理和照片处理所必要的OpenCV模块的头文件。然后利用cv和std命名空间下的类和方法。通过定义一个鼠标回调函数onMouse来处理图像上的绘图操作,并通过主函数mai…

React添加到现有项目

1.检查现有项目的根目录下是否有package.json文件 如果没有,则在项目的根目录下初始化一个package.json配置文件 2.在根目录下安装react和react-dom依赖 npm install --save react react-dom react-scripts安装成功后,react、react-dom以及react-scr…

汽车制造业PMC组态应用最佳实践

01案例及行业介绍 汽车制造工业是我国国民经济的重要支柱产业,汽车制造工厂一般包含冲压、焊装、涂装、总装四大车间。每辆汽车的生产过程被分解成很多加工任务下发给各个车间进行完成。车辆从冲压车间开始到总装车间结束一直进行不同类型的工序加工。 PMC即生产控…

Sarson Funds 在 Casper 测试网推出稳定币 csprUSD

Sarson Funds 与 Casper Association 合作,在 Casper Network (CSPR)测试网上推出了 csprUSD 稳定币。 作为最新的法币背书型稳定币,csprUSD 进入了数字货币市场,与 Ripple 和 Cardano 等组织近期推出的产品定位一致。…

使用vite从头搭建一个vue3项目(二)创建目录文件夹以及添加vue-router

目录 一、创建 vue3 项目 vite-vue3-project-js二、创建项目目录三、创建Home、About组件以及 vue-router 配置路由四、修改完成后页面 一、创建 vue3 项目 vite-vue3-project-js 使用 vite 创建一个极简 vue3 项目请参考此文章:使用Vite创建一个vue3项目 下面是我…

基于GitHub的开源讨论系统,赋予网站交互可能

Giscus:让每一条见解直达GitHub,用Giscus开启网站与社区的无缝对话新纪元!- 精选真开源,释放新价值。 概览 纯静态网站或博客,由于没有数据存储功能,经常借助第三方的评论系统以插件的方式集成进来&#x…

Java 自定義 List<T> 分頁工具

Java 自定義 List 分頁工具 PS: T可修改为对应的实体 rt com.google.common.collect.Lists;import java.util.Arrays; import java.util.Collections; import java.util.List;/*** ClassName: MyPageHelper* Descripution: List<T>分頁工具**/ public class MyPageHelp…

elementUI 下拉框加提示文案

效果如下&#xff1a; 展示文案在最下面&#xff0c;跟选项有个分割线 <el-select v-model"value" placeholder"请选择" clearable popper-class"addNotice" class"addNoticeS" visible-change"(v) >selectNotice(v,展示…

使用clickhouse-backup备份和恢复数据

作者&#xff1a;俊达 介绍 clickhouse-backup是altinity提供的一个clickhouse数据库备份和恢复的工具&#xff0c;开源项目地址&#xff1a;https://github.com/Altinity/clickhouse-backup 功能上能满足日常数据库备份恢复的需求&#xff1a; 支持单表/全库备份支持备份上…

Day23_学点儿IDEA_单元测试@Test在新module项目中失效、Jackson核心对象 ObjectMapper识别不到

版本 IntelliJ IDEA 2023.2.4 目录结构 Study(Project) ├──JavaSE(Module) │ └──xxxx └──JavaWeb(Module)└──xxxx问题 1.1 在JavaSE项目中正常可以用的单元测试Test&#xff0c;到了JavaWeb项目中不起作用了 1.2 解决方法 如果是新创建的项目&#xff0c;先…

时间序列分析 # 平稳性检验和ARMA模型的识别与定阶 #R语言

掌握单位根检验的原理并能解读结果&#xff1b;掌握利用序列的自相关图和偏自相关图识别模型并进行初步定阶。 原始数据在文末&#xff01;&#xff01;&#xff01; 练习1、根据某1971年9月-1993年6月澳大利亚季度常住人口变动&#xff08;单位&#xff1a;千人&#xff09;的…

Arthas排查工具

简介 | arthas (aliyun.com) 在线安装 #下载jar包 curl -O https://arthas.aliyun.com/arthas-boot.jar#启动会先检测虚拟机进程&#xff0c;如果没有启动失败(idea) java -jar arthas-boot.jar linux安装与window一样 卸载arthas rm -rf ~/.arthas/ rm -rf ~/logs/arthas

时钟周期检测标志信号

在某些情况下需要对系统时钟分频后的时钟进行周期检测&#xff0c;引出周期标志信号以便在后续其他情况的使用。虽然在大多数情况下我们能够知道分频后的时钟是系统时钟的几倍分频&#xff0c;但为增强在分频时钟改变情况下周期标志信号的复用性或对未知时钟的周期检测&#xf…

C++格式化输出开源库fmt入手教程

fmt项目快速上手指南 1. cmake环境配置 include(FetchContent) FetchContent_Declare(fmtGIT_REPOSITORY https://github.com/fmtlib/fmtGIT_TAG 10.0.0GIT_SHALLOW TRUE) # 1. 下载fmt库 FetchContent_MakeAvailable(fmt)add_executable(fmt_guide main.cpp) # 2. 链接fmt库…

Git基础操作及其分支管理

目录 一、git的用处&#xff1f; 1.1 git也不是银弹 二、安装git 三、git基础操作 3.1 创建git本地仓库 3.2 配置Git 3.3 认识工作区、暂存区、版本库 3.4 添加文件 3.5 Git文件目录 3.6 版本回退 3.7 撤销修改 3.7.1 对于工作区的代码&#xff0c;还没有进行add操作…

【LAMMPS学习】八、基础知识(2.6)Lammps中的Walls墙

8. 基础知识 此部分描述了如何使用 LAMMPS 为用户和开发人员执行各种任务。术语表页面还列出了 MD 术语&#xff0c;以及相应 LAMMPS 手册页的链接。 LAMMPS 源代码分发的 examples 目录中包含的示例输入脚本以及示例脚本页面上突出显示的示例输入脚本还展示了如何设置和运行各…

从前端角度防范XSS攻击的策略与实践

XSS&#xff08;Cross-Site Scripting&#xff0c;跨站脚本攻击&#xff09;是一种常见的网络安全威胁&#xff0c;它允许攻击者将恶意脚本注入到正常的网页中&#xff0c;从而在其他用户的浏览器上执行这些脚本。这可能导致数据泄露、会话劫持、甚至是对受害者计算机的完全控制…