【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os

# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()

# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:
    os.system(f"wget {url} -O {filename}")
except Exception as e:
    print(f"Error downloading image: {e}")

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Perform inference on CPU
with torch.no_grad():
    output = model(input_tensor)

# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")

# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")

# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)

print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")

# Load the ONNX model
import onnx
import onnxruntime

onnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)

# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()

# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)

# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):
    print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())


得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image

# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Perform inference on CPU
with torch.no_grad():
    output = model(input_tensor)

# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image

# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()

# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)

# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):
    print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/159300.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

人生阶段总结

--回顾一下我迷茫、努力、不开心又失败的阶段人生自我介绍一下,我是一个智力平平,记忆力差,适合自学的长睡眠者。 大专之前 国内的应试教育基本上不适合我,厌恶补课厌恶机械式听课刷题,所有的优势学科都是自学&#xf…

Python hashlib 模块详细教程:安全哈希与数据完整性保护

更多Python学习内容:ipengtao.com 大家好,我是涛哥,今天为大家分享 Python hashlib 模块详细教程,文章6500字,阅读大约17分钟,大家enjoy~~ hashlib模块为Python提供了一种简便的方式来使用各种哈希算法&…

vue2 tinymce富文本插件

一、介绍 TinyMCE是一款易用、且功能强大的所见即所得的富文本编辑器。同类程序有:UEditor、Kindeditor、Simditor、CKEditor、wangEditor、Suneditor、froala等等。 TinyMCE的优势: 开源可商用,基于LGPL2.1插件丰富,自带插件基…

航天联志Aisino-AISINO26081R服务器通过调BIOS用U盘重新做系统(windows系统通用)

产品名称:航天联志Aisino系列服务器 产品型号:AISINO26081R CPU架构:Intel 的CPU,所以支持Windows Server all 和Linux系统(重装完系统可以用某60驱动管家更新所有硬件驱动) 操作系统:本次我安装的服务器系统为Serv…

瑞吉外卖Day06

