PyTorch 内 LibTorch/TorchScript 的使用

PyTorch 内 LibTorch/TorchScript 的使用

  • 1. .pt .pth .bin .onnx 格式
    • 1.1 模型的保存与加载到底在做什么?
    • 1.2 为什么要约定格式?
    • 1.3 格式汇总
      • 1.3.1 .pt .pth 格式
      • 1.3.2 .bin 格式
      • 1.3.3 直接保存完整模型
      • 1.3.4 .onnx 格式
      • 1.3.5 jit.trace
      • 1.3.6 jit.script
    • 1.4 总结
  • 2. TorchScript 的转换
    • 2.1 jit trace 注意事项
    • 2.2 jit trace 验证技巧
    • 2.3 混合使用 trace 和 script
    • 2.4 trace 和 script 的性能
    • 2.5 总结
  • 3. LibTorch 的使用
    • 3.1 LibTorch 的链接
    • 3.2 接口和实现

Reference:

  1. [Pytorch].pth转.pt文件
  2. Pytorch格式 .pt .pth .bin .onnx 详解
  3. pytorch 基于tracing/script方式转ONNX

1. .pt .pth .bin .onnx 格式

1.1 模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码
    1. 包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    2. 包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    3. 包含了我们如何加载数据和使用;
    4. 包含了我们如何测试评估模型。
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    1. 包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    2. 包含optimizer.state_dict(),这是模型的优化器中的参数;
    3. 包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集
    1. 包含了我们训练模型使用的所有数据;
    2. 可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档
    1. 根据使用文档的步骤,每个人都可以重现模型;
    2. 包含了模型的使用细节和我们相关参数的设置依据等信息。

可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。

现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。

1.2 为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的 README.md 才知道。而在使用 torch.load() 方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于 torch.load() 方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取

torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.

1.3 格式汇总

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据需要将 PyTorch 模型转换为通用的二进制格式的场景.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.git.trace 或 torch.git.script 函数将 PyTorch 模型转换为 TorchScript 格式需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景.pt 或 .pth

1.3.1 .pt .pth 格式

一个完整的 PyTorch 模型文件,包含了如下参数:

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值

下面是一个 .pt 文件的保存和加载示例(注意,后缀也可以是 .pth):

  • .state_dict():包含所有的参数和持久化缓存的字典,model 和 optimizer 都有这个方法
  • torch.save():将所有的组件保存到文件中

模型保存

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 初始化优化器

loss = nn.MSELoss()# 初始化损失函数

PATH = "model.pth" # 保存路径

# 保存模型
torch.save({
            'epoch': 10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)

netron 可得:
在这里插入图片描述

模型加载

import torch
import torch.nn as nn

# 定义同样的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载模型
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
PATH = "model.pth"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

1.3.2 .bin 格式

.bin 文件是一个二进制文件,可以保存 PyTorch 模型的参数和持久化缓存。.bin 文件的大小较小,加载速度较快,因此在生产环境中使用较多。

下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀—后缀无意义):
保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn

# 定义相同的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

1.3.3 直接保存完整模型

可以看出来,我们在之前的保存方式中,都是保存了 .state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
保存模型

PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)

netron 可得:
在这里插入图片描述

可以看到与上面仅保存参数的方式相比,多了很多信息。

加载模型

model = torch.load("entire_model.pt")
model.eval()

1.3.4 .onnx 格式

上述保存的文件可以通过 PyTorch 提供的 torch.onnx.export 函数转化为ONNX格式,这样可以在其他深度学习框架中使用 PyTorch 训练的模型。转化方法如下:

import torch
import torch.onnx

# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")

# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])

加载 ONNX 格式的代码可以参考以下示例代码(注意 ONNX 只能推理不能训练,不包含反向信息的):

import onnx
import onnxruntime

# 加载ONNX文件
onnx_model = onnx.load("model.onnx")

# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")

# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)

# 运行模型
outputs = ort_session.run(None, {"input": input_data})

# 输出结果
print(outputs)

注意,需要安装 onnxonnxruntime 两个 Python 包。此外,还需要使用 numpy 等其他常用的科学计算库。

