【官方文档解读】torch.jit.script 的使用,并附上官方文档中的示例代码


由 OpenMMLab 的部署教程 所述,对于模型中存在有控制条件的(如 if,for 等),需要用 torch.jit.script 而非采样默认的 torch.jit.trace 方法。本文则详细介绍了下官方文档中对 torch.jit.script 的解释和示例代码。

torch.jit.script

torch.jit.script 用于将函数或 nn.Module 编译为 TorchScript。

函数签名
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)
功能概述

将函数或 nn.Module 脚本化,会检查源代码,并使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回一个 ScriptModuleScriptFunction。TorchScript 是 Python 语言的一个子集,因此并不是所有的 Python 功能都能在其中使用,但我们提供了足够的功能来对张量进行计算和执行控制相关操作。完整指南请参阅 TorchScript 语言参考。

脚本化字典或列表会将其中的数据复制到一个 TorchScript 实例中,该实例可以在 Python 和 TorchScript 之间以零复制开销传递引用。

torch.jit.script 可以作为函数用于模块、函数、字典和列表,并可以作为装饰器 @torch.jit.script 用于 TorchScript 类和函数。

参数
  • obj(Callable、类或 nn.Module) – 要编译的 nn.Module、函数、类类型、字典或列表。
  • example_inputs(Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例输入以注释函数或 nn.Module 的参数。
返回值

如果 obj 是 nn.Module,脚本会返回一个 ScriptModule 对象。返回的 ScriptModule 将具有与原始 nn.Module 相同的子模块和参数集。如果 obj 是独立函数,将返回 ScriptFunction。如果 obj 是字典,则脚本返回 torch._C.ScriptDict 实例。如果 obj 是列表,则脚本返回 torch._C.ScriptList 实例。

脚本化函数

@torch.jit.script 装饰器通过编译函数体来构建 ScriptFunction。

示例(脚本化函数):
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# 以 Python 代码查看编译后的图
print(foo.code)

# 使用 TorchScript 解释器调用函数
foo(torch.ones(2, 2), torch.ones(2, 2))
使用示例输入脚本化函数

示例输入可用于注释函数参数。

示例(脚本化前注释函数):
import torch

def test_sum(a, b):
    return a + b

# 注释参数为 int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# 以 Python 代码查看编译后的图
print(scripted_fn.code)

# 使用 TorchScript 解释器调用函数
scripted_fn(20, 100)
脚本化 nn.Module

默认情况下,脚本化 nn.Module 会编译 forward 方法,并递归编译 forward 调用的任何方法、子模块和函数。如果 nn.Module 仅使用 TorchScript 支持的功能,则无需对原始模块代码进行任何更改。脚本将构建一个 ScriptModule,其中包含原始模块的属性、副本和方法。

示例(脚本化包含参数的简单模块):
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # 此参数将被复制到新的 ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # 当使用此子模块时,它将被编译
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # 这会调用 `nn.Linear` 模块的 `forward` 方法,从而在此处将 `self.linear` 子模块编译为 `ScriptModule`
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))
示例(脚本化包含 traced 子模块的模块):
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # torch.jit.trace 生成一个 ScriptModule 的 conv1 和 conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

要编译 forward 以外的方法(并递归编译它调用的任何内容),请将 @torch.jit.export 装饰器添加到方法上。要选择不编译,请使用 @torch.jit.ignore@torch.jit.unused

示例(模块中导出和忽略的方法):
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # 此函数不会被编译,因此可以使用任何 Python API
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
示例(使用示例输入注释 nn.Module 的 forward 方法):
import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
    result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# 在提供的输入下运行 pdt_model 并注释 forward 的参数
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# 使用实际输入运行 scripted_model
print(scripted_model([20]))

官方文档链接:https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/678827.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SemiDrive X9H 平台 QT 静态编译

一、 前言 芯驰 X9H 芯片,搭载多个操作系统协同运行,系统实现了仪表、空调、中控、副驾多媒体的四屏驱动控制,在人车智能交互上可以通过显示屏、屏幕触摸控制、语音控制、物理按键控制、车身协议的完美融合,使汽车更智能。让车主…

“论软件系统建模方法”必过范文,突击2024软考高项论文

论文部分 摘要 2023年03月,我参与了某艺术品公司线上拍卖管理平台的研发。该项目的目标是建立一个互联网在线拍卖平台,用户可以通过手机或PC浏览器进入拍卖平台,对喜欢的拍品进行参拍出价。平台提供了在线支付、在线出价、保证金管理、拍品…

通信技术信号源硬件实验

定义 1.RZ码(归零码) RZ编码也成为归零码,归零码的特性就是在一个周期内,用二进制传输数据位,在数据位脉冲结束后,需要维持一段时间的低电平。 2.NRZ码(不归零编码) NRZ编码也成为…

AVL树的介绍与实现

前言 我们上一期介绍了二叉搜索树并做了实现,本期我们来继续学习另一个更优的树即AVL树! 本期内容介绍 什么是AVL树? AVL树的实现 AVL树的性能分析 在正式的介绍AVL树之前,我们先来回忆一下二叉搜索树的特点:左子树的…

nnUNet保姆级使用教程!从环境配置到训练与推理(新手必看)

文章目录 写在前面nnUNet是什么?一、配置虚拟环境二、安装nnUNet框架1.安装nnUNet这一步我遇到的两个问题: 2.安装隐藏层hiddenlayer(可选) 三、数据集准备nnUNet对于你要训练的数据是有严格要求的,这第一点就体现在我…

