PyTorch-Lightning:trining_step的自动优化

文章目录

  • PyTorch-Lightning:trining_step的自动优化
      • 总结:
    • class _ AutomaticOptimization()
      • def run
      • def _make_closure
      • def _training_step
        • class ClosureResult():
          • def from_training_step_output
      • class Closure

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

在这里插入图片描述

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        closure = self._make_closure(kwargs, optimizer, batch_idx)

        if (
            # when the strategy handles accumulation, we want to always call the optimizer step
            not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()
        ):
            # For gradient accumulation

            # -------------------
            # calculate loss (train step + train step end)
            # -------------------
            # automatic_optimization=True: perform ddp sync only when performing optimizer_step
            with _block_parallel_sync_behavior(self.trainer.strategy, block=True):
                closure()

        # ------------------------------
        # BACKWARD PASS
        # ------------------------------
        # gradient update with accumulated gradients
        else:
            self._optimizer_step(batch_idx, closure)

        result = closure.consume_result()
        if result.loss is None:
            return {}
        return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:

        step_fn = self._make_step_fn(kwargs)
        backward_fn = self._make_backward_fn(optimizer)
        zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)
        return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]
    loss: Optional[Tensor] = field(init=False, default=None)
    extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数

class Closure(AbstractClosure[ClosureResult]):

    warning_cache = WarningCache()

    def __init__(
        self,
        step_fn: Callable[[], ClosureResult],
        backward_fn: Optional[Callable[[Tensor], None]] = None,
        zero_grad_fn: Optional[Callable[[], None]] = None,
    ):
        super().__init__()
        self._step_fn = step_fn
        self._backward_fn = backward_fn
        self._zero_grad_fn = zero_grad_fn

    @override
    @torch.enable_grad()
    def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
        step_output = self._step_fn()

        if step_output.closure_loss is None:
            self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

        if self._zero_grad_fn is not None:
            self._zero_grad_fn()

        if self._backward_fn is not None and step_output.closure_loss is not None:
            self._backward_fn(step_output.closure_loss)

        return step_output

    @override
    def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
        self._result = self.closure(*args, **kwargs)
        return self._result.loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/536232.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式单片机入职第二天-EEPROM与IIC

上午: 1.安装Jlink驱动,死活没反应,因为昨天才装完系统,领导让我装电脑主板驱动 领导方法进惠普官网通过查询电脑型号,里面几十个驱动搞得我眼花,领导告诉我进官网就去开会了,可能因为是外网&…

Win11 WSL2 install Ubuntu20.04 and Seismic Unix

Win11系统,先启用或关闭Windows功能,勾选“适用于Linux的Windows子系统”和“虚拟机平台”两项 设置wsl默认版本为wsl2,并更新 wsl --list --verbose # 查看安装版本及内容 wsl --set-default-version 2 # 设置wsl默认版本为wsl2 # 已安装…

nvm安装详细教程(安装nvm、node、npm、cnpm、yarn及环境变量配置)

一、安装nvm 1. 下载nvm 点击 网盘下载 进行下载 2、双击下载好的 nvm-1.1.12-setup.zip 文件 3.双击 nvm-setup.exe 开始安装 4. 选择我接受,然后点击next 5.选择nvm安装路径,路径名称不要有空格,然后点击next 6.node.js安装路径&#…

智慧矿山视频智能监控与安全监管方案

一、行业背景 随着全球能源需求的日益增长,矿业行业作为国民经济的重要支柱,其发展日益受到广泛关注。然而,传统矿山管理模式的局限性逐渐显现,如生产安全、人员监管、风险预警等方面的问题日益突出。因此,智慧矿山智…

【算法练习】30:快速排序学习笔记

一、快速排序的算法思想 原理:快速排序基于分治策略。它的基本思想是选择一个元素作为“基准”,将待排序序列划分为两个子序列,使得左边的子序列中的所有元素都小于基准,右边的子序列中的所有元素都大于基准。这个划分操作被称为分…

【Python函数和类3/6】函数的返回值

目录 知识回顾 目标 函数的返回值 Tips 练习 ​编辑return的其它特性 任意类型的返回值 返回多个值 return的位置 小结 局部变量 局部变量的作用域 全局变量 全局变量的作用域 同名变量 pi的作用域 总结 知识回顾 在上篇博客中,我们学习给函数设置参…

集群开发学习(一)(安装GO和MySQL,K8S基础概念)

完成gin小任务 参考文档: https://www.kancloud.cn/jiajunxi/ginweb100/1801414 https://github.com/hanjialeOK/going 最终代码地址:https://github.com/qinliangql/gin_mini_test.git 学习 1.安装go wget https://dl.google.com/go/go1.20.2.linu…

玩机进阶教程------手机定制机 定制系统 解除系统安装软件限制的一些步骤解析

