30分钟吃掉pytorch转onnx及推理

pytorch模型线上部署最常见的方式是转换成onnx然后再转成tensorRT 在cuda上进行部署推理。

本文介绍将pytorch模型转换成onnx模型并进行推理的方法。

#!pip install onnx 
#!pip install onnxruntime
#!pip install torchvision

公众号算法美食屋后台回复关键词:源码,获取本文notebook源代码。

一,准备pytorch模型

我们先导入torchvision中的resnet18模型,演示它的推理效果。

以便和onnx的结果进行对比。

 
 
import torch
import torchvision.models as models
import numpy as np
import torchvision
import torchvision.transforms as T


from PIL import Image


def create_net():
    net = models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    return net 


net = create_net()


torch.save(net.state_dict(),'resnet18.pt')
net.eval();
 
 
def get_test_transform():
    return T.Compose([
        T.Resize([320, 320]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


image = Image.open("dog.png") # 289
img = get_test_transform()(image)
img = img.unsqueeze_(0) 
output = net(img)
score, indice = torch.max(torch.softmax(output,axis=-1),1)
info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}


def show_image(image, title):
    import matplotlib.pyplot as plt 
    ax=plt.subplot()
    ax.imshow(image)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([]) 
    plt.show()


show_image(image, title = info)

35c44ecc42bff2e1badd24ed674a4d86.png

二,pytorch模型转换成onnx模型

1, 简化版本

 
 
import onnxruntime
import onnx


batch_size = 1  
input_shape = (3, 320, 320)   


x = torch.randn(batch_size, *input_shape)
onnx_file = "resnet18.onnx"
torch.onnx.export(net,x,onnx_file,
                opset_version=10,
                do_constant_folding=True,  # 是否执行常量折叠优化
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={
                    "input":{0:"batch_size"},  
                     "output":{0:"batch_size"}})
!du -s -h resnet18.pt
45M	resnet18.pt
!du -s -h resnet18.onnx
45M	resnet18.onnx

可以在 https://netron.app/ 中拖入 resnet18.onnx 文件查看模型结构

2,全面版本

下面的代码包括了设置输入输出尺寸,以及动态可以变batch等等。

 
 
import argparse
from argparse import Namespace
import time
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
import onnx
import onnxruntime


from io import BytesIO




ROOT = os.getcwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))


params = Namespace(weights='resnet18.pt',
                   img_size=[320,320],
                   batch_size=1,
                   half=False,
                   dynamic_batch=True
                  )


parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='checkpoint.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--dynamic-batch', action='store_true', help='export dynamic batch onnx model')
parser.add_argument('--trt-version', type=int, default=8, help='tensorrt version')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')


args = parser.parse_args(args='',namespace=params)




args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand
print(args)


t = time.time()


# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device(f'cuda:{args.device}' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'


# Load PyTorch model
model = create_net()
model.to(device)
model.load_state_dict(torch.load(args.weights)) # pytorch模型加载


# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device)  # image size(1,3,320,192) iDetection


# Update model
if args.half:
    img, model = img.half(), model.half()  # to FP16
model.eval()


prediction = model(img)  # dry run


# ONNX export
print('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx')  # filename
with BytesIO() as f:
    dynamic_axes = {"input":{0:"batch_size"}, "output":{0:"batch_size"} } if args.dynamic_batch else None
    torch.onnx.export(model, img, f, verbose=False, opset_version=13,
                      training=torch.onnx.TrainingMode.EVAL,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=dynamic_axes)
    f.seek(0)
    # Checks
    onnx_model = onnx.load(f)  # load onnx model
    onnx.checker.check_model(onnx_model)  # check onnx model
    
if args.simplify:
    try:
        import onnxsim
        print('\nStarting to simplify ONNX...')
        onnx_model, check = onnxsim.simplify(onnx_model)
        assert check, 'assert check failed'
    except Exception as e:
        print(f'Simplifier failure: {e}')


onnx.save(onnx_model, export_file)


print(f'ONNX export success, saved as {export_file}')


# Finish
print('\nExport complete (%.2fs)' % (time.time() - t))
Namespace(weights='resnet18.pt', img_size=[320, 320], batch_size=1, half=False, dynamic_batch=True, inplace=False, simplify=False, trt_version=8, device='cpu')

Starting to export ONNX...
ONNX export success, saved as resnet18.onnx

Export complete (0.57s)

三,使用onnx模型进行推理

1,函数风格

 
 