pushowl | 运用分销裂变模式实现业绩快速增长

一、公司简介 PushOwl公司是一家专注于为电子商务商店提供在线营销工具的印度初创企业。以下是对PushOwl公司的详细介绍: 基本信息: 所属公司:Creatorbox Softwares Private Limited 成立日期:2018年 所属地:印度 …

103.网络游戏逆向分析与漏洞攻防-ui界面的设计-加速功能的开关设计

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 如果看不懂、不知道现在做的什么,那就跟着做完看效果,代码看不懂是正常的,只要会抄就行,抄着抄着就能懂了 内容…

RDMA (1)

RDMA是什么 Remote Direct Memory Access(RDMA)是用来给有高速需求的应用释放网络消耗的。 RDMA在网络的两个应用之间进行低延迟,高吞吐的内存对内存的直接数据通信。 InfiniBand需要部署独立的协议。 RoCE(RDMA over Converged Ethernet),也是由InfiniBand Trade Associat…

【吊打面试官系列】Java高并发篇 - Java 线程数过多会造成什么异常?

大家好,我是锋哥。今天分享关于 【Java 线程数过多会造成什么异常?】面试题,希望对大家有帮助; Java 线程数过多会造成什么异常? 1、线程的生命周期开销非常高 1000道 互联网大厂Java工程师 精选面试题-Java资源分享…

postman教程-12-保存请求至Collections

领取资料,咨询答疑,请➕wei: June__Go 上一小节我们学习了Postman管理环境的方法,本小节我们讲解一下Postman保存请求至Collections集合的方法。 1、创建Collection 在保存Request请求之前,先创建一个Collection(集合)&#…

【网络编程开发】1.网络结构 2.IP地址与端口号 3.字节序

网络编程开发 两台计算机要互相传送文件需解决很多问题: 必须有一条传送数据的通路。发起方必须激活通路。要告诉网络如何识别接收方。发起方要清楚对方是否已开机,且与网络连接正常。发起方要清楚对方是否准备好接收和存储文件。若文件格式不兼容&…

藏品名称:龙凤呈祥摆件

藏品名称:龙凤呈祥摆件 规格:重约 14.3Kg 藏品类别:杂项 此器身布满繁缠纹饰。器表为一轮红日出于东方,照耀辽阔江海。红日旁边有两点黄金凸显其尊贵。一神龙首尾相接有祥云伴随,大山脚下栖息着一只凤凰与神龙遥相呼应。龙身和龙尾交接处有花和花蕊,花开富贵象征着吉祥。整个…

Esxi的安装问题处理: Failed to verify signatures of the following vib(s)

前言 在安装esxi的时候报错 如下图: 自己在安装过程中遇到点问题 Failed to verify signatures of the following vib(s) 一番查找,只要在bios里面关闭 Security boot 就可以解决 Prepping an ESXi 6.7 host for Secure Boot – Mike Foley

9.抽象类和接口

抽象类 抽象类概念 在面向对象的概念中,所有的对象都是通过类来描绘的,但是反过来,并不是所有的类都是用来描绘对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类 比如: 我…

Linux网络-自定义协议、序列化和反序列化、网络计算服务器的实现和Windows端客户端

文章目录 前言一、自定义协议传结构体对象 序列化和反序列化什么是序列化?反序列化 二、计算器服务端(线程池版本)1.main.cc2.Socket.hpp3.protocol.hpp4.Calculator.hpp5.serverCal.hpp6.threadPool.hpp7.Task.hpp8. log.hpp 客户端Windows客…

PS插件一键轻松搞定电商产品摄影图!

在电商行业中,一张高质量的产品摄影图往往能够吸引更多潜在消费者的目光,从而增加产品的销量。然而,对于许多电商卖家和摄影师来说,后期处理产品图片却是一个既耗时又费力的工作。 最近我发现一款PS插件可以一键生成电商产品摄影…

品牌舆情监测系统是什么?怎么监测?

品牌形象与口碑对于企业的重要性不言而喻,品牌舆情监测系统应运而生,成为企业守护品牌声誉的利器。品牌舆情监测系统是什么?怎么选择合适的舆情监测系统?接下来伯乐网络传媒就给大家分享一下。 一、品牌舆情监测系统的应用价值 1…

【一百零四】【算法分析与设计】【模板】二维差分,2132. 用邮票贴满网格图,LCP 74. 最强祝福力场,二位差分,差分思想,记录变化值,离散化技巧

【模板】二维差分 描述 给你一个n行m列的矩阵,下标从1开始。 接下来有q次操作,每次操作输入5个参数x1, y1, x2, y2, k 表示把以(x1, y1)为左上角,(x2,y2)为右下角的子矩阵的每个元素都加上k, 请输出操作后的矩阵。 输入描述: 第一…

读书-《蛤蟆先生去看心理医生》

书名蛤蟆先生去看心理医生作者罗伯特戴博德状态阅读中简介该书借用《柳林风声》的故事主角蛤蟆先生,讲述了他接受心理咨询的故事。作者通过陷入抑郁的蛤蟆先生和心理咨询师苍鹭的互动,探索蛤蟆先生爱炫耀、自卑性格和抑郁情绪的来源,指出童年…

二叉树的OJ题

1.二叉树的前序遍历 /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ /*** Note: The returned array must be malloced, assume caller calls free().*/int TreeeSize(struct Tre…