【mmengine】优化器封装(OptimWrapper)(入门)优化器封装 vs 优化器

  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装(OptimWrapper)进行单精度训练混合精度训练梯度累加,对比二者实现上的区别。

一、 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

二、 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
在这里插入图片描述

三、 基于 Pytorch 的 SGD 优化器实现混合精度训练

在这里插入图片描述

  • 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
    • 内存占用更少
    • 计算更快
from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

四、 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述
只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

七、 获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)

print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}

在这里插入图片描述

八、 导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)

# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)



# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)

# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)

在这里插入图片描述

九、 使用多个优化器

OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nn

from mmengine.optim import OptimWrapper, OptimWrapperDict

# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)

# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)

# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)

# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)

print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/886406.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虚拟机三种网络模式详解

在电脑里开一台虚拟机,是再常见不过的操作了。无论是用虚拟机玩只有旧版本系统能运行的游戏,还是用来学习Linux、跑跑应用程序都是很好的。而这其中,虚拟机网络是绝对绕不过去的。本篇文章通俗易懂的介绍了常见的虚拟网络提供的三种网络链接模…

数据结构-3.5.队列的顺序实现

一.队列的顺序实现&#xff0c;初始化操作以及判断队列是否为空&#xff1a; 1.图解&#xff1a; 2.代码&#xff1a; #include<stdio.h> #define MaxSize 10 //定义一个队列最多存储的元素个数 ​ typedef struct {int data[MaxSize]; //用静态数组存放队列元素int f…

【springboot】整合沙箱支付

目录 1. 配置沙箱应用环境2. 配置springboot项目1. 引入依赖2. 配置文件注册下载ngrok 3. 创建支付宝支付服务类4. 支付界面模板5. 控制类实现支付6. 测试 1. 配置沙箱应用环境 使用支付宝账号登录到开放平台控制台。 使用支付宝登录后&#xff0c;看到以下页面&#xff0c;下…

MFC工控项目实例二十二主界面计数背景颜色改变

承接专栏《MFC工控项目实例二十一型号选择界面删除参数按钮禁用切换》 1、在SEAL_PRESSUREDlg.h文件中添加代码 class CSEAL_PRESSUREDlg : public CDialog { public: CBrush m_brush1;CBrush m_brush2;CBrush m_brush3;... } 2、在SEAL_PRESSUREDlg.cpp文件中添加代码 BO…

在2核2G服务器安装部署MySQL数据库可以稳定运行吗?

阿里云2核2G服务器可以安装MySQL数据库吗&#xff1f;当然可以&#xff0c;并且可以稳定运行MySQL数据库&#xff0c;目前阿里云服务器网aliyunfuwuqi.com使用的就是阿里云2核2G服务器&#xff0c;在云服务器上安装MySQL数据库&#xff0c;可以稳定运行。 目前阿腾云用于运行M…

查看 git log的过程中看到 :说明日志输出可能超出屏幕大小,系统进入了分页模式

在命令行提示符中&#xff0c;通常 : 表示系统等待进一步的输入。如果你在查看 git log 的过程中看到 :&#xff0c;说明日志输出可能超出屏幕大小&#xff0c;系统进入了分页模式&#xff0c;默认使用 less 命令查看内容。 此时你可以&#xff1a; 按 q 退出日志查看。按 En…

算法笔记(五)——分治

文章目录 算法笔记&#xff08;五&#xff09;——分治快排颜色分类排序数组数组中的第K个最大元素库存管理 III 归并排序数组交易逆序对的总数计算右侧小于当前元素的个数翻转对 算法笔记&#xff08;五&#xff09;——分治 分治算法字面上的解释是“分而治之”&#xff0c;就…

Python 从入门到实战32(数据库MySQL)

我们的目标是&#xff1a;通过这一套资料学习下来&#xff0c;通过熟练掌握python基础&#xff0c;然后结合经典实例、实践相结合&#xff0c;使我们完全掌握python&#xff0c;并做到独立完成项目开发的能力。 上篇文章我们讨论了数据库编程接口操作的相关知识。今天我们将学习…

【框架篇】过滤器和拦截器的区别以及使用场景

