PyTorch 模型转换为 ONNX 格式

PyTorch 模型转换为 ONNX 格式

在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH 和 ONNX 格式的区别,并介绍如何使用 Netron 可视化 ONNX 模型。

1. PTH 和 ONNX 的区别

PTH 格式

  • 定义:PTH 是 PyTorch 框架的专有格式,通常用于保存模型的状态字典(state_dict),包括模型的结构和训练好的参数。

  • 兼容性

    • PTH 文件只能在 PyTorch 中使用,无法直接在 C++ 环境中加载。虽然 PyTorch 提供了 C++ API(LibTorch),但 PTH 文件的加载和使用主要依赖于 Python 环境。
    • 在 C++ 中使用 PTH 文件需要将模型转换为 PyTorch 的 C++ 格式,这可能会增加复杂性和开发时间。
  • 用途

    • PTH 格式适合在 Python 环境中进行模型训练和调试,但在 C++ 中进行模型部署时,通常需要将模型转换为其他格式(如 ONNX)以便于跨平台使用。
    • 在 C++ 中,使用 PTH 文件的灵活性较低,尤其是在需要与其他框架或系统集成时。

ONNX 格式

  • 定义:ONNX(Open Neural Network Exchange)是一个开放的深度学习模型交换格式,旨在促进不同深度学习框架之间的互操作性。

  • 兼容性

    • ONNX 文件可以在多个深度学习框架中使用,包括 PyTorch、TensorFlow、Caffe2 等,这使得它在 C++ 环境中的兼容性更强。
    • ONNX 模型可以通过 ONNX Runtime、TensorRT、OpenVINO 等推理引擎在 C++ 中高效运行,支持多种硬件加速。
  • 用途

    • ONNX 格式非常适合模型的部署和推理,特别是在需要跨平台或跨框架使用时。它允许开发者在 C++ 中轻松加载和运行模型,而无需依赖于 Python 环境。
    • 在 C++ 中,使用 ONNX 模型可以简化工程化流程,便于与其他系统集成,提升模型的可移植性和可扩展性。

总结

在 C++ 进行深度学习模型的工程化时,选择 ONNX 格式通常更为合适,因为它提供了更好的跨平台兼容性和灵活性。PTH 格式虽然在 PyTorch 环境中非常方便,但在 C++ 中的使用受到限制,通常需要额外的转换步骤。ONNX 的开放性和广泛支持使其成为在多种环境中部署深度学习模型的首选格式。

2. 训练 MNIST 数据集的 CNN 模型

以下是使用 PyTorch 训练 MNIST 数据集的完整代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 1. 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 2. 定义 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入通道为1,输出通道为32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入通道为32,输出通道为64
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 全连接层
        self.fc2 = nn.Linear(128, 10)  # 输出层

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + 激活 + 池化
        x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + 激活 + 池化
        x = x.view(x.size(0), -1)  # 展平输入
        x = torch.relu(self.fc1(x))  # 第一个全连接层
        x = self.fc2(x)  # 输出层
        return x

# 3. 训练模型
model = SimpleCNN().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器

# 训练过程
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备
        optimizer.zero_grad()  # 清空梯度
        outputs = model(images)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 4. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

# 5. 转换为 ONNX 格式
onnx_file_path = 'mnist_cnn_model.onnx'
dummy_input = torch.randn(1, 1, 28, 28).to(device)  # 示例输入,形状为 [batch_size, channels, height, width]
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True,
                  opset_version=11, do_constant_folding=True,
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

print(f'Model has been converted to ONNX format and saved as {onnx_file_path}.')

3. 使用 Netron 可视化 ONNX 模型

一旦您将模型转换为 ONNX 格式,您可以使用 Netron 来可视化模型结构。Netron 是一个开源的模型可视化工具,支持多种深度学习框架的模型文件格式,包括 ONNX。

