30分钟吃掉 Pytorch 转 onnx

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:

《大模型面试宝典》(2024版) 发布!

圈粉无数!《PyTorch 实战宝典》火了!!!


PyTorch 是一个用于机器学习的开源深度学习框架,而ONNX(Open Neural Network Exchange)是一个用于表示深度学习模型的开放式格式。

将 PyTorch 模型转换为ONNX格式有几个原因和优势:

  1. 跨平台部署: ONNX是一个跨平台的格式,支持多种深度学习框架,包括PyTorch、TensorFlow等。将模型转换为ONNX格式可以使模型在不同框架和设备上进行部署和运行。

  2. 性能优化: ONNX格式可以在不同框架之间实现性能优化。例如,可以在PyTorch中训练模型,然后转换为ONNX格式,并在性能更高的框架(如TensorRT)中进行推理。

  3. 模型压缩: ONNX格式可以实现模型的压缩和优化,从而减小模型的体积并提高推理速度。这对于在资源受限的设备上部署模型尤为重要。

pytorch 模型线上部署最常见的方式是转换成onnx,然后再转成tensorRT 在cuda上进行部署推理。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了Pytorch 技术与面试交流群, 想要获取最新面试题、了解最新面试动态的、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

本文介绍将pytorch模型转换成onnx模型并进行推理的方法。

#!pip install onnx 
#!pip install onnxruntime
#!pip install torchvision

一,准备pytorch模型

我们先导入torchvision中的resnet18模型,演示它的推理效果。

以便和onnx的结果进行对比。

import torch
import torchvision.models as models
import numpy as np
import torchvision
import torchvision.transforms as T

from PIL import Image

def create_net():
    net = models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    return net 

net = create_net()

torch.save(net.state_dict(),'resnet18.pt')
net.eval();
def get_test_transform():
    return T.Compose([
        T.Resize([320, 320]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

image = Image.open("dog.png") # 289
img = get_test_transform()(image)
img = img.unsqueeze_(0) 
output = net(img)
score, indice = torch.max(torch.softmax(output,axis=-1),1)
info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}

def show_image(image, title):
    import matplotlib.pyplot as plt 
    ax=plt.subplot()
    ax.imshow(image)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([]) 
    plt.show()

show_image(image, title = info)

图片

二,pytorch模型转换成onnx模型

1, 简化版本

import onnxruntime
import onnx

batch_size = 1  
input_shape = (3, 320, 320)   

x = torch.randn(batch_size, *input_shape)
onnx_file = "resnet18.onnx"
torch.onnx.export(net,x,onnx_file,
                opset_version=10,
                do_constant_folding=True,  # 是否执行常量折叠优化
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={
                    "input":{0:"batch_size"},  
                     "output":{0:"batch_size"}})

!du -s -h resnet18.pt
 45M	resnet18.pt
!du -s -h resnet18.onnx 

 45M	resnet18.onnx

可以在 https://netron.app/ 中拖入 resnet18.onnx 文件查看模型结构

2,全面版本

下面的代码包括了设置输入输出尺寸,以及动态可以变batch等等。

import argparse
from argparse import Namespace
import time
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
import onnx
import onnxruntime

from io import BytesIO


ROOT = os.getcwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

params = Namespace(weights='resnet18.pt',
                   img_size=[320,320],
                   batch_size=1,
                   half=False,
                   dynamic_batch=True
                  )

parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='checkpoint.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--dynamic-batch', action='store_true', help='export dynamic batch onnx model')
parser.add_argument('--trt-version', type=int, default=8, help='tensorrt version')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

args = parser.parse_args(args='',namespace=params)


args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand
print(args)

t = time.time()

# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device(f'cuda:{args.device}' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'

# Load PyTorch model
model = create_net()
model.to(device)
model.load_state_dict(torch.load(args.weights)) # pytorch模型加载

# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device)  # image size(1,3,320,192) iDetection

# Update model
if args.half:
    img, model = img.half(), model.half()  # to FP16
model.eval()

prediction = model(img)  # dry run

# ONNX export
print('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx')  # filename
with BytesIO() as f:
    dynamic_axes = {"input":{0:"batch_size"}, "output":{0:"batch_size"} } if args.dynamic_batch else None
    torch.onnx.export(model, img, f, verbose=False, opset_version=13,
                      training=torch.onnx.TrainingMode.EVAL,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=dynamic_axes)
    f.seek(0)
    # Checks
    onnx_model = onnx.load(f)  # load onnx model
    onnx.checker.check_model(onnx_model)  # check onnx model
    
if args.simplify:
    try:
        import onnxsim
        print('\nStarting to simplify ONNX...')
        onnx_model, check = onnxsim.simplify(onnx_model)
        assert check, 'assert check failed'
    except Exception as e:
        print(f'Simplifier failure: {e}')

onnx.save(onnx_model, export_file)

print(f'ONNX export success, saved as {export_file}')

# Finish
print('\nExport complete (%.2fs)' % (time.time() - t))
Namespace(weights='resnet18.pt', img_size=[320, 320], batch_size=1, half=False, dynamic_batch=True, inplace=False, simplify=False, trt_version=8, device='cpu')

Starting to export ONNX...
ONNX export success, saved as resnet18.onnx

Export complete (0.57s)

三,使用onnx模型进行推理

1,函数风格

onnx_sesstion = onnxruntime.InferenceSession(export_file)
def pipe(img_path,
         onnx_sesstion = onnx_sesstion):
    image = Image.open(img_path) 
    img = get_test_transform()(image)
    img = img.unsqueeze_(0) 

    to_numpy = lambda tensor: tensor.data.cpu().numpy()
    
    inputs = {onnx_sesstion.get_inputs()[0].name: to_numpy(img)}
    outs = onnx_sesstion.run(None, inputs)[0]

    score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
    info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}
    return info
