12.6深度学习_模型优化和迁移_模型移植

八、模型移植

1. 认识ONNX

​ https://onnx.ai/

​ Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。

​ ONNX的规范及代码主要由微软,亚马逊 ,Face book 和 IBM等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, PaddlePaddle, TensorFlow等。

2. 导出ONNX

2.1 安装依赖包

pip install onnx
pip install onnxruntime

2.2 导出ONNX模型

import os
import torch
import torch.nn as nn
from torchvision.models import resnet18

if __name__ == "__main__":
    dir = os.path.dirname(__file__)
    weightpath = os.path.join(
        os.path.dirname(__file__), "pth", "resnet18_default_weight.pth"
    )
    onnxpath = os.path.join(
        os.path.dirname(__file__), "pth", "resnet18_default_weight.onnx"
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(
        #
        in_channels=3,
        out_channels=64,
        kernel_size=3,
        stride=1,
        padding=0,
        bias=False,
    )
    # 删除池化层
    model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
	# 修改全连接层
    in_feature = model.fc.in_features
    model.fc = nn.Linear(in_feature, 10)
    model.load_state_dict(torch.load(weightpath, map_location=device))
    model.to(device)
    # 创建一个实例输入
    x = torch.randn(1, 3, 224, 224, device=device)
    # 导出onnx
    torch.onnx.export(
        model,
        x,
        onnxpath,
        #
        verbose=True, # 输出转换过程
        input_names=["input"],
        output_names=["output"],
    )
    print("onnx导出成功")

2.3 ONNX结构可视化

可以直接在线查看:https://netron.app/

也可以下载桌面版:https://github.com/lutzroeder/netron

3. ONNX推理

ONNX在做推理时不再需要导入网络,且适用于Python、JAVA、PyQT等各种语言,不再依赖于PyTorch框架;

3.1 简单推理

import onnxruntime as ort
import torchvision.transforms as transforms
import cv2 as cv
import os
import numpy as np

img_size = 224
transformtest = transforms.Compose(
    [
        transforms.ToPILImage(),  # 将numpy数组转换为PIL图像
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            # 均值和标准差
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2471, 0.2435, 0.2616],
        ),
    ]
)


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=True)


def cv_imread(file_path):
    cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)
    return cv_img

lablename = "飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车".split("、")

if __name__ == "__main__":
    dir = os.path.dirname(__file__)
    weightpath = os.path.join(
        os.path.dirname(__file__), "pth", "resnet18_default_weight.pth"
    )
    onnxpath = os.path.join(
        os.path.dirname(__file__), "pth", "resnet18_default_weight.onnx"
    )

    # 读取图片
    img_path = os.path.join(dir, "test", "5.jpg")
    img = cv_imread(img_path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img_tensor = transformtest(img)

    # 将图片转换为ONNX运行时所需的格式
    img_numpy = img_tensor.numpy()
    img_numpy = np.expand_dims(img_numpy, axis=0)  # 增加batch_size维度
    # 加载onnx模型
    sess = ort.InferenceSession(onnxpath)

    # 运行onnx模型
    outputs = sess.run(None, {"input": img_numpy})
    output = outputs[0]

    # 应用softmax
    probabilities = softmax(output)
    print(probabilities)
    # 获得预测结果
    pred_index = np.argmax(probabilities, axis=1)
    pred_value = probabilities[0][pred_index[0]]

    print(pred_index)
    
    print(
        "预测目标:",
        lablename[pred_index[0]],
        "预测概率:",
        str(pred_value * 100)[:5] + "%",
    )

输出结果:

[[6.7321511e-05 9.7113671e-11 7.6417709e-05 2.8661249e-02 7.0206769e-04
  3.9052707e-04 9.7010124e-01 6.8206714e-07 4.1351362e-07 5.7089373e-09]]
[6]
预测目标: 青蛙 预测概率: 97.01%

3.2 使用GPU推理

需要安装依赖包:

pip install onnxruntime-gpu

代码:

# 导入FileSystemStorage
import time
import random
import os

# 人工智能推理用到的模块
import onnxruntime as ort
import torchvision.transforms as transforms
import numpy as np
import PIL.Image as Image

img_size = 32
transformtest = transforms.Compose(
    [
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            # 均值和标准差
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2471, 0.2435, 0.2616],
        ),
    ]
)


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=True)