onnx_sesstion = onnxruntime.InferenceSession(export_file)
 
 
def pipe(img_path,
         onnx_sesstion = onnx_sesstion):
    image = Image.open(img_path) 
    img = get_test_transform()(image)
    img = img.unsqueeze_(0) 


    to_numpy = lambda tensor: tensor.data.cpu().numpy()
    
    inputs = {onnx_sesstion.get_inputs()[0].name: to_numpy(img)}
    outs = onnx_sesstion.run(None, inputs)[0]


    score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
    info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}
    return info
 
 
img_path = 'dog.png'
image = Image.open(img_path)
info = pipe(img_path)
show_image(image,info)

5bb3e0e3bd75dd3694f3ce98f42fb5b4.png

2,对象风格

 
 
import os, sys


import onnxruntime
import onnx
    
class ONNXModel():
    def __init__(self, onnx_path):
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_names = [node.name for node in self.onnx_session.get_inputs()]
        self.output_names = [node.name for node in self.onnx_session.get_outputs()]
        print("input_name:{}".format(self.input_names))
        print("output_name:{}".format(self.output_names))
 
    def forward(self, x):
        if isinstance(x,np.ndarray):
            assert len(self.input_names)==1
            input_feed = {self.input_names[0]:x}
        elif isinstance(x,(tuple,list)):
            assert len(self.input_names)==len(x)
            input_feed = {k:v for k,v in zip(self.input_names,x)}
        else:
            assert isinstance(x,dict)
            input_feed = x
        outs = self.onnx_session.run(self.output_names, input_feed=input_feed)
        return outs
    
    def predict(self,img_path):
        image = Image.open(img_path) 
        img = get_test_transform()(image)
        img = img.unsqueeze_(0) 
        to_numpy = lambda tensor: tensor.data.cpu().numpy()
        outs = self.forward(to_numpy(img))[0]
        score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
        return {'score':score[0].data.numpy().tolist(),
            'indice':indice[0].data.numpy().tolist()}
 
 
onnx_model = ONNXModel(export_file)
info = onnx_model.predict(img_path)
show_image(image, title = info)
input_name:['input']
output_name:['output']

e1e4da5206fcf00a85153d5ff34d3977.png

c3fc58f804e640de7223dfeb9cecba8a.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/693163.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

jmeter -n -t 使用非GUI模式运行脚本说明

命令模式下执行jmx文件 jmeter -n -t fatie.jmx -l results\t4.jtl -e -o results\h1 表示以命令行模式运行当前目录下的脚本fatie.jmx,将结果存入当前目录下的results\t1.jtl,并且生成html格式的报告,写入文件夹results\h1。 说明:生成结果的文件夹r…

《精通ChatGPT:从入门到大师的Prompt指南》第10章:案例分析

第10章:案例分析 10.1 优秀Prompt案例解析 在深入探讨如何精通ChatGPT的使用之前,理解并分析一些优秀的Prompt案例是至关重要的。这不仅有助于更好地掌握Prompt的构建技巧,还能提高与AI交互的效果。在这一节中,我们将详细解析一…

实用的 C 盘搬家软件

一、简介 1、一款专门用于 Windows 系统的文件夹移动工具,它允许用户将程序或游戏的安装文件夹从一台驱动器移动到另一台驱动器,或者同一个驱动器内的不同路径,而无需重新安装或破坏现有的程序安装。 二、下载 1、下载地址: 官网链…

1-Maven-settings配置

1-Maven-settings配置 整理下Maven工具的使用。 【本地仓库、私服、镜像仓库、远程仓库、中央仓库】 本文基于阅读其他博客和对公司Maven配置的学习整理出来的。希望通过本此学习能对Maven有个整体性的掌控。 顺序:profile.repository > pom文件中的repository &…

关于焊点检测(SJ-BIST)模块实现

关于焊点检测(SJ-BIST)模块实现 语言 :Verilg HDL 、VHDL EDA工具:ISE、Vivado、Quartus II 关于焊点检测(SJ-BIST)模块实现一、引言二、焊点检测功能的实现方法(1) 输入接口&#x…