img_path = 'dog.png'image = Image.open(img_path)info = pipe(img_path)show_image(image,info)

图片

2,对象风格

import os, sys

import onnxruntime
import onnx
    
class ONNXModel():
    def __init__(self, onnx_path):
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_names = [node.name for node in self.onnx_session.get_inputs()]
        self.output_names = [node.name for node in self.onnx_session.get_outputs()]
        print("input_name:{}".format(self.input_names))
        print("output_name:{}".format(self.output_names))
 
    def forward(self, x):
        if isinstance(x,np.ndarray):
            assert len(self.input_names)==1
            input_feed = {self.input_names[0]:x}
        elif isinstance(x,(tuple,list)):
            assert len(self.input_names)==len(x)
            input_feed = {k:v for k,v in zip(self.input_names,x)}
        else:
            assert isinstance(x,dict)
            input_feed = x
        outs = self.onnx_session.run(self.output_names, input_feed=input_feed)
        return outs
    
    def predict(self,img_path):
        image = Image.open(img_path) 
        img = get_test_transform()(image)
        img = img.unsqueeze_(0) 
        to_numpy = lambda tensor: tensor.data.cpu().numpy()
        outs = self.forward(to_numpy(img))[0]
        score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
        return {'score':score[0].data.numpy().tolist(),
            'indice':indice[0].data.numpy().tolist()}
onnx_model = ONNXModel(export_file)
info = onnx_model.predict(img_path)
show_image(image, title = info)
input_name:['input']
output_name:['output']

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/693869.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【算法小记】深度学习——时间序列数据分析 Time series Data Analysis

在本篇博客中将简单介绍常见的几种循环神经网络和一维卷积神经网络,并使用一些简答的数据进行拟合分析。本文相对适合刚入门的同学,同时也作为自己过去一段时间学习的总结和记录,现在神经网络框架已经非常完善的支持了很多常见和有效的深度学…

Bootstrap框架集成ECharts教程

最新公司项目要在原有的基础上增加一些饼状图和柱状图来统计一些数据给客户,下面就是集成的一个过程,还是很简单的,分为以下几步 1、引入ECharts的包 2、通过ECharts官网或者菜鸟教程直接拿示例代码过来修修改改直接用就可以了 注意&#xf…

Python爬虫入门与登录验证自动化思路

1、pytyon爬虫 1.1、爬虫简介 Python爬虫是使用Python编写的程序,可以自动访问网页并提取其中的信息。爬虫可以模拟浏览器的行为,自动点击链接、填写表单、进行登录等操作,从而获取网页中的数据。 使用Python编写爬虫的好处是,…

python中while循环实现九九乘法表

i 1while i < 9: # 控制行的循环j 1while j < i: # 控制每行的输出print(f"{j}*{i}{j * i}\t", end"")j 1print()i 1运行截图&#xff1a;

ARM-V9 RME(Realm Management Extension)系统架构之系统安全能力的架构差异

安全之安全(security)博客目录导读 RME系统中的应用处理单元&#xff08;PE&#xff09;之间的架构差异可能会带来潜在的安全风险并增加管理软件的复杂性。例如&#xff0c;通过在ID_AA64MMFR0_EL1.PARange中为每个PE设置不同的值来支持不同的物理范围&#xff0c;可能会妨碍内…

复数的概念

1. 虚数单位&#xff1a;i 引入一个新数 ‘i’&#xff0c;i又叫做虚数单位&#xff0c;并规定&#xff1a; 它的平方等于 -1&#xff0c;即 i -1。实数可以与它进行四则运算&#xff0c;并且原有的加&#xff0c;乘运算律依然成立。 2.定义 复数的定义&#xff1a;形如 a…

医学领域科技查新点提炼方法!---附案例分析

医学领域的查新项目研究范围较广&#xff0c;涉及基础医学、临床医学、中医学、预防医学、卫生学、特种医学等众多与人类健康和疾病有关的科学。查新目的主要包括立项、成果鉴定和报奖&#xff0c;有的期刊投稿也要求作者提供查新报告。 医学领域查新项目的两极化较明显&#…

(Proteus仿真设计)基于51单片机的电梯程序控制系统