1.3.5 jit.trace

保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()

PATH = "model_trace.pth"

# 保存模型
example = torch.rand(1, 10)
traced_module = torch.jit.trace(model, example)
traced_module.save(PATH)

在这里插入图片描述

1.3.6 jit.script

保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()

PATH = "model_script.pth" # 保存路径

# 保存模型
scripted_module = torch.jit.script(model)
scripted_module.save(PATH)

netron 可得:
在这里插入图片描述

1.4 总结

综上,PyTorch 可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的 model.state_dict()optimizer.state_dict(),以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在 readme.md 中写清楚,导出了哪些信息。

保存场景保存方法文件后缀
整个模型(保存模型结构)model = Net()
torch.save(model, PATH)
.pt .pth .bin
仅模型参数(不保存模型结构)model = Net()
torch.save(model.state_dict(), PATH)
.pt .pth .bin
checkpoints使用model = Net()
torch.save({‘epoch’:10,‘model_state_dict’:model.state_dict(),‘optimizer_state_dict’: optimizer.state_dict(),‘loss’: loss,}, PATH)
.pt .pth .bin
ONNX通用保存model = Net()
model.load_state_dict(torch.load(“model.bin”))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, “model.onnx”, input_names=[“input”], output_names=[“output”])
.onnx
TorchScript 无 Python 环境使用model = Net()
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(‘model_scripted.pt’)
model = torch.jit.load(‘model_scripted.pt’)
model.eval()
.pt .pth

2. TorchScript 的转换

上文内提到 .pthpt 等价,而且后缀主要用于提示。不过相对来说,PyTorch 的模型文件一般保存为 .pth 文件的更多一点,而 C++ 接口一般读取的是 .pt 文件,因此,C++ 在调用 PyTorch 训练好的模型文件的时候,就需要转换为以 .pt 为代表的 TorchScript 文件,才能够读取。

Script mode 通过 torch.jit.trace 或者 torch.jit.script 来调用。这两个函数都是将 Python 代码转换为 TorchScript 的两种不同的方法。

  • torch.jit.trace:将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个 PyTorch 模型,torch.jit.trace 会跟踪此 input 在 model 中的计算过程,然后将其转换为 Torch 脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

  • torch.jit.script 直接将 Python 函数(或者一个 Python 模块)通过 Python 语法规则和编译转换为 Torch 脚本。torch.jit.script 更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于 RNN 或者一些具有可变序列长度的模型,使用 torch.jit.script 会更为方便。

在通常情况下,更应该倾向于使用 torch.jit.trace 而不是 torch.jit.script

在模型部署方面,ONNX 被大量使用。而导出 ONNX 的过程,也是 model 进行 torch.jit.trace 的过程,因此这里我们把 torch 的 trace 做稍微详细一点的介绍。

2.1 jit trace 注意事项