使用步骤:
  1. 下载 Netron

    • 您可以访问 Netron 的官方网站 在线使用,或者下载桌面版本。
  2. 打开 ONNX 模型

    • 如果使用在线版本,直接将 mnist_cnn_model.onnx 文件拖放到浏览器窗口中。
    • 如果使用桌面版本,打开 Netron 应用,选择“File” > “Open Model”,然后选择您的 ONNX 文件。
  3. 查看模型结构

    • 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/924744.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Three.js 和其他 WebGL 库 对比

在WebGL开发中,Three.js是一个非常流行的库,它简化了3D图形的创建和渲染过程。然而,市场上还有许多其他的WebGL库,如 Babylon.js、PlayCanvas、PIXI.js 和 Cesium,它们也有各自的特点和优势。本文将对Three.js 与这些常…

通过 JNI 实现 Java 与 Rust 的 Channel 消息传递

做纯粹的自己。“你要搞清楚自己人生的剧本——不是父母的续集,不是子女的前传,更不是朋友的外篇。对待生命你不妨再大胆一点,因为你好歹要失去它。如果这世上真有奇迹,那只是努力的另一个名字”。 一、crossbeam_channel 参考 cr…

摆脱复杂配置!使用MusicGPT部署你的私人AI音乐生成环境

文章目录 前言1. 本地部署2. 使用方法介绍3. 内网穿透工具下载安装4. 配置公网地址5. 配置固定公网地址 前言 今天给大家分享一个超酷的技能:如何在你的Windows电脑上快速部署一款文字生成音乐的AI创作服务——MusicGPT,并且通过cpolar内网穿透工具&…

挑战用React封装100个组件【001】

项目地址 https://github.com/hismeyy/react-component-100 组件描述 组件适用于需要展示图文信息的场景,比如产品介绍、用户卡片或任何带有标题、描述和可选图片的内容展示 样式展示 代码展示 InfoCard.tsx import ./InfoCard.cssinterface InfoCardProps {ti…

搭建帮助中心到底有什么作用?

在当今快节奏的商业环境中,企业面临着日益增长的客户需求和竞争压力。搭建一个有效的帮助中心对于企业来说,不仅是提升客户服务体验的重要途径,也是优化内部知识管理和提升团队效率的关键。以下是帮助中心在企业运营中的几个关键作用&#xf…

学习threejs,使用CubeCamera相机创建反光效果

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️CubeCamera 立方体相机 二、…

微前端-MicroApp

微前端即是由一个主应用来集成多个微应用(可以不区分技术栈进行集成) 下面是使用微前端框架之一 MicroApp 对 react微应用 的详细流程 第一步 创建主应用my-mj-app 利用脚手架 npx create-react-app my-mj-app 快速创建 安装 npm install --save rea…

深度学习—BP算法梯度下降及优化方法Day37

梯度下降 1.公式 w i j n e w w i j o l d − α ∂ E ∂ w i j w_{ij}^{new} w_{ij}^{old} - \alpha \frac{\partial E}{\partial w_{ij}} wijnew​wijold​−α∂wij​∂E​ α为学习率 当α过小时,训练时间过久增加算力成本,α过大则容易造成越过最…

wp the_posts_pagination 与分类页面搭配使用

<ul> <?php while( have_posts() ) : the_post(); <li > <a href"<?php the_permalink(); ?>"> <?php xizhitbu_get_thumbnail(thumb-pro); ?> </a> <p > <a href&q…

深度学习-49-AI应用实战之基于HyperLPR的车牌识别

文章目录 1 车牌识别系统1.1 识别原理1.1.1 车牌定位1.1.2 字符识别2 实例应用2.1 安装hyperlpr32.2 识别结果2.3 可视化显示2.4 结合streamlit3 附录3.1 PIL.Image转换成OpenCV格式3.2 OpenCV转换成PIL.Image格式3.3 st.image嵌入图像内容3.4 参考附录1 车牌识别系统 车牌识别…

ShuffleNet V2:高效卷积神经网络架构设计的实用指南