def imgclass():
    # AI推理
    # 读取图片
    imgpath = os.path.join(os.path.dirname(__file__), "..", "static/ai", filename)
    # 加载并预处理图像
    image = Image.open(imgpath)
    input_tensor = transformtest(image)
    input_tensor = input_tensor.unsqueeze(0)  # 添加批量维度

    # 将图片转换为ONNX运行时所需的格式
    img_numpy = input_tensor.numpy()

    # 加载模型
    onnxPath = os.path.join(
        #
        os.path.dirname(__file__),
        "..",
        "onnx",
        "resnet18_default_weight_1.onnx",
    )
    # 设置 ONNX Runtime 使用 GPU
    providers = ["CUDAExecutionProvider"]
    sess = ort.InferenceSession(onnxPath, providers=providers)
    # 使用模型对图片进行推理运算
    output = sess.run(None, {"input": img_numpy})
    output = softmax(output[0])
    print(output)
    ind = np.argmax(output, axis=1)
    print(ind)
    lablename = "飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车".split("、")
    res = {"code": 200, "msg": "处理成功", "url": img, "class": lablename[ind[0]]}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/932113.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2-2-18-14 QNX系统架构之 TCP/IP 网络

阅读前言 本文以QNX系统官方的文档英文原版资料为参考,翻译和逐句校对后,对QNX操作系统的相关概念进行了深度整理,旨在帮助想要了解QNX的读者及开发者可以快速阅读,而不必查看晦涩难懂的英文原文,这些文章将会作为一个…

实战:MyBatis适配多种数据库:MySQL、Oracle、PostGresql等

概叙 很多时候,一套代码要适配多种数据库,主流的三种库:MySQL、Oracle、PostGresql,刚好mybatis支持这种扩展,如下图所示,在一个“namespace”,判断唯一的标志是iddatabaseId,刚好写…

电子信息工程自动化 单片机彩灯控制

摘要 随着社会经济和科学技术的不断进步,人们在保持发展的同时,环境带给人类的影响已经不足以让我们忽视,所以城市的美化问题慢慢的进入了人们的眼帘,PLC的产生给带电子产品带来了巨大变革,彩灯的使用在城市的美化中变…

【后台管理系统】-【组件封装】

目录 组件封装搜索组件table列表组件content组件自定义插槽定制modal组件动态获取options数据 组件封装 搜索组件 给页面写一个配置文件,将配置文件传入组件,可直接生成页面,以下面页面为例, 新建src/views/main/system/depart…

「嵌入式系统设计与实现」书评:学习一个STM32的案例

本文最早发表于电子发烧友论坛:【新提醒】【「嵌入式系统设计与实现」阅读体验】 学习一个STM32的案例 - 发烧友官方/活动 - 电子技术论坛 - 广受欢迎的专业电子论坛!https://bbs.elecfans.com/jishu_2467617_1_1.html 感谢电子发烧友论坛和电子工业出版社的赠书。 …

设计模式:20、状态模式(状态对象)

目录 0、定义 1、状态模式的三种角色 2、状态模式的UML类图 3、示例代码 0、定义 允许一个对象在其内部状态改变时改变它的行为,对象看起来似乎修改了它的类。 1、状态模式的三种角色 环境(Context):环境是一个类&#xff0…

idea中新建一个空项目

目的,为了在同一个目录下有多个小的项目:使用IDE为idea2022。 步骤: 点击新建项目,点击创建空项目,这里选择空项目是将其作为其他项目的一个容器,如图所示: 然后点击文件->项目结构&#xf…

Java基础复习

“任何时候我也不会满足,越是多读书,就越是深刻地感到不满足,越感到自己知识贫乏。科学是奥妙无穷的。” ——马克思 目录 一、方法&方法重载 二、运算符 三、数据类型 四、面向对象 1. 面向对象思想 2. 引用传递 3. 访问权限修饰…