定制机 在于各工作室与商家合作定制rom中有一些定制机。限制用户私自安装第三方软件。或者限制解锁 。无法如正常机登陆账号等等。定制机一般用于固定行业或者一些部门。专机专用。例如很多巴枪扫描机型等等。或者一些小牌机型。对于没有官方包的机型首先要导出各个分区来制作…

【OpenVINO™】使用 OpenVINO™ C# API 部署 YOLOv9 目标检测和实例分割模型(上篇)

YOLOv9模型是YOLO系列实时目标检测算法中的最新版本,代表着该系列在准确性、速度和效率方面的又一次重大飞跃。它通过引入先进的深度学习技术和创新的架构设计,如通用ELAN(GELAN)和可编程梯度信息(PGI)&…

复合数据类型

在C语言中,复合数据类型是指那些可以包含多个简单数据类型的数据类型。以下是一些常见的C语言复合数据类型以及相关的例子: 1. 数组(Arrays): 数组是一种可以存储多个相同类型数据的数据结构。例如: #in…

从像素游戏到 3A 大作的游戏引擎/框架

Bevy —— Rust 构建的游戏引擎 Bevy 是一款由 Rust 语言构建且简单明了的数据驱动的游戏引擎,并将永远保持开源且免费。 Mach —— Zig 游戏引擎和图形工具包 Mach 是一个 Zig 游戏引擎和图形工具包,用于构建高性能、真正跨平台、健壮且模块化的游戏&…

日程安排组件DHTMLX Scheduler v7.0新版亮点 - 拥有多种全新的主题

DHTMLX Scheduler是一个类似于Google日历的JavaScript日程安排控件,日历事件通过Ajax动态加载,支持通过拖放功能调整事件日期和时间,事件可以按天、周、月三个种视图显示。 备受关注的DHTMLX Scheduler 7.0版本日前正式发布了,如…

JS原生DOM操作 - 获得元素/网页大小/元素宽高

文章目录 获得元素的方法获取页面元素位置宽高概念方法获得网页/元素宽高clientHeight和clientWidth:scrollHeight和scrollWidth:window.innerWidth:element.style.width: offsetXXX 获得网页元素的宽高和相对父元素位置&#xff…

有道词典网页版接口分析与爬虫研究

说明:仅供学习使用,请勿用于非法用途,若有侵权,请联系博主删除 作者:zhu6201976 一、目标站点 有道词典网页版:网易有道 二、目标接口 url:https://dict.youdao.com/jsonapi_s?doctypejson&…

通过8种加锁情况来弄懂加锁对于线程执行顺序的影响

1个资源类对象,2个线程,2个同步方法,第二个线程等待1s后开启。 //资源类 public class Example {//2个同步方法public synchronized void method1(){System.out.println("线程1正在执行...");}public synchronized void method2()…

(2022级)成都工业学院数据库原理及应用实验三:数据定义语言DDL

唉,用爱发电连赞都没几个,博主感觉没有动力了 想要完整版的sql文件的同学们,点赞评论截图,发送到2923612607qq,com,我就会把sql文件以及如何导入sql文件到navicat的使用教程发给你的 基本上是无脑教程了,…

Banana Pi BPI-M7 RK3588开发板运行RKLLM软件堆AI大模型部署

关于Banana Pi BPI-M7 Banana Pi BPI-M7 采用Rockchip RK3588,板载8/16/32G RAM内存和 64/128G eMMC存储,支持无线wifi6和蓝牙5.2。2x2.5G网络端口,1个HDMIout标准 输出口,2x USB3.0,2xTYPE-C,2x MIPI CSI…

Day96:云上攻防-云原生篇Docker安全系统内核版本漏洞CDK自动利用容器逃逸

目录 云原生-Docker安全-容器逃逸&系统内核漏洞 云原生-Docker安全-容器逃逸&docker版本漏洞 CVE-2019-5736 runC容器逃逸(需要管理员配合触发) CVE-2020-15257 containerd逃逸(启动容器时有前提参数) 云原生-Docker安全-容器逃逸&CDK自动化 知识点&#xff1…

Vue3基础语法

在这个章节中&#xff0c;简单的看下Vue3的基础语法&#xff0c;有了这些基础后&#xff0c;对写vue3单页也就没有什么问题了。 模板语法 在写html时&#xff0c;我们希望在某个节点绑定一个动态值时&#xff0c;是使用dom操作执行的&#xff0c;如下&#xff1a; <!DOCT…

(Java)数据结构——排序(第一节)堆排序+PTA L2-012 关于堆的判断

前言 本博客是博主用于复习数据结构以及算法的博客&#xff0c;如果疏忽出现错误&#xff0c;还望各位指正。 堆排序&#xff08;Heap Sort&#xff09;概念 堆排序是一种基于堆数据结构的排序算法&#xff0c;其核心思想是将待排序的序列构建成一个最大堆&#xff08;或最小…