1.用户地址 1.1实体类 /*** 地址簿*/ Data public class AddressBook implements Serializable {private static final long serialVersionUID 1L;private Long id;//用户idprivate Long userId;//收货人private String consignee;//手机号private String phone;//性别 0 女…

安装应用与免安装应用差异对比

差异 安装的程序和免安装的应用程序之间有以下几个方面的差别: 安装过程:安装的程序需要通过一个安装程序或安装脚本进行安装。这个过程通常会将应用程序的文件和依赖项复制到指定的目录,并进行一些配置和注册操作。免安装的应用程序则不需要…

HP惠普暗影精灵9笔记本电脑OMEN by HP Transcend 16英寸游戏本16-u0000原厂Windows11系统

惠普暗影9恢复出厂开箱状态,原装出厂Win11-22H2系统ISO镜像 下载链接:https://pan.baidu.com/s/17ftbBHEMFSEOw22tnYvPog?pwd91p1 提取码:91p1 适用型号:16-u0006TX、16-u0007TX、16-u0008TX、16-u0009TX、16-u0017TX 原厂系…

element ui修改select选择框背景色和边框色

一、修改时间输入框的背景和边框字体颜色 <div class"hright"><el-date-picker :popper-append-to-body"false" class"custom-timeselect" v-model"form.timevalue" type"daterange" range-separator"至"…

SpringBoot-过滤器Filter+JWT令牌实现登录验证

登录校验-Filter 分析 过滤器Filter的快速入门以及使用细节我们已经介绍完了&#xff0c;接下来最后一步&#xff0c;我们需要使用过滤器Filter来完成案例当中的登录校验功能。 我们先来回顾下前面分析过的登录校验的基本流程&#xff1a; 要进入到后台管理系统&#xff0c;我…

Web前后端漏洞分析与防御

第1章 课程介绍 试看2 节 | 15分钟 介绍安全问题在web开发中的重要性&#xff0c;并对课程整体进行介绍 收起列表 视频&#xff1a; 1-1 Web安全课程介绍 (09:24) 试看 视频&#xff1a; 1-2 项目总览 (04:47) 第2章 环境搭建2 节 | 26分钟 本章节我们会搭建项目所需要的环境 …

Linux上使用Python源码编译安装Python

安装python apt install python3-dev python3 python3-venv yum install python38-devel源码安装Python 1.下载需要的Python版本 Python源码地址&#xff1a;https://www.python.org/downloads/source/ 2.安装gcc&#xff08;yum install gcc&#xff09; 3.解压&#xff0c…

vim模式用法总结

0.前言 我们用gcc编译文件的时候&#xff0c;如果发生了下面的错误&#xff0c;那么如何用vim打开的时候就定位到&#xff1f; 我们可以知道&#xff0c;这是第6行出现了错误&#xff1b; 所以我们使用vim打开的时候多输入个这个&#xff0c;我们就可以快速定位了 vim test.c 6…

一阶低通滤波器(一阶巴特沃斯滤波器)

连续传递函数G(s) 离散传递函数G(z) 转换为差分方程形式 一阶巴特沃斯滤波器Filter Designer参数设计&#xff1a;参考之前的博客Matlab的Filter Designer工具设计二阶低通滤波器 设计采样频率100Hz&#xff0c;截止频率20Hz。 注意&#xff1a;设计参数使用在离散系统中&…

03.webpack中hash,chunkhash和contenthash 的区别

hash、contenthash 和 chunkhash 是通过散列函数处理之后&#xff0c;生成的一串字符&#xff0c;可用于区分文件。 作用&#xff1a;善用文件的哈希值&#xff0c;解决浏览器缓存导致的资源未及时更新的问题 1.文件名不带哈希值 const path require(path) const HtmlWebpac…

键盘方向键移动当前选中的table单元格,并可以输入内容

有类似于这样的表格&#xff0c;用的<table>标签。原本要在单元格的文本框里面输入内容&#xff0c;需要用鼠标一个一个去点以获取焦点&#xff0c;现在需要不用鼠标选中&#xff0c;直接用键盘的上下左右来移动当前正在输入的单元格文本框。 const currentCell React.u…

简单漂亮的首页

效果图 说明 这个首页我也是构思了很久&#xff0c;才想出这个界面&#xff0c;大家喜欢的话&#xff0c;可以拿走去使用 技术的话&#xff0c;采用的就是vue的语法&#xff0c;但是不影响&#xff0c;很多样式我都是直接手敲出来的 代码实现 标语 <!-- 标语 start-->&…

《QT从基础到进阶·三十》QVariant的基础用法

很多时候&#xff0c;需要几种不同的数据类型需要传递&#xff0c;如果用结构体&#xff0c;又不大方便&#xff0c;容器保存的也只是一种数据类型&#xff0c;而QVariant则可以统统搞定。 QVariant可以保存QT和C常用类型&#xff0c;如果是自定义类型&#xff0c;比如struct,c…

【配置环境】VS Code怎么使用JavaScript的Mocha测试框架和Chai断言库

一&#xff0c;环境 Windows 11 家庭中文版&#xff0c;64 位操作系统, 基于 x64 的处理器VS Code 版本: 1.83.1 (user setup)Node.js 版本&#xff1a;20.9.0 二&#xff0c;安装背景 在运行测试用例时遇到 ReferenceError: describe is not defined 错误&#xff0c;网上搜寻…

JUC工具类_CyclicBarrier与CountDownLatch

最近被问到CyclicBarrier和CountDownLatch相关的面试题&#xff0c;CountDownLatch平时工作中经常用到&#xff0c;但是CyclicBarrier没有用过&#xff0c;一时答不上来&#xff0c;因此简单总结记录一下 1.什么是CyclicBarrier&#xff1f; 1.1 概念 CyclicBarrier&#xff…

MyBatis #{} 和 ${} 的区别

前言&#xff1a; #{} 和 ${} 的区别是 MyBatis 中一个常见的面试题&#xff0c;#{} 和 ${} 是MyBatis 中获取参数的两种方式&#xff0c;但我们在项目中大多数使用的都是 #{} 来获取参数&#xff0c;那么它们两个有什么区别呢&#xff1f; 区别 一. #{} 采用预编译 SQL&…