为了能够把模型编写的更能够被 jit trace,需要对代码做一些妥协,例如:

  1. 如果 model 中有 DataParallel 的子模块,或者 model 中有将 tensors 转换为 numpy arrays,或者调用了 OpenCV 的函数等,这种情况下,model 不是一个正确的在单个设备上、正确连接的 graph,这种情况下,不管是使用 torch.jit.script 还是 torch.jit.trace 都不能 trace 出正确的 TorchScript 来。

  2. model 的输入输出应该是 Union[Tensor, Tuple[Tensor], Dict[str, Tensor]] 的类型,而且在 dict 中的值,应该是同样的类型。但是对于 model 中间子模块的输入输出,可以是任意类型,例如 dicts of Any, classes, kwargs 以及 Python 支持的都可以。对于 model 输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

    outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
    # torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
    adapter = TracingAdapter(model, inputs)  # 使用Adapter,将model inputs包装为trace支持的类型
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功
    
    # Traced model的输出只能是tuple tensors类型:
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 再通过adapter转换为想要的输出类型
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
  3. 一些数值类型的问题。比如下面的代码片段:

    import torch
    a=torch.tensor([1,2])
    print(type(a.size(0)))
    print(type(a.size()[0]))
    print(type(a.shape[0]))
    

    在eager mode下,这几个返回值的类型都是int型。上面代码的输出为:

    <class 'int'>
    <class 'int'>
    <class 'int'>
    

    但是在 trace mode 下,这几个表达式的返回值类型都是 Tensor 类型。因此,有些表达式使用不当,如果在 trace 过程中,一些 shape 表达式的返回值类型是 int 型,那么可能造成这块代码没有被 trace。在代码中,可以通过使用 torch.jit.is_tracing 来检查这块代码在 trace mode 下有没有被执行。

  4. 由于动态的 control flow,造成模型没有被完整的 trace。看下面的例子:

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.trace(f, torch.tensor(3))
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)
    

    可以看到 trace 后的 model 只保留了一条分支。因此由于输入造成的 dynamic 的 control flow,trace 后容易出现错误。

    这种情况下,我们可以使用 torch.jit.script 来进行 TorchScript 的转换。

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.script(f)
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:
      if bool(torch.gt(torch.sum(x), 0)):
        _0 = torch.sqrt(x)
      else:
        _0 = torch.square(x)
      return _0
    

    在大多数情况下,我们应该使用 torch.jit.trace,但是像上面的这种 dynamic control flow 的情况,我们可以混合使用 torch.jit.tracetorch.jit.script,在后面会进行阐述
    另外在一些 Blog 中,对于 dynamic control flow 的定义是有错误的,例如 if x[0] == 4: x += 1 是 dynamic control flow,但是:

    model: nn.Sequential = ...
    for m in model:
      x = m(x)
    

    以及:

    class A(nn.Module):
      backbone: nn.Module
      head: Optiona[nn.Module]
      def forward(self, x):
        x = self.backbone(x)
        if self.head is not None:
            x = self.head(x)
        return x
    

    都不是 dynamic control flowdynamic control flow 是由于对输入条件的判断造成的不同分支的执行

  5. trace 过程中,将变量 trace 成了常量。看下面一个例子:

    import torch
    a, b = torch.rand(1), torch.rand(2)
    
    def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))
    
    print(torch.jit.trace(f1, a)(b))
    # 输出: tensor([0, 1])
    # 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的
    
    print(torch.jit.trace(f2, a)(b))
    # 输出:
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    # tensor([0])
    # 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。
    
    # 我们打印一下两者的区别
    print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
    # 输出
    # def f1(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1
    
    #  def f2(x: Tensor) -> Tensor:
    #   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _0
    
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    
    # 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
    

    我们导出 ONNX 的过程,也是进行 torch.jit.trace 的过程,在导出 ONNX 的时候,有时候也会遇到

    TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

    这样的提示信息,这时候要检查一下代码中是不是有可能 trace 过程中,变量会被当做常量的情况,有可能会导致导出的 ONNX 精度异常。

    • 关于 ONNX
      ONNX 默认基于 trace 的方式,运行一次模型,记录下和 tensor 的相关操作。trace 将不会捕获根据输入数据而改变的行为。比如 if 语句,只会记录执行的那一条分支,同样的,for 循环的次数,导出与跟踪运行完全相同的静态图。如果要使用动态控制流导出模型,则需要使用 torch.jit.script
      torch.jit.script:真正的去编译,在 PYTHON 的 AST 语法树做语法分析句法分析。因此可以使用if等动态控制流。返回 ScriptModule。
      torch.onnx.export 在运行时,先判断是否是 SriptModule,如果不是,则进行 torch.jit.trace,因此 export 需要一个随机生成的输入参数。
      import torch.nn as nn
      import torch
      import torch.nn.functional as F
      import cv2
      import numpy as np
      import onnx
      import onnxruntime as ort
      
      #from torch.onnx import register_custom_op_symbolic # 私有层支持
      
      class test_net(nn.Module):
          def __init__(self,):
              super(test_net, self).__init__()
              #self.model = nn.MaxPool3d(kernel_size=(1,3,3), stride=(2,1,2))
              #self.model = nn.AvgPool3d(kernel_size=(1,3,3), stride=(2,1,2)) #-> AveragePool
              self.model = nn.Conv3d(3,64,kernel_size=(1,3,3), stride=(2,1,2))
              self.relu = nn.ReLU()
              self.relu6 = nn.ReLU6()
              self.relu66 = nn.ReLU6()
      
          def forward(self, x):
              out1 = self.model(x)
              f_mean = torch.mean(out1) # -> ReduceMean
              #f_mean = torch.mean(out1).item() # item()会将f_mean转换为常数 会丢失 mean操作
              # script模式转onnx会报错 torch._C._jit_pass_erase_number_types(graph) RuntimeError: Unknown number type: Scalar
              out2 = torch.div(out1, f_mean)
              #outlist = list()
              #for i in range(3):
              #    if i in [0]:
              #        #outlist.append(nn.ReLU()(out2))  # script模式下报错 类对象要提前构建
              #        outlist.append(self.relu(out2))   # scrip_to_onnx 报错 找不到25 BUG
              #    else:
              #        #outlist.append(nn.ReLU6()(out2))
              #        outlist.append(self.relu6(out2))
              #out = torch.cat(outlist)
              # 上述 for循环构图在tracing模式下会展开
              # script模式下难转换,报错
              # 手动平铺
              o1 = self.relu(out2)
              o2 = self.relu6(out2)
              #o3 = self.relu6(out2)   # script模式下被优化掉了 BUG
              o3 = self.relu66(out2)   # script模式下被优化掉了
              out = torch.cat([o1,o2,o3])
      
              return out
      
      # 模型构建和运行
      imgh, imgw = 24, 94
      net = test_net().eval() # 若存在batchnorm、dropout层则一定要eval() 使得BN层参数不更新
      dummy_input = torch.randn(1,3,3,imgh, imgw)# n c d h w
      torch_out = net.forward(dummy_input)# net(dummy_input)
      
      
      # export onnx
      dynamic_axes = {'input': {3: 'height', 4: 'width'}, 'output': {3: 'height', 4: 'width'}} # 配置动态分辨率
      onnx_pth = "test-conv-relu.onnx"
      
      # 传入原model,采用默认trace方式捕获模型,需要运行模型
      torch.onnx.export(net, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes)
      # 也可传入 scriptModule
      #net_script= torch.jit.script(test_net())
      # 需要外加配置 example_outputs,用来获取输出的shape和dtype,无需运行模型
      #torch.onnx.export(net_script, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, example_outputs=[torch_out])
      
      # ort run
      oxx_m = ort.InferenceSession(onnx_pth)
      onnx_blob = dummy_input.data.numpy()
      onnx_out = oxx_m.run(None, {'input':onnx_blob})[0]
      
      dummy_input2 = torch.randn(1,3,3,imgh*2, imgw*2)
      onnx_blob2 = dummy_input2.data.numpy()
      onnx_out2 = oxx_m.run(None, {'input':onnx_blob2})[0]
      
      # opencv run
      #cv_m = cv2.dnn.readNet(onnx_pth)
      
      print('mean diff = ', np.mean(onnx_out - torch_out.data.numpy()))
      

    除了 len 会导致 trace 错误,其他几个也会导致 trace 出现问题:

    • .item() 会在 trace 过程中将 tensors 转为 int/float

    • 任何将 torch 类型转为 numpy/python 类型的代码

    • 一些有问题的算子,例如 advanced indexing

    • torch.jit.trace 不会对传入的 device 生效

      import torch
      def f(x):
          return torch.arange(x.shape[0], device=x.device)
      m = torch.jit.trace(f, torch.tensor([3]))
      print(m.code)
      # 输出
      # def f(x: Tensor) -> Tensor:
      #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
      #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
      #   return _1
      print(m(torch.tensor([3]).cuda()).device)
      # 输出:device(type='cpu')
      

      trace 不会对传入的 cuda device 生效。

