PyTorch概述(七)---Optim

  • torch.optim是一个实现多种优化算法的包;
  • 很多常用的方法已经被支持;
  • 接口丰富;
  • 容易整合更为复杂的算法;

如何使用一个优化器

  • 为了使用torch.optim包功能;
  • 用户必须构建一个优化器对象;
  • 该优化器将保持当前的参数状态且基于计算的梯度更新参数;

构建优化器

  • 要构建一个优化器;
  • 必须给优化器一个可迭代的对象;
  • 该对象包含可优化的参数(应当是变量s);
  • 然后,用户可以指定具体的优化器参数,比如学习率,权重衰减等;
import torch.optim as optim
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
optimizer=optim.Adam([var1,var2],lr=0.0001)

单一参数设置

  • 优化器也支持每一个参数的设置;
  • 为了这样做,不要给优化器传入一个可迭代的变量s;
  • 而是给优化器传入一个可迭代的字典s;
  • 每个字典将定义一个分离的参数组;
  • 参数组内应当包含一个参数键,该参数键包含一个属于他的参数列表;
  • 其他键应当匹配优化器可接受的关键字参数;
  • 且该键将被用作这个组内的优化选项;
  • 依然可以传递选项作为关键字参数,他们将被用于默认参数;
  • 组内对他们并不覆写;
  • 当用户想变化单一的选项时会很有用;
  • 同时保持其他的参数组一致;
  • 比如,当想指定每一层的学习率时:
optim.SGD([{'params':model.base.parameters()},
            {'params':model.classifier.parameters(),'lr':1e-3}],lr=1e-2,momentum=0.9)
  • 上述代码意味着model.base的参数将使用默认的学习率:1e^{-2};
  • model.classifier的参数将使用1e^{-3}的学习率;
  • momentum=0.9将会被所有的参数使用;

优化步骤

  • 所有的优化器都实现一个step()方法;
  • 该方法对参数进行更新;
  • 有两种使用方式:
  • optimizier.step()
  • optimizer.step(closure)

optimizer.step()

  • 大多数优化器都支持的一个简单版本用法;
  • 当使用backward()计算梯度后调用该方法;
for input,target in dataset:
    optimizer.zero_grad()
    output=model(input)
    loss=loss_fn(output,target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)

  • 一些优化算法,比如共轭梯度和LBFGS;
  • 需要多次评估该函数;
  • 用户必须传递closure参数以允许算法重新计算模型;
  • closure参数应当清理梯度,计算损失,并返回;
for input,target in dataset:
    def closure():
        optimizer.zero_grad()
        output=model(input)
        loss=loss_fn(output,target)
        loss.backward()
    optimizer.step(closure)

基础类

  • 类 torch.optim.Optimizer(params,defaults)
  • 是所有优化器的基础类;
  • 必须以集合的方式指定参数;
  • 集合内的参数具有确定的顺序且同实际运行中的一致;
  • 不满足要求的是set和字典键值迭代器;
  • params(iterable)---一个可迭代的torch.Tensor或者字典,指定需要优化的张量类型;
  • defaults(Dict[str,Any])---具有默认优化选项值得字典(当参数组没有指定时使用);

Optimizer.add_param_group

给优化器参数组增加成员

Optimizer.load_state_dict

加载优化器状态

Optimizer.state_dict

以字典的方式返回优化器状态

Optimizer.step

执行一个优化器步(参数更新)

Optimizer.zero_grad

对所有优化器张量重置梯度

算法 

Adadelta

实现 Adadelta 算法;

Adagrad

实现 Adagrad 算法.

Adam

实现Adam 算法.

AdamW

实现AdamW 算法.

SparseAdam

适合稀疏矩阵的Adam算法的掩码版本

Adamax

实现Adamax 算法(基于无线范数的Adam变体).

ASGD

实现平均随机梯度下降.

LBFGS

实现L-BFGS 算法.

NAdam

实现NAdam 算法.

RAdam

实现RAdam 算法.

RMSprop

实现RMSprop 算法.

Rprop

实现弹性反向传播算法.

SGD