摘要 https://arxiv.org/pdf/1807.11164 当前&#xff0c;神经网络架构设计大多以计算复杂度的间接指标&#xff0c;即浮点运算数&#xff08;FLOPs&#xff09;为指导。然而&#xff0c;直接指标&#xff08;例如速度&#xff09;还取决于其他因素&#xff0c;如内存访问成本…

在C#中使用OpenCV的.net包装器EmguCV

Emgu.CV OpenCvSharp 两个库都是OpenCV的C#封装。这里不讨论优劣&#xff0c;两个都有相应的用途。 下载安装4.6.0.5131&#xff0c;执行文件exe https://sourceforge.net/projects/emgucv/files/emgucv/4.6.0/ 安装到一个目录下&#xff0c;这里安装到H:\Emgu\ 目录下。…

HarmonyOS:@Provide装饰器和@Consume装饰器:与后代组件双向同步

一、前言 Provide和Consume&#xff0c;应用于与后代组件的双向数据同步&#xff0c;应用于状态数据在多个层级之间传递的场景。不同于上文提到的父子组件之间通过命名参数机制传递&#xff0c;Provide和Consume摆脱参数传递机制的束缚&#xff0c;实现跨层级传递。 其中Provi…

webrtc 3A移植以及实时处理

文章目录 前言一、交叉编译1.Pulse Audio webrtc-audio-processing2.交叉编译 二、基于alsa进行实时3A处理1.demo源码2.注意项3.效果展示 总结 前言 由于工作需要&#xff0c;硬件3A中的AEC效果实在太差&#xff0c;后面使用SpeexDSP的软3A&#xff0c;效果依旧不是很好&#…

Java 反射(Reflection)

Java 反射&#xff08;Reflection&#xff09; Java 反射&#xff08;Reflection&#xff09;是一个强大的特性&#xff0c;它允许程序在运行时查询、访问和修改类、接口、字段和方法的信息。反射提供了一种动态地操作类的能力&#xff0c;这在很多框架和库中被广泛使用&#…

深入浅出剖析典型文生图产品Midjourney

2022年7月,一个小团队推出了公测的 Midjourney,打破了 AIGC 领域的大厂垄断。作为一个精调生成模型,以聊天机器人方式部署在 Discord,它创作的《太空歌剧院》作品,甚至获得了美国「数字艺术/数码摄影」竞赛单元一等奖。 这一事件展示了 AI 在绘画领域惊人的创造力,让人们…

【Linux】磁盘 | 文件系统 | inode

&#x1fa90;&#x1fa90;&#x1fa90;欢迎来到程序员餐厅&#x1f4ab;&#x1f4ab;&#x1f4ab; 主厨&#xff1a;邪王真眼 主厨的主页&#xff1a;Chef‘s blog 所属专栏&#xff1a;青果大战linux 总有光环在陨落&#xff0c;总有新星在闪烁 模电好难啊&#xff…

PHP 去掉特殊不可见字符 “\u200e“

描述 最近在排查网站业务时&#xff0c;发现有数据匹配失败的情况 肉眼上完全看不出问题所在 当把字符串 【M24308/23-14F‎】复制出来发现 末尾有个不可见的字符 使用删除键或左右移动时才会发现 最后测试通过 var_dump 打印 发现这个"空字符"占了三个长度 &#xf…

【C#设计模式(15)——命令模式(Command Pattern)】

前言 命令模式的关键通过将请求封装成一个对象&#xff0c;使命令的发送者和接收者解耦。这种方式能更方便地添加新的命令&#xff0c;如执行命令的排队、延迟、撤销和重做等操作。 代码 #region 基础的命令模式 //命令&#xff08;抽象类&#xff09; public abstract class …

使用zabbix监控k8s

一、 参考文献 小阿轩yx-案例&#xff1a;Zabbix监控kubernetes云原生环境 手把手教你实现zabbix对Kubernetes的监控 二、部署经验 关于zabbix监控k8s&#xff0c;总体来说是分为两块内容&#xff0c;一是在k8s集群部署zabbix-agent和zabbix- proxy。二是在zabbix进行配置。…