2.2 jit trace 验证技巧

为了保证trace的正确,我们可以通过一下的一些方法来尽量保证 trace 后的模型不会出错:
1.注意 warnings 信息。类似这样的:

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个 warning 等级。
2. 做单元测试。需要验证一下 eager mode 的模型输出与 trace 后的模型输出是否一致。

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
  1. 避免一些特殊的情况。例如下面的代码:
if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W))  # 会创建一个空的输出

避免一些特殊情况比如空的输入输出之类的。

  1. 注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:
  • 使用 torch.size(0) 来代替 len(tensor),因为 torch.size(0) 返回的是 Tensor,len(tensor) 返回的是 int。对于自定义类,实现一个 .size 方法或者使用 .__len__() 方法来代替 len() ,例如这个例子
  • 不要使用 int() 或者 torch.as_tensor 来转换 size 的类型,因为这些操作也会被视为常量。
  1. 混合 tracing 和 scripting 方法。可以使用 torch.jit.script 来转换一些 torch.jit.trace 不能搞定的小的代码片段,混合使用 tracing 和 scripting,基本可以解决所有的问题。

2.3 混合使用 trace 和 script

trace 和 script 都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用 torch.jit.trace,必要时才使用 torch.jit.script

  1. 在使用 torch.jit.trace 时,使用 @script_if_tracing 装饰器可以让被装饰的函数使用 script 方式进行编译

    def forward(self, ...):
      # ... some forward logic
      @torch.jit.script_if_tracing
      def _inner_impl(x, y, z, flag: bool):
          # use control flow, etc.
          return ...
      output = _inner_impl(x, y, z, flag)
      # ... other forward logic
    

    但是使用 @script_if_tracing 时,需要保证函数中没有 PyTorch 的 modules,如果有的话,需要做一些修改,例如下面的:

    # 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
      output = self.layers(x)
    else:
      # Create empty outputs
      output = torch.zeros(...)
    

    这里需要做如下修改:

    # 需要将self.layers移出if判断,这时候可以用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
    else:
      # Create empty inputs
      x = torch.zeros(...)
    # 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
    output = self.layers(x)
    
  2. 合并多次 trace 的结果
    使用 torch.jit.script 生成的模型相比使用 torch.jit.trace 有两个好处:

    • 可以使用条件控制流,例如模型中使用一个 bool 值来控制 forward 的 flow,在 traced modules 里面是不支持的
    • 使用 traced module,只能有一个 forward() 函数,但是使用 scripted module,可以有多个前向计算的函数
    class Detector(nn.Module):
      do_keypoint: bool
    
      def forward(self, img):
          box = self.predict_boxes(img)
          if self.do_keypoint:
              kpts = self.predict_keypoint(img, box)
    
      @torch.jit.export
      def predict_boxes(self, img): pass
    
      @torch.jit.export
      def predict_keypoint(self, img, box): pass
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