实现随机梯度下降算法(动量选项可选).

  •  很多算法对于优化性能\可读性和通用性具有不同的实现;
  • 如果用户没有特别的指定算法的实现方法,默认情况下针对用户设备尝试最快的实现方法;
  • 有三大类主要的实现:for-loop,foreach(多张量),fused;
  • 最直接的是对参数的for-loop循环实现,并进行大量的计算;
  • for-loop实现通常较foreach实现更慢,foreach实现一次合并参数到多张量中且进行大量计算;
  • foreach实现节省了很多序列化的内核调用;
  • 一些优化器具有更快速的融合实现;
  • 这些优化器融合大量的计算到一个内核中;
  • 我们可以认为foreach实现是水平融合,融合实现是foreach实现水平融合的垂直融合;
  • 一般来讲,三大类实现的性能排序为:fused>foreach>for-loop;
  • 应用时,默认情况下优先采用foreach;
  • 可应用意味着foreach实现是可用的,用户没有指定任何实现的细节参数(比如fused,foreach,differentiable),且所有张量是本地的在CUDA上;
  • 注意,虽然融合的实现较foreach实现应当更快;
  • 但这些实现是比较新的,在任何地方应用之前应该具有更多的实验时间;
  • 欢迎大家尝试;

目前算法的状态

Algorithm

Default

Has foreach?

Has fused?

Adadelta

foreach

yes

no

Adagrad

foreach

yes

no

Adam

foreach

yes

yes

AdamW

foreach

yes

yes

SparseAdam

for-loop

no

no

Adamax

foreach

yes

no

ASGD

foreach

yes

no

LBFGS

for-loop

no

no

NAdam

foreach

yes

no

RAdam

foreach

yes

no

RMSprop

foreach

yes

no

Rprop

foreach

yes

no

SGD

foreach

yes

no

如何调整学习率

  • torch.optim.lr_scheduler 提供一些方法基于训练的代数调整学习率;
  • torch.optim.lr_scheduler.ReduceLROnPlateau 基于一些验证测量允许动态的减少学习率;
  • 学习率调度应当在优化器更新后再应用;
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
scheduler=ExponentialLR(optimizer,gamma=0.9)
for epoch in range(20):
    for input,target in dataset:
        optimizer.zero_grad()
        output=model(input)
        loss=loss_fn(output,target)
        loss.backward()
        optimizer.step()
    scheduler.step()
  • 很多学习率调度器被称为背靠背调度器(也成为链式调度器);
  • 结果是调度器被一个个的应用到另一个之前的调度器获取到的学习率上;
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
scheduler1=ExponentialLR(optimizer,gamma=0.9)
scheduler2=MultiStepLR(optimizer,milestones=[30,80],gamma=0.1)
for epoch in range(20):
    for input,target in dataset:
        optimizer.zero_grad()
        output=model(input)
        loss=loss_fn(output,target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

注意

  • 在1.1.0版本之前,学习率调度器被期望在优化器更新之前调用;
  • 1.1.0版本改变了这一特性;
  • 如果用户在优化器更新之前使用了学习率调度器;
  • 将忽略学习率调度器中的第一个值;
  • 如果在更新PyTorch1.1.0之后无法重新生成结果;
  • 请检查是否在错误的位置调用了学习率调度器;

lr_scheduler.LambdaLR

设置每一个参数组的学习率为初始学习率乘以一个给定函数.

lr_scheduler.MultiplicativeLR

将每个参数组的学习率乘以指定函数中给出的系数.

lr_scheduler.StepLR

每个步长的epoch以gamma衰减每个参数组的学习率.

lr_scheduler.MultiStepLR

一旦训练的代数达到了一个里程碑以gamma衰减每组参数的学习率;

lr_scheduler.ConstantLR

用一个小的常值系数衰减每个参数组的学习率直到训练的代数达到了预定义的里程碑total_iters;

lr_scheduler.LinearLR

通过线性改变小的乘法系数衰减每个参数组的学习率,直到训练的代数达到了预定义的里程碑:total_iters;

lr_scheduler.ExponentialLR

每一代通过gamma衰减每一个参数组的学习率;

lr_scheduler.PolynomialLR

在给定的total_iters内,使用多项式函数衰减每一个参数组的学习率;

lr_scheduler.CosineAnnealingLR

使用余弦退火时间表设置每一个参数组的学习率,这里\eta_{max}设置为初始的学习率lr,T_{cur}为训练的代数自从在SGDR中重新启动以来.

lr_scheduler.ChainedScheduler

学习率调度器链表.

lr_scheduler.SequentialLR

接收一个在优化过程中被期望顺序调用的调度器列表和提供具体的间隔以反映在一个给定的代数哪一个调度器被猜测调用的里程碑点.

lr_scheduler.ReduceLROnPlateau

当度量停止改进时减少学习率;

lr_scheduler.CyclicLR

根据周期学习率政策(CLR)设置每一个参数组的学习率;

lr_scheduler.OneCycleLR

根据1周期学习率政策设置每一个参数组的学习率.

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火时间表设置每一个参数组的学习率,这里\eta_{max}设置为初始的学习率lr,T_{cur}为训练的代数自从在SGDR中重新启动以来.T_i为SGDR中的两次热重启之间的代数.

权值平均(SWA和EMA)

  • torch.optim.swa.utils 实现随机权值平均(SWA)和指数移动平均(EMA);
  • 特别的torch.optim.swa_utils.AveragedModel类实现SWA和EMA模型;
  • torch.optim.swa_utils.SWALR实现SWA学习率调度器;
  • torch.optim.swa_utils.update_bn()是一个工具函数用于在训练结束时更新SWA/EMA批次标准化统计;
  • SWA在Averaging Weights Leads to Wider Optima and Better Generalization.中被提出;
  • EMA是一个通过减少需要更新的权重的数量减少训练时间的广为人知的技术;
  • EMA是一个Polyak averaging的变体,但是在迭代中使用等权重而不是指数权重;

构建平均模型

  • AveragedModel类服务于计算SWA和EMA模型的权重;
  • 创建SWA平均模型:
averaged_model=AveragedModel(model)
  • 通过指定multi_avg_fn参数构建EMA模型:
decay=0.999
averaged_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(decay))
  • decay是一个在0和1之间的参数,控制平均化的参数以多快的速度衰减;
  • 如果decay参数没有提供给get_ema_multi_avg_fn,默认值为0.999;
  • get_ema_multi_avg_fn返回一个应用以下EMA权重公式的函数:

W_{t+1}^{EMA}=\alpha W_t^{EMA}+(1-\alpha)W_t^{model}

  • 这里,\alpha是EMA衰减因子,model可以是任何的torch.nn.Moudle对象;
  • averaged_model将保持追踪运行中的模型的参数均值化;
  • 为了更新这些均值,用户应当使用update_parameters()函数在optimizer.step()之后;
averaged_model.update_parameters(model)
  • 对于SWA和EMA,该函数的调用通常紧随optimizer step()函数;
  • 在SWA中,在训练起始的一些次数下被略过;

制定均值策略

  • 默认情况下,torch.optim.swa_utils.AveragedModel计算一个用户提供的参数的运行等效均值;
  • 但是用户可以使用定制的均值函数,使用avg_fn或者multi_avg_fn参数;
  • avg_fn允许定义一个函数操作每一个参数元组(平均的参数,模型参数)且应当返回平均的参数;
  • multi_avg_fn允许定义对元组参数列表(均值的参数列表,模型参数列表)的更有效的操作;
  • 同时,比如使用torch._foreach* 函数,该函数必须在位更新均值化的参数;
  • 下例中ema_model计算一个指数移动平均使用avg_fn参数:
import torch.optim.swa_utils
ema_avg=lambda averaged_model_parameter,model_parameter,
num_averaged:0.9*averaged_model_parameter+0.1*model_parameter
ema_model=torch.optim.swa_utils.AveragedModel(model,avg_fn=ema_avg)
  • 以下的实例ema_model计算一个指数移动平均使用更为高效的multi_avg_fn参数:
ema_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(0.9))

swa学习率计划

  • 通常,SWA中学习率被设置为一个大的常量数值;
  • SWALR是一个学习率调度器,将学习率调整到一个固定的值,并保持为常量;
  • 比如以下实例代码创建一个调度器线性调整学习率在5代训练的每个参数组中从初始值到0.05;
swa_scheduler=torch.optim.swa_utils.SWALR(optimizer,
                                      anneal_strategy="linear",
                                        anneal_epochs=5,swa_lr=0.05)
  • 可以使用余弦退火到一个固定的学习率值而不是使用线性退火通过设置annel_strategy='cos';

关注批量规范化

  • update_bn()是一个有用的函数允许计算SWA模型在一个给定加载器loader中的批规范统计,在训练的末尾;
torch.optim.swa_utils.update_fn(loader,swa_model)
  • update_bn()应用swa_model模型到数据加载器中的每一个单元;
  • 且在模型中每一个批标准化层计算激活数据;
  • update_fn()假设数据加载器loader中的每一批不是张量就是张量列表;
  • 且张量或张量列表中的第一个元素是网络swa_model应当应用到的张量;
  • 如果用户的加载器具有不同的结构;
  • 可以更新swa_model模型的批标准化数据;
  • 通过对数据集的每个元素使用swa_model做前向传递;