嵌入式里的“移植”概念

这里因为最近一年看到公司某项目很多代码上有直接硬件的操作,这里有感而发,介绍移植的概念。 一、硬件 先上一个图: 举个例子,大学里应该都买过开发板,例如st的,这里三个层次, 内核&#xff…

量子计算与商业转型之旅

近年来,各组织对量子计算的应用有所增加,并使全球商业运营发生了显著变化。《财富商业洞察》的一份报告显示,2022 年量子计算市场价值为 7.17 亿美元,预计到 2030 年将达到 65.28 亿美元。 从本质上讲,量子计算机与经…

周末和男朋友户外运动美好时光

周末的时光总是那么令人期待,仿佛一周的忙碌和疲惫都在这一刻得到了释放。这次,我和男朋友决定到户外去锻炼身体,享受一下大自然的馈赠。 清晨的阳光透过窗帘的缝隙洒进房间,我懒洋洋地睁开眼睛,看到男朋友已经在一旁整…

【Git】:标签管理

目录 理解标签 创建标签 操作标签 理解标签 标签的作用 标记版本:标签 tag ,可以简单的理解为是对某次 commit 的⼀个标识,相当于起了⼀个别名。例如,在项目发布某个版本的时候,针对最后⼀次 commit 起⼀个 v1.0 这样…

QT的ui界面显示不全问题(适应高分辨率屏幕)

//自动适应高分辨率 QCoreApplication::setAttribute(Qt::AA_EnableHighDpiScaling);一、问题 电脑分辨率高,默认情况下,打开QT的ui界面,显示不全按钮内容 二、解决方案 如果自己的电脑分辨率较高,可以尝试以下方案:自…

【Elasticsearch】初始化默认字段及分词

1、添加分词插件 1)在线安装 执行命令 需要指定相同的版本 bin/elasticsearch-plugin.bat install https://get.infini.cloud/elasticsearch/analysis-ik/7.17.24 2)离线安装 将安装包解压到 /plugins 目录下 安装包可以从对应的资源处下载 启动成…

MATLAB直流电机模型,直流电机控制

直流电机控制简介 直流电机(DC motor)广泛应用于各种机械驱动和电力控制系统中,其运行性能的控制至关重要。为了精准地控制直流电机的输出特性,可以通过不同的控制方式进行调节。常见的控制方式包括电枢电流控制、速度控制、电机位…

Linux之封装线程库和线程的互斥

Linux之封装线程库和线程的互斥与同步 一.封装线程库二.线程的互斥2.1互斥量的概念2.2初始化和销毁互斥量2.3加锁和解锁2.4互斥量的原理2.5可重入和线程安全2.6死锁 一.封装线程库 其实在我们C内部也有一个线程库而C中的线程库也是封装的原生线程库的函数,所以我们…

PHP语法学习(第九天)—PHP连接mysql详解(下)

首先,温馨提示,该部分内容跟昨天“PHP语法学习(第八天)—PHP连接mysql详解(上)”一起食用更佳噢!! 学习本篇内容必须掌握数据库基础命令点击“MYSQL 数据库”~~ 本文是接着PHP连接mysql的知识点接着讲,今天主要讲述PHP…

qt基本部分控件用法(一)

前言: 以前 windows下做工具主要是MFC,趁有点空时间,研究了QT,感觉跟MFC 差不多,VS 比 QT CREATOR 还是强大,不过QT可以跨平台,功能更强大,MFC 只能在win平台下.; 1:环境…

Mysql索引,聚簇索引,非聚簇索引,回表查询

什么是索引 数据库索引是为了实现高效数据查询的一种有序的数据数据结构,类似于书的目录,通过目录可以快速的定位到想要的数据,因为一张表中的数据会有很多,如果直接去表中检索数据效率会很低,所以需要为表中的数据建立…

以MP6924A为核心的LLC拓扑学习【一】

PFCLLC: 在PFC(功率因数校正)和LLC(谐振变换器)组成的电源系统中,各个电路有特定的作用,它们协同工作以实现高效率和高功率因数的电能转换。 1. PFC(功率因数校正)电路的作用 PFC电…