2.4 trace 和 script 的性能

trace 总是会比 script 生成一样或者更简单的计算图,因此性能会更好一些。因为 script 会完整的表达 Python 代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

class A(nn.Module):
  def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 += xs[k]
    return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

2.5 总结

trace 具有明显的局限性:这篇文章的大部分篇幅都在谈论 trace 的局限性以及如何解决这些问题。实际上,这正是 trace 的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

相反,script 更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复 script 的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

trace 和 script 都会影响代码的编写方式,但 trace 因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
  • 它需要修改一些代码才能通用(例如在 trace 时添加一些 script),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。

3. LibTorch 的使用

在得到所需模型后,可以尝试在 C++ 环境下使用得到的模型,这里就用到了 LibTorch。

3.1 LibTorch 的链接

结合自己环境的 CUDA 版本,去官网下载对应版本的 libTorch。例如 CUDA 版本为 11.1,则需要在下载地址中找到 libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip 进行下载。

链接进需要再 cmake 内加上这几行即可:

set(TORCH_PATH "/home/yj/libtorch/share/cmake/Torch")
message("TORCH_PATH set to: ${TORCH_PATH}")
set(Torch_DIR ${TORCH_PATH})

find_package(Torch REQUIRED)
message(STATUS "Torch version is: ${Torch_VERSION}")

# <target> is your target's name
target_link_libraries(<target> 
  ${TORCH_LIBRARIES}
)