SWA综述

  • 以下实例,swa_model是一个累积权重均值的SWA模型;
  • 对模型训练300代,调整学习率计划,在训练代160时手机SWA参数的均值:
import torch
loader,optimizer,model,loss_fn=...
swa_model=torch.opim.swa_utils.AveragedModel(model)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=300)
swa_start=160
swa_scheduler=SWALR(optimizer,swa_lr=0.05)

for epoch in range(300):
    for input,target_in_loader:
        optimizer.zero_grad()
        loss_fn(model(input),target).backward()
        optimizer.step()
    if epoch>swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()
torch.optim.swa_utils.update_bn(loader,swa_model)
preds=swa_model(test_input)

EMA综述

  • 以下实例中,ema_model是一个EMA模型的实例;
  • 该实例累积权重均值的指数衰减;
  • 衰减率为0.999;
  • 训练模型300代,从训练一开始就收集EMA均值;
import torch
loader,optimizer,model,loss_fn=...
ema_model=torch.opim.swa_utils.AveragedModel(model,
                                             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
for epoch in range(300):
    for input,target_in_loader:
        optimizer.zero_grad()
        loss_fn(model(input),target).backward()
        optimizer.step()
        ema_model.update_parameters(model)
torch.optim.swa_utils.update_bn(loader,ema_model)
preds=swa_model(test_input)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/417869.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【一个上下拉且松手回弹的自定义ScrollView】

文章目录 UserDefineScrollView举例使用activity_main.xmlMainActivity.java文件运行效果下拉前下拉后上拉 普通的scrollView下拉到顶部时就不动了,而如qq设置界面中的布局,下拉到顶端时还能下拉一段距离。本文介绍一个自定义scrollView就可以实现这样的…

遥感、航拍、影像等用于深度学习的数据集集合

遥感图像的纹理特征异常繁杂,地貌类型多变,人工提取往往存在特征提取困难和特征提取不准确的问题,同时,在这个过程中还会耗费海量的人力物力。随着计算力的突破、数据洪流的暴发和算法的不断创新,在具有鲜明“大数据”…

嵌入式中14 个超级牛的免费开源小工具

Homebrew for macOS 地址:https://brew.sh Mac 上非常好用的包管理工具,很多常见的安装都可以通过 brew install app 或者 brew cask install app 直接安装,类似 apt-get 。 Oh My Zsh 地址:https://github.com/robbyrussell…

Machine Vision Technology:Lecture2 Linear filtering

Machine Vision Technology:Lecture2 Linear filtering Types of ImagesImage denoising图像去噪Defining convolution卷积的定义Key properties卷积的关键属性卷积的其它属性Annoying details卷积练习Sharpening锐化Gaussian KernelNoise噪声 分类Gaussian noise高…

江科大stm32学习笔记——【5-2】对射式红外传感器计次旋转编码计次

一.对射式红外传感器计次 1.原理 2.硬件连接 3.程序 CountSensor.c: #include "stm32f10x.h" // Device header #include "Delay.h"uint16_t CountSensor_Count;void CountSensor_Init(void) {//配置RCC时钟:RCC_APB2Perip…

改进YOLO系列 | YOLOv5/v7 引入通用高效层聚合网络 GELAN | YOLOv9 新模块

今天的深度学习方法专注于如何设计最合适的目标函数,以使模型的预测结果最接近真实情况。同时,必须设计一个合适的架构,以便为预测提供足够的信息。现有方法忽视了一个事实,即当输入数据经过逐层特征提取和空间转换时,会丢失大量信息。本文将深入探讨数据通过深度网络传输…

DH秘钥交换算法

1 应用 关于加密,对称加密和非对称加密各有优劣,最佳方案是先使用非对称加密实现秘钥交换,后面再利用协商的结果作为对称加密的秘钥,具体可以参考 《嵌入式算法6---AES加密/解密算法》、《嵌入式算法18---RSA非对称加密算法》。 …

TikTok运营应该使用什么IP?网络问题大全

想要迈过TikTok新手门槛,首先必须要学习的就是网络问题。很多人开始做TikTok账号或者TikTok小店时,都会遇到一些先前没有遇到的词汇和概念,比如原生IP,独享IP,甚至专线,那么一个IP可以做几个账号呢&#xf…

多人同时导出 Excel 干崩服务器?我们来实现一个排队导出功能!

考虑到数据库数据日渐增多,导出会有全量数据的导出,多人同时导出可以会对服务性能造成影响,导出涉及到mysql查询的io操作,还涉及文件输入、输出流的io操作,所以对服务器的性能会影响的比较大; 结合以上原因…

CPD点云配准

一、CPD点云配准 Python 这是github上一位大佬写的Python包,链接:neka-nat/probreg: Python package for point cloud registration using probabilistic model (Coherent Point Drift, GMMReg, SVR, GMMTree, FilterReg, Bayesian CPD) (github.com)你…

深入理解与应用工厂方法模式

文章目录 一、模式概述**二、适用场景****三、模式原理与实现****四、采用工厂方法模式的原因****五、优缺点分析****六、与抽象工厂模式的比较**总结 一、模式概述 ​ 工厂方法模式是一种经典的设计模式,它遵循面向对象的设计原则,特别是“开闭原则”&…

EasyX的使用(详解版)

EasyX的基础概念&#xff1a; 图形化——EasyX的安装-CSDN博客 创建图形化窗口 #include<graphics.h> #include<conio.h> int main() {//创建绘图窗口&#xff0c;大小为100x100像素。//更改为大窗口&#xff0c;像素增大&#xff1b;更改为小窗口&#xff0c;像素…

华为数通方向HCIP-DataCom H12-821题库(单选题:481-500)

第481题 以下关于基于SD-WAN思想的EVPN互联方案的描述,错误的是哪一项? A、通过部署独立的控制面,将网络转发和控制进行了分离,从而实现了网络控制的集中化 B、通过对WAN网络抽象和建模,将上层网络业务和底层网络具体实现架构进行解耦,从而实现网络自动化 C、通过集中的…

上位机图像处理和嵌入式模块部署(当前机器视觉新形态)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 过去的机器视觉处理&#xff0c;大部分都是集中在上位机、或者是服务器领域&#xff0c;这种形式维持了很长的时间。这种业务形态下&#xff0c;无…

javaee教程郑阿奇课后答案,三年经验月薪50k我是怎么做到的

个人背景 如标题所示&#xff0c;我的个人背景非常简单&#xff0c;Java开发经验1年半&#xff0c;学历普通&#xff0c;2本本科毕业&#xff0c;毕业后出来就一直在Crud&#xff0c;在公司每天重复的工作对我的技术提升并没有什么帮助&#xff0c;但小镇出来的我也深知自我努…

这一步一步爬的伤痕累累

一、网安专业名词解释 ① CTF CTF&#xff08;Capture The Flag&#xff09;中文一般译作夺旗赛&#xff0c;在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会&#xff0c;以代替之前黑客们通过互相发起真实攻击进…

数据结构-----再谈String,字符串常量池,String对象的创建、intern方法的作用

文章目录 1.字符串常量池1.1. 创建对象的思考2.2. 字符串常量池(StringTable)1.3. 再谈String对象创建1.4. intern方法 1.字符串常量池 1.1. 创建对象的思考 下面两种创建String对象的方式相同吗&#xff1f; public static void main(String[] args) {String s1 "hel…

Jmeter系列(5)线程数到底能设置多大

疑惑 一台设备的线程数到底可以设置多大&#xff1f; 线程数设置 经过一番搜索找到了这样的答案&#xff1a; Linux下&#xff0c;2g的 java内存&#xff0c;1m 的栈空间&#xff0c;最大启动线程数2000线程数建议不超过1000jmeter 能启动多少线程&#xff0c;由你的堆内存…

Decision Transformer

DT个人理解 emmm, 这里的Transformer 就和最近接触到的whisper一样,比起传统Transformer,自己还设计了针对特殊情况的tokens。比如whisper里对SOT,起始时间,语言种类等都指定了特殊tokens去做Decoder的输入和输出。 DT这里的作为输入的Tokens由RL里喜闻乐见的历史数据:…

docker save 命令 docker load 命令 快速复制容器

docker save 命令 docker load 命令 1、docker save 命令2、docker load 命令 1、docker save 命令 docker save 命令用于在系统上把正在使用的某个容器镜像 导出成容器镜像文件保存下载&#xff0c;以便在其他系统上导入这个容器镜像文件 以便快速在其他服务器上启动相同的容…