在项目开发中&#xff0c;常常会同时配置拦截器&#xff08;Interceptor&#xff09;和过滤器&#xff08;Filter&#xff09;&#xff0c;以下就是它们两个主要的区别&#xff1a; 过滤器&#xff08;Filter&#xff09; 配置和实现 Filter的实现还是很简单的&#xff0c;可…

【微服务】组件、基础工程构建(day2)

组件 服务注册和发现 微服务模块中&#xff0c;一般是以集群的方式进行部署的&#xff0c;如果我们调用的时候以硬编码的方式&#xff0c;那么当服务出现问题、服务扩缩容等就需要对代码进行修改&#xff0c;这是非常不好的。所以微服务模块中就出现了服务注册和发现组件&…

计算机毕业设计 基于Python的广东旅游数据分析系统的设计与实现 Python+Django+Vue Python爬虫 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

华为云+WordPress+Puock主题搭建个人博客

网站访问地址&#xff1a;qingxuly.cn 搭建网站 购买华为云服务器&#xff0c;购买域名&#xff0c;进行备案&#xff0c;配置域名解析等操作&#xff0c;请参考华为云文档。 安装Ubuntu系统 华为云控制台中给云服务器安装Ubuntu2204。 配置服务器安全组 华为云安全组中创建安…

【嵌入式系统】第18章 脉宽调试器(PWM)

目录 18.1 结构框图 18.3 功能说明 18.3.4 PWM 信号发生器 18.3.5 死区发生器 18.3.6 中断/ADC 触发选择器 18.3.7 同步方法 18.3.8 故障条件 18.3.9 输出控制块 LES 硬件介绍&#xff08;12&#xff09;正交编码接口QEI 19.1 结构框图 19.2 信号描述 19.3 功能说明…

GPG error golang 1.19

1. 问题描述及原因分析 在飞腾2000的服务器&#xff0c;OS为Kylin Linux Advanced Server release V10环境下&#xff0c;docker版本为18.09.0&#xff08;docker-engine-18.09.0-101.ky10.aarch64&#xff09;&#xff0c;基于容器镜像golang:1.19编译新的容器镜像&#xff0…

C++黑暗迷宫

目录 开头程序程序的流程图程序游玩的效果下一篇博客要说的东西 开头 大家好&#xff0c;我叫这是我58。 程序 #include <iostream> #include <cstdlib> #include <ctime> using namespace std; struct near {int i;int ia;int ix;int iy;int iwalk; }; v…

22.1 k8s不同role级别的服务发现

本节重点介绍 : 服务发现的应用3种采集的k8s服务发现role 容器基础资源指标 role :nodek8s服务组件指标 role :endpoint部署在pod中业务埋点指标 role :pod 服务发现的应用 所有组件将自身指标暴露在各自的服务端口上&#xff0c;prometheus通过pull过来拉取指标但是promet…

SQL中基本SELECT语句及常见关键字的使用(内连接,左/右连接)

这里写目录标题 SQL中基本SELECT语句的使用SQL语法简介DDL、DML、DCLSEECT SELECT常用关键词group by分组having筛选limit限定条数UION和UION ALL合并SQL执行顺序 联表查询多表查询示例特殊用法&#xff1a;笛卡尔积&#xff08;交叉连接&#xff09;等值连接vs非等值连接自连接…

VScode 自定义代码配色方案

vscode是一款高度自定义配置的编辑器, 我们来看看如何使用它自定义配色吧 首先自定义代码配色是什么呢? 看看我的代码界面 简而言之, 就是给你的代码的不同语义(类名, 函数名, 关键字, 变量)等设置不同的颜色, 使得代码的可读性变强. 其实很多主题已经给出了定制好的配色方案…

D3.js中国地图可视化

1、项目介绍 该项目来自Github&#xff0c;基于D3.js中国地图可视化。 D3.js is a JavaScript library for manipulating documents based on data. It uses HTML, SVG, and CSS to display data. The full name of D3 is "Data-Driven Documents," which means it a…

【Flume Kafaka实战】Using Kafka with Flume

一 目标 在Cloudera Manager中创建两个Flume的Agent&#xff0c;Agent1从local file中获取内容&#xff0c;写入到kafka的队列中。Agent2以Agent1的sink作为source&#xff0c;将数据从kafka中读取出来&#xff0c;写入到HDFS中。 二 实战 2.1 Kafka Sink 第一步&#xff0…