3.2 接口和实现

  1. 头文件引入 :

    #include <torch/script.h>
    #include <torch/torch.h>
    
  2. 加载模型

    module = torch::jit::load(PATH);
    
  3. 函数实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/336627.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

小程序宿主环境-组件button

button <button>普通按钮</button> <button type"primary">主色调按钮</button> <button type"warn">警告按钮</button><button size"mini">普通按钮</button> <button type"primary&q…

OpenCV-Python(49):图像去噪

目标 学习使用非局部平均值去噪算法去除图像中的噪音学习函数cv2.fastNlMeansDenoising()、cv2.fastNlMeansDenoisingColored等 原理 在前面的章节中我们已经学习了很多图像平滑技术&#xff0c;比如高斯平滑、中值平滑等。当噪声比较小时&#xff0c;这些技术的效果都是很好…

钡铼 楼宇暖通网关之 BACnet网关在空气源热泵智能控制系统中的应用介绍

前言 在刚刚过去的2023年&#xff0c;空气源热泵市场依然火爆&#xff0c;全线市场销量递增&#xff0c;各种新品层出不穷&#xff0c;市场认可度持续攀升&#xff0c;在整个采暖市场&#xff0c;空气源热泵已然成为当红明星。 热泵组管道比较复杂&#xff0c;传感器分布比较分…

JUC-Java内存模型JMM

JMM概述 Java Meory Model java内存模型。在不同的硬件和不同的操作系统上&#xff0c;对内存的访问方式是不一样的。这就造成了同一套java代码运行在不同的操作系统上会出问题。JMM就屏蔽掉硬件和操作系统的差异&#xff0c;增加java代码的可移植性。这是一方面。 另一方面JM…

Java项目:11 Springboot的垃圾回收管理系统

作者主页&#xff1a;舒克日记 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 功能介绍 本系统通过利用系统的垃圾回收流程&#xff0c;提高垃圾回收效率&#xff0c;通过垃圾回收的申请&#xff0c;增删改查&#xff0c;垃圾运输申…

TCP服务器最多支持多少客户端连接

目录 一、理论数值 二、实际部署 参考 一、理论数值 首先知道一个基础概念&#xff0c;对于一个 TCP 连接可以使用四元组&#xff08;src_ip, src_port, dst_ip, dst_port&#xff09;进行唯一标识。因为服务端 IP 和 Port 是固定的&#xff08;如下图中的bind阶段&#xff0…

Mysql运维篇(一) 日志类型

一路走来&#xff0c;所有遇到的人&#xff0c;帮助过我的、伤害过我的都是朋友&#xff0c;没有一个是敌人&#xff0c;如有侵权请留言&#xff0c;我及时删除。 一、mysql相关日志 首先&#xff0c;我们能接触到的&#xff0c;一般我们排查慢查询时&#xff0c;会去看慢查询…

关于java的方法重写

关于java的方法重写 我们之前在学习方法的时候&#xff0c;了解到了方法的重载&#xff0c;但是本篇文章我们要了解的是方法的重写&#xff0c;是不一样的&#xff0c;千万不能混淆在一起&#x1f600; 一、初识重写 1、首先我们建立一个新的包&#xff0c;然后新建一个A类&…

快速幂 算法

暴力算法 我们可以采用暴力算法 #include<bits/stdc.h> using namespace std; #define ll long long int main() {ll a, b, c;cin >> a >> b >> c;ll ans 1;for (ll i 1; i < b; i) {ans * a;}ans % c;cout << ans; } 不过这样肯定会超时…

torchtext安装及常见问题

Pytorch 、 torchtext和Python之间有严格的对应关系&#xff1a; 在命令窗中安装torchtext pip install torchtext 注意这种安装方式&#xff0c;在pytorch版本与python版本不兼容时动会自动更新并安装pytorchcpu版本&#xff0c;安装的新版本pytorch可能会不兼容。慎用。 …

Qt QCustomPlot 绘制子轴