&#xff08;Proteus仿真设计&#xff09;基于51单片机的电梯程序控制系统 一.项目介绍 本设计模拟的是一个五层的&#xff0c;各楼层间隔为4.5m的电梯程序控制系统&#xff0c;能够完成各楼层乘客的接送任务。形象地说&#xff0c;就是要对不同楼层乘客的不同需求&#xff0…

【NI国产替代】产线综测仪:锂电池保护板测试仪,支持快速定制

• 精度等级01% • 支持直流电压、电流、nA 级待机电流电阻等&#xff0c;常规测试 • 支持过压、欠压、过冲、过放、过温,短路等&#xff0c;保护测试 • 通讯总线电平可编程&#xff0c;兼容多种 • 支持 SWD 或IIC 固件烧录 • 测试速度快&#xff0c;支持最多 24 通道…

WPF-UI布局

WPF布局元素有如下几个&#xff1a; Grid&#xff1a;网格。可以自定义行和列并通过行列的数量、行高和列宽来调整控件的布局。StackPanel&#xff1a;栈式面板。可将包含的元素在竖直或水平方向上排成一条直线&#xff0c;当移除一个元素后&#xff0c;后面的元素会自动向前移…

leetcode:不同的二叉树

class Solution { public:int numTrees(int n) {vector<int> dp(n1);dp[0] 1;dp[1] 1;for(int i 2;i < n;i){for(int j 1;j < i;j) // 当根节点为j时{dp[i] dp[j-1] * dp[i-j];}}return dp[n];} }; /* dp[i] i个不同的数组成的二叉搜索数的个数假设 i 5当根…

【栈】895. 最大频率栈

本文涉及知识点 栈 LeetCode895. 最大频率栈 设计一个类似堆栈的数据结构&#xff0c;将元素推入堆栈&#xff0c;并从堆栈中弹出出现频率最高的元素。 实现 FreqStack 类: FreqStack() 构造一个空的堆栈。 void push(int val) 将一个整数 val 压入栈顶。 int pop() 删除并返…

webapi跨越问题

由于浏览器存在同源策略&#xff0c;为了防止 钓鱼问题&#xff0c;浏览器直接请求才不会有跨越的问题 浏览器要求JavaScript或Cookie只能访问同域下的内容 浏览器也是一个应用程序&#xff0c;有很多限制&#xff0c;不能访问和使用电脑信息&#xff08;获取cpu、硬盘等&#…

在开源处理器架构RISC-V中发现可远程利用的中危漏洞

在RISC-V SonicBOOM处理器设计中发现中度危险的漏洞 最近&#xff0c;西北工业大学的网络空间安全学院胡伟教授团队在RISC-V SonicBOOM处理器设计中发现了一个中度危险的漏洞。这个团队的研究人员发现了一个可远程利用的漏洞&#xff0c;该漏洞存在于开源处理器架构RISC-V中。…

JAVAEE值网络编程(2)_TCP流套接字及通信模型、TCP网络编程及代码实例

前言 在上一节内容中&#xff0c;我们介绍了什么是套接字&#xff0c;以及使用UDP数据报套接字网络编程&#xff0c; 最后我们还介绍了Java数据报套接字通信模型以及相关代码实例。在这一节我们将会介绍TCP流套接字编程。 一、流套接字及通信模型 1.1 TCP套接字 TCP&#xff0…

总结七大排序算法

插入排序 直接插入排序是一种简单的插入排序法&#xff0c;其基本思想是&#xff1a;把待排序的记录按其关键码值的大小逐个插入到一个已经排好序的有序序列中&#xff0c;直到所有的记录插入完为止&#xff0c;得到一个新的有序序列 。实际中我们玩扑克牌时&#xff0c;就用了…

kNN算法-概述

所谓kNN算法就是K-nearest neigbor algorithm。这是似乎是最简单的监督机器学习算法。在训练阶段&#xff0c;kNN算法存储了标签训练样本数据。简单地说&#xff0c;就是调用训练方法时传递给它的标签训练样本会被它存储起来。 kNN算法也叫lazy learning algorithm懒惰学习算法…

给浮躁的面试者一个建议

哈喽&#xff0c;大家好&#xff0c;我叫人宅&#xff0c;关于找工作&#xff0c;大家心态非常浮躁&#xff0c;尤其是零零后&#xff0c;或者是九五后。本次为大家分享一下关于就业问题和就业态度。 我讲解的这些其实适合所有高科技行业。我这边就拿程序员为例。 如果你是刚毕…

【React】Redux与React - 环境准备

配套工具 在React中使用redux&#xff0c;官方要求安装俩个其他插件 - Redux Toolkit 和 react-redux 配置基础环境 使用 CRA 快速创建 React 项目 npx create-react-app react-redux安装配套工具 npm i reduxjs/toolkit react-redux启动项目 npm run start

【实战项目二】Python爬取豆瓣影评

目录 一、环境准备 二、编写代码 一、环境准备 pip install beautifulsoup4 pip intall lxml pip install requests我们需要爬取这些影评 二、编写代码 我们发现每个影评所在的div的class都相同&#xff0c;我们可以从这入手 from bs4 import BeautifulSoup import request…