SpringBoot+Vue网上超市(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 角色对应功能 用户管理员 功能截图

C基础与SDK调试方法

REVIEW 上次学习了一下软件使用流程zynq PS点灯-CSDN博客 本次学习一下C编程基础与调试方法 1. 硬件编程原理 小梅哥视频链接: 07_Xilinx嵌入式裸机硬件编程原理_哔哩哔哩_bilibili 对应的课程笔记:【zynq课程笔记】【裸机】【第7课 】【硬件编程原理…

eNSP学习——配置RIP路由附加度量值

目录 主要命令 原理概述 实验目的 实验内容 实验拓扑 实验编址 实验步骤 1、基本配置 2、搭建RIP网络 3、配置RIP Metricin 4、配置RIP Metricout 需要eNSP各种配置命令的点击链接自取:华为eNSP各种设备配置命令大全PDF版_ensp配置命令大全资…

Vyper重入漏洞解析

什么是重入攻击 Reentrancy攻击是以太坊智能合约中最具破坏性的攻击之一。当一个函数对另一个不可信合约进行外部调用时,就会发生重入攻击。然后,不可信合约会递归调用原始函数,试图耗尽资金。 当合约在发送资金之前未能更新其状态时&#…

计算机网络-数制转换与子网划分

目录 一、了解数制 1、计算机的数制 2、二进制 3、八进制 4、十进制 5、十六进制 二、数制转换 1、二进制转十进制 2、八进制转十进制 3、十六进制转十进制 4、十进制转二进制 5、十进制转八进制 6、十进制转十六进制 三、子网划分 1、IP地址定义 2、IP的两种协…

Linux之进程信号详解【上】

🌎 Linux信号详解 文章目录: Linux信号详解 信号入门 技术应用角度的信号 信号及信号的产生       信号的概念       信号的处理方式 信号的产生方式         键盘产生信号         系统调用产生信号         软件…

Nvidia/算能 +FPGA+AI大算力边缘计算盒子:隧道和矿井绘图设备

RockMass 正在努力打入采矿业和隧道工程利基市场。 这家位于多伦多的初创公司正在利用 NVIDIA AI 开发一款绘图平台,帮助工程师评估矿井和施工中的隧道稳定性。 目前,作为安全预防措施,地质学家和工程师会站在离岩石五米远的地方&#xff0…

Llama模型家族之Stanford NLP ReFT源代码探索 (二)Intervention Layers层

LlaMA 3 系列博客 基于 LlaMA 3 LangGraph 在windows本地部署大模型 (一) 基于 LlaMA 3 LangGraph 在windows本地部署大模型 (二) 基于 LlaMA 3 LangGraph 在windows本地部署大模型 (三) 基于 LlaMA…

【Java 百“练”成钢】Java 基础:类和对象

Java 基础:类和对象 01.打印信息02.打印类的简单名称03.打印类的 ClassLoader04.获取类的方法05.获取类的Package06.创建一个对象数组07.计算圆的面积08.计算圆的周长09.创建具有私有访问修饰符的成员10.创建带访问修饰符的成员11.将对象作为参数传递12.通过类对象获…

开源多平台AI音乐生成器本地安装结合cpolar内网穿透实现远程访问

文章目录 前言1. 本地部署2. 使用方法介绍3. 内网穿透工具下载安装4. 配置公网地址5. 配置固定公网地址 前言 本文主要介绍如何在Windows系统电脑上快速本地部署一个文字生成音乐的AI创作工具MusicGPT,并结合cpolar内网穿透工具实现随时随地远程访问使用。 MusicG…

Linux 35.5 + JetPack v5.1.3@ ego-planner编译安装

Linux 35.5 JetPack v5.1.3 ego-planner编译安装 1. 源由2. 编译&安装Step 1:依赖库安装Step 2:建立工程Step 3:编译工程Step 4:安装工程 3. 问题汇总3.1 planner/plan_env - OpenCV3.2 uav_simulator/local_sensing - CUDA优…

基于非下采样小波包分析的滚动轴承故障诊断(MATLAB R2021B)

小波变换具有良好的时频局部化特性和多分辨率特性,可准确定位信号的突变点并可在不同尺度上描述信号的局部细节特征,被广泛应用于信号降噪。但标准正交小波变换不具有平移不变性,采用标准正交小波对信号消噪后,会在脉冲尖峰处产生…

进口电动防爆调节阀的性能-美国品牌

进口电动防爆调节阀以其卓越的防爆性能、安全可靠、精确控制以及广泛的技术规格和应用场合,成为工业领域中不可或缺的重要设备。 进口电动防爆调节阀的性能可以归纳如下: 一、防爆性能 防爆设计:进口电动防爆调节阀采用专门的防爆设计和制…

使用 PAI-DSW x Free Prompt Editing图像编辑算法,开发个人AIGC绘图小助理

教程简述 在本教程中,您将学习在阿里云交互式建模平台PAI-DSW x Free Prompt Editing(CVPR2024中选论文算法)图像编辑算法,开发个人AIGC绘图小助理,实现文本驱动的图像编辑功能单卡即可完成AIGC图片风格变化、背景变化…

5. 刷题

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/868 题目描述 小爱需要完成 𝑛n 道题目…