抄大神杰作&#xff1a;QCustomplot&#xff08;五&#xff09;QCPAxisRect进行子绘图-CSDN博客 需求来源&#xff1a;试验数据需要多轴对比。 实现多Y轴、单X轴、X轴是时间轴、X轴range联动、rect之间的间距是0&#xff0c;每个图上有legend(这里有个疑问&#xff0c;每添加…

【⭐AI工具⭐】实用工具推荐

目录 壹 实用工具工具合集TinyWowHiPDF 公式处理SimpleTex公式中常用的希腊字母符号公式在论文中的格式 图像处理BgRemoverPix Fix像素蒸发Photopea 音频处理啦啦爱 笔记整理飞书妙记 素材整理Eagle 其它一次性临时电子邮件近邻词汇检索据意查句诗三百能不能好好说话&#xff1…

2023 年值得一读的技术文章 | NebulaGraph 技术社区

在之前的产品篇&#xff0c;我们了解到了 NebulaGraph 内核及周边工具在 2023 年经历了什么样的变化。伴随着这些特性的变更和上线&#xff0c;在【文章】博客分类中&#xff0c;一篇篇的博文记录下了这些功能背后的设计思考和研发实践。当中&#xff0c;既有对内存管理 Memory…

Python爬虫IP池

目录 一、介绍 1.1 为什么需要IP池&#xff1f; 1.2 IP池与代理池的区别 二、构建一个简单的IP池 三、注意事项 一、介绍 在网络爬虫的世界中&#xff0c;IP池是一个关键的概念。它允许爬虫程序在请求网页时使用多个IP地址&#xff0c;从而降低被封禁的风险&#xff0c;提高…

【大坑】MyBatisPlus使用updateById莫名将数据四舍五入了

问题描述 我目前在为本地的一所高中开发一个成绩分析的网站&#xff0c;后端使用的是SpringBootMyBatisPlus&#xff0c;业务逻辑是用户在前端上传EXCEL文件&#xff0c;后端从文件中读取成绩存到数据库用于分析。但是奇怪的是&#xff1a;在后端&#xff0c;进入数据库之前的…

DBA技术栈MongoDB: 索引和查询优化

2.1 批量插入数据 单条数据插入db.collection.insertOne()多条数据插入db.collection.insertMany() db.inventory.insertMany( [{ item: "journal", qty: 25, size: { h: 14, w: 21, uom: "cm" }, status: "A" },{ item: "notebook"…

【MATLAB源码-第119期】基于matlab的GMSK系统1bit差分解调误码率曲线仿真,输出各个节点的波形以及功率谱。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 GMSK&#xff08;高斯最小频移键控&#xff09;是一种数字调制技术&#xff0c;广泛应用于移动通信&#xff0c;例如GSM网络。它是一种连续相位调频制式&#xff0c;通过改变载波的相位来传输数据。GMSK的关键特点是其频谱的…

vue3通过ref调用子组件方法,第一次点击报找不到该方法,ref和v-if冲突

通过ref实现父子组件通信&#xff0c;但在第一次点击按钮的时候报找不到子组件的方法 原因&#xff1a;ref和v-if冲突,ref只有在组件渲染完成才注册引用信息&#xff0c;v-if首次为false没有把元素或子组件渲染&#xff0c;所以没有注册引用信息。 父组件 <uni-popup ref…

GO 中高效 int 转换 string 的方法与高性能源码剖析

文章目录 使用 strconv.Itoa使用 fmt.Sprintf使用 strconv.FormatIntFormatInt 深入剖析1. 快速路径处理小整数2. formatBits 函数的高效实现 结论 Go 语言 中&#xff0c;将整数&#xff08;int&#xff09;转换为字符串&#xff08;string&#xff09;是一项常见的操作。 本文…

Peter算法小课堂—拓扑排序与最小生成树

拓扑排序 讲拓扑排序前&#xff0c;我们要先了解什么是DAG树。所谓DAG树&#xff0c;就是指“有向无环图”。请判断下列图是否是DAG图 第一幅图&#xff0c;它不是DAG图&#xff0c;因为它形成了一个环。第二幅图&#xff0c;它也不是DAG图&#xff0c;因为它没有方向。第三幅…