STM32 X-CUBE-AI:Pytorch模型部署全流程

文章目录

    • 概要
    • 版本:
    • 参考资料
    • STM32CUBEAI安装
    • CUBEAI模型支持
    • LSTM模型转换注意事项
    • 模型转换
    • 模型应用
      • 1 错误类型及代码
      • 2 模型创建和初始化
      • 3 获取输入输出数据变量
      • 4 获取模型前馈输出
      • 模型应用小结
    • 小结

概要

STM32 CUBE MX扩展包:X-CUBE-AI部署流程:模型转换、CUBEAI模型验证、CUBEAI模型应用
深度学习架构使用Pytorch模型,模型包括多个LSTM和全连接层(包含Dropout和激活函数层)。

版本:

STM32CUBEMX:6.8.1
X-CUBE-AI:8.1.0 (推荐该版本,对LSTM支持得到更新)
ONNX:1.14.0

参考资料

遇到ERROR和BUG可到ST社区提问:ST社区
CUBEAI入门指南下载地址:X-CUBE-AI入门指南手册
官方应用示例:部署示例

STM32CUBEAI安装

CUBEAI扩展包的安装目前已有许多教程,这里不再赘述。CUBEAI安装
需要注意的是,在STM32CUBEMX上安装CUBEAI时,可能并不能安装到最新版本的CUBEAI,因此可以前往ST官网下载最新版本(最新版本会对模型实现进行更新)。https://www.st.com/zh/embedded-software/x-cube-ai.html
在这里插入图片描述

CUBEAI模型支持

目前,CUBEAI支持三种类型的模型:

  1. Keras:.h5
  2. TensorFlow:.tflite
  3. 一切可以转换成ONNX格式的模型:.onnx

Pytorch部署CUBEAI需要将Pytorch生成模型.pth转换成.onnx。

LSTM模型转换注意事项

  1. 由于CUBEAI扩展包和ONNX对LSTM的转换限制,在Pytorch模型搭建时,需要设置LSTM的batch_first=False(并不影响模型训练和应用),设置后,需要注意输入输出数据的格式。
  2. LSTM模型内部全连接层的输入前,将数据切片,取最后时间步,可避免一些问题。
  3. 对于多LSTM,可以采取forward函数多输入x的形式,每一个x是一个LSTM的输入。

模型转换

  1. Pytorch->Onnx
    使用torch.onnx.export()函数,其中可动态设置变量,函数使用已有很多教程,暂不赘述。Pytorch转ONNX及验证
    由于部署后属于模型应用阶段,输入数据batch=1,seq_lengthinput_num可自行设置,也可设置动态参数(设置seq_length为动态参数后,CUBEMX验证的示例数据中的seq_length=1)。

  2. Onnx->STM32
    大致流程可参考: STM32 模型验证
    STM32CUBEMX中选中相应模型即可,可修改模型名称方便后续网络部署(不要使用默认network名字)。
    在这里插入图片描述
    验证阶段可能会发生多种问题,问题错误种类参见:
    在这里插入图片描述

模型应用

到目前为止,Pytorch模型已经成功转换成ONNX并在CUBEMX进行了验证且得到通过,下面则是STM32的模型应用部分。

通过keil打开项目后,在下面的目录下存放模型相关文件,主要用到的函数为:modelName.c、modelName.h,其中modelName为在CUBEMX中定义的模型名称。
在这里插入图片描述
主要使用的函数如下:

  1. ai_modelName_create_and_init:用于模型创建和初始化
  2. ai_modelName_inputs_get:用于获取模型输入数据
  3. ai_modelName_outputs_get:用于获取模型输出数据
  4. ai_pytorch_ftc_lstm_run:用于前馈运行模型得到输出
  5. ai_mnetwork_get_error:用于获取模型错误代码,调试用

各函数相关参数及用法如下。
相似代码可参见:X-CUBE-AI入门指南手册

1 错误类型及代码

/*!
 * @enum ai_error_type
 * @ingroup ai_platform
 *
 * Generic enum to list network error types.
 */
typedef enum {
  AI_ERROR_NONE                         = 0x00,     /*!< No error */
  AI_ERROR_TOOL_PLATFORM_API_MISMATCH   = 0x01,
  AI_ERROR_TYPES_MISMATCH               = 0x02,
  AI_ERROR_INVALID_HANDLE               = 0x10,
  AI_ERROR_INVALID_STATE                = 0x11,
  AI_ERROR_INVALID_INPUT                = 0x12,
  AI_ERROR_INVALID_OUTPUT               = 0x13,
  AI_ERROR_INVALID_PARAM                = 0x14,
  AI_ERROR_INVALID_SIGNATURE            = 0x15,
  AI_ERROR_INVALID_SIZE                 = 0x16,
  AI_ERROR_INVALID_VALUE                = 0x17,
  AI_ERROR_INIT_FAILED                  = 0x30,
  AI_ERROR_ALLOCATION_FAILED            = 0x31,
  AI_ERROR_DEALLOCATION_FAILED          = 0x32,
  AI_ERROR_CREATE_FAILED                = 0x33,
} ai_error_type;

/*!
 * @enum ai_error_code
 * @ingroup ai_platform
 *
 * Generic enum to list network error codes.
 */
typedef enum {
  AI_ERROR_CODE_NONE                = 0x0000,    /*!< No error */
  AI_ERROR_CODE_NETWORK             = 0x0010,
  AI_ERROR_CODE_NETWORK_PARAMS      = 0x0011,
  AI_ERROR_CODE_NETWORK_WEIGHTS     = 0x0012,
  AI_ERROR_CODE_NETWORK_ACTIVATIONS = 0x0013,
  AI_ERROR_CODE_LAYER               = 0x0014,
  AI_ERROR_CODE_TENSOR              = 0x0015,
  AI_ERROR_CODE_ARRAY               = 0x0016,
  AI_ERROR_CODE_INVALID_PTR         = 0x0017,
  AI_ERROR_CODE_INVALID_SIZE        = 0x0018,
  AI_ERROR_CODE_INVALID_FORMAT      = 0x0019,
  AI_ERROR_CODE_OUT_OF_RANGE        = 0x0020,
  AI_ERROR_CODE_INVALID_BATCH       = 0x0021,
  AI_ERROR_CODE_MISSED_INIT         = 0x0030,
  AI_ERROR_CODE_IN_USE              = 0x0040,
  AI_ERROR_CODE_LOCK                = 0x0041,
} ai_error_code;

2 模型创建和初始化

/*!
 * @brief Create and initialize a neural network (helper function)
 * @ingroup pytorch_ftc_lstm
 * @details Helper function to instantiate and to initialize a network. It returns an object to handle it;
 * @param network an opaque handle to the network context
 * @param activations array of addresses of the activations buffers
 * @param weights array of addresses of the weights buffers
 * @return an error code reporting the status of the API on exit
 */
AI_API_ENTRY
ai_error ai_modelName_create_and_init(
  ai_handle* network, const ai_handle activations[], const ai_handle weights[]);

重点关注输入参数networkactivations:数据类型均为ai_handle(即void*)。初始化方式如下:

	ai_error err;
	ai_handle network = AI_HANDLE_NULL;
	const ai_handle act_addr[] = { activations };
		
	// 实例化神经网络
	err = ai_modelName_create_and_init(&network, act_addr, NULL);
	if (err.type != AI_ERROR_NONE)
	{
		printf("E: AI error - type=%d code=%d\r\n", err.type, err.code);
	}

3 获取输入输出数据变量

/*!
 * @brief Get network inputs array pointer as a ai_buffer array pointer.
 * @ingroup pytorch_ftc_lstm
 * @param network an opaque handle to the network context
 * @param n_buffer optional parameter to return the number of outputs
 * @return a ai_buffer pointer to the inputs arrays
 */
AI_API_ENTRY
ai_buffer* ai_modelName_inputs_get(
  ai_handle network, ai_u16 *n_buffer);

/*!
 * @brief Get network outputs array pointer as a ai_buffer array pointer.
 * @ingroup pytorch_ftc_lstm
 * @param network an opaque handle to the network context
 * @param n_buffer optional parameter to return the number of outputs
 * @return a ai_buffer pointer to the outputs arrays
 */
AI_API_ENTRY
ai_buffer* ai_modelName_outputs_get(
  ai_handle network, ai_u16 *n_buffer);

需要先创建输入输出数据:

// 输入输出结构体
ai_buffer* ai_input;
ai_buffer* ai_output;

// 结构体内容如下
/*!
 * @struct ai_buffer
 * @ingroup ai_platform
 * @brief Memory buffer storing data (optional) with a shape, size and type.
 * This datastruct is used also for network querying, where the data field may
 * may be NULL.
 */
typedef struct ai_buffer_ {
  ai_buffer_format        format;     /*!< buffer format */
  ai_handle               data;       /*!< pointer to buffer data */
  ai_buffer_meta_info*    meta_info;  /*!< pointer to buffer metadata info */
  /* New 7.1 fields */
  ai_flags                flags;      /*!< shape optional flags */
  ai_size                 size;       /*!< number of elements of the buffer (including optional padding) */
  ai_buffer_shape         shape;      /*!< n-dimensional shape info */
} ai_buffer;

之后调用函数进行结构体赋值:

	ai_input = ai_modelName_inputs_get(network, NULL);
	ai_output = ai_modelName_outputs_get(network, NULL);

接下来需要对结构体中的data进行赋值,ai_input和ai_output均为输入输出地址,对于多输入形式的模型,可以数组索引多个输入:

// 单输入
ai_float *pIn;
ai_output[0].data = AI_HANDLE_PTR(pIn);

// 多输入
ai_float *pIn[]
for(int i=0; i<AI_MODELNAME_IN_NUM; i++)
	{
		ai_input[i].data = AI_HANDLE_PTR(pIn[i]);
	}
// 输出
ai_float *pOut;
ai_output[0].data = AI_HANDLE_PTR(pOut);

pIn为指针数组,数组内存储多个输入数据指针;AI_MODELNAME_IN_NUM为宏定义,表示输入数据数量。
AI_HANDLE_PTRai_handle类型宏定义,传入ai_float *指针,将数据转换成ai_handle类型。

#define AI_HANDLE_PTR(ptr_)           ((ai_handle)(ptr_))

4 获取模型前馈输出

/*!
 * @brief Run the network and return the output
 * @ingroup pytorch_ftc_lstm
 *
 * @details Runs the network on the inputs and returns the corresponding output.
 * The size of the input and output buffers is stored in this
 * header generated by the code generation tool. See AI_PYTORCH_FTC_LSTM_*
 * defines into file @ref pytorch_ftc_lstm.h for all network sizes defines
 *
 * @param network an opaque handle to the network context
 * @param[in] input buffer with the input data
 * @param[out] output buffer with the output data
 * @return the number of input batches processed (default 1) or <= 0 if it fails
 * in case of error the error type could be queried by 
 * using @ref ai_pytorch_ftc_lstm_get_error
 */
AI_API_ENTRY
ai_i32 ai_modelName_run(
  ai_handle network, const ai_buffer* input, ai_buffer* output);

函数传入网络句柄,输入输出buffer指针,返回处理的批次数量(应用阶段应该为1),可通过判断返回值是否为1,说明模型运行是否成功。

	printf("---------Running Network-------- \r\n");
	batch = ai_modelName_run(network, ai_input, ai_output);
	printf("---------Running End-------- \r\n");
	if (batch != BATCH) {
		err = ai_mnetwork_get_error(network);
		printf("E: AI error - type=%d code=%d\r\n", err.type, err.code);
		Error_Handler();
	}

运行后,可通过查看pOut数组数据得到模型输出。

void printData_(ai_float *pOut, ai_i8 num)
{
	printf("(Total Num: %d): ", num);
	for (int i=0; i < num; i++)
	{
		if (i == num-1)
		{
			printf("%.4f. \r\n", pOut[i]);
		}
		else
		{
			printf("%.4f, ", pOut[i]);
		}
	}
}

模型应用小结

可以根据官方部署示例中的方法对AI_InitAI_Run进行封装。

小结

遇到的BUG持续更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/141341.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Accelerate 0.24.0文档 一:两万字极速入门

文章目录 一、概述1.1 PyTorch DDP1.2 Accelerate 分布式训练简介1.2.1 实例化Accelerator类1.2.2 将所有训练相关 PyTorch 对象传递给 prepare()方法1.2.3 启用 accelerator.backward(loss) 1.3 Accelerate 分布式评估1.4 accelerate launch1.4.1 使用accelerate launch启动训…

k8s集群搭建(ubuntu 20.04 + k8s 1.28.3 + calico + containerd1.7.8)

环境&需求 服务器&#xff1a; 10.235.165.21 k8s-master 10.235.165.22 k8s-slave1 10.235.165.23 k8s-slave2OS版本&#xff1a; rootvms131:~# lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.5 LTS Release: …

Java自学第11课:电商项目(4)重新建立项目

经过前几节的学习&#xff0c;我们已经找到之前碰到的问题的原因了。那么下面接着做项目学习。 1 新建dynamic web project 建立时把web.xml也生成下&#xff0c;省的右面再添加。 会询问是否改为java ee环境&#xff1f;no就行&#xff0c;其实改过来也是可以的。这个不重要。…

css3 初步了解

1、css3的含义及简介 简而言之&#xff0c;css3 就是 css的最新标准&#xff0c;使用css3都要遵循这个标准&#xff0c;CSS3 已完全向后兼容&#xff0c;所以你就不必改变现有的设计&#xff0c; 2、一些比较重要的css3 模块 选择器 1、标签选择器&#xff0c;也称为元素选择…

滚珠螺杆在注塑机械手中起什么作用?

注塑机械手的配件中滚珠螺杆是重要的一环&#xff0c;在注塑机械手中起着重要的作用。注塑机械手是一种自动化设备&#xff0c;可以在注塑生产中实现自动化操作&#xff0c;而滚珠螺杆则是实现这一操作的关键部件之一。 滚珠螺杆是一种将旋转运动转化为直线运动的精密传动部件&…

在Linux系统下微调Llama2(MetaAI)大模型教程—Qlora

Llama2是Meta最新开源的语言大模型&#xff0c;训练数据集2万亿token&#xff0c;上下文长度是由Llama的2048扩展到4096&#xff0c;可以理解和生成更长的文本&#xff0c;包括7B、13B和70B三个模型&#xff0c;在各种基准集的测试上表现突出&#xff0c;最重要的是&#xff0c…

【Linux网络】本地DNS服务器搭建

目录 一、什么是DNS&#xff0c;相关介绍 1、dns是什么&#xff1a; 2、域名的分类&#xff1a; 3、服务器的类型 二、DNS解析的过程 三、DNS的相关配置文件学习 1、本地主机有关的DNS文件学习 2、本地的DNS缓存服务器的文件 3、bind软件的相关配置文件&#xff1a; 4…

Jmeter执行接口自动化测试-如何初始化清空旧数据

需求分析&#xff1a; 每次执行完自动化测试&#xff0c;我们不会执行删除接口把数据删除&#xff0c;而需要留着手工测试&#xff0c;此时会导致下次执行测试有旧数据我们手工可能也会新增数据&#xff0c;导致下次执行自动化测试有旧数据 下面介绍两种清空数据的方法 一、通…

nginx代理docker容器服务

场景描述 避免暴力服务端口&#xff0c;使用nginx代理 一个前端&#xff0c;一个后端&#xff0c;docker方式部署到服务器&#xff0c;使用docker创建的nginx代理端口请求到前端端口 过程 1 docker 安装nginx 1.1 安装一个指定版本的nginx docker pull nginx#启动一个ngi…

vuejs - - - - - 移动端设备兼容(pxtorem)

pxtorem的使用 1. 依赖安装2. vue.config.js配置3. 动态设置html的font-size大小4. 效果如图&#xff1a; 1. 依赖安装 yarn add postcss-pxtorem -D 2. vue.config.js配置 module.exports {...css: {loaderOptions: {postcss: {plugins: [require("postcss-pxtorem&quo…

22.能被7整除,并且求和。

#include<stdio.h>int main(){int i ,sum0;printf("1-1000能被7整除的数字有&#xff1a;\n");for(i1;i<1000;i){if(i%70){printf("%d ",i);sumsumi;} }printf("\n");printf("能被7整除的数字的和是&#xff1a;%d ",sum);re…

这样书写Python代码的方式,实在是太优雅了~

文章目录 前言一、在Python中配合pipe灵活使用链式写法二 、pipe中常用的管道操作函数1.使用traverse()展平嵌套数组2.使用dedup()进行顺序去重3.使用filter()进行值过滤4.使用groupby()进行分组运算5.使用select()对上一步结果进行自定义遍历运算6.使用sort()进行排序 总结关于…

线性表->栈

文章目录 前言概述栈的初始化销毁压栈出栈判断栈为不为空栈的有效个数 前言 栈相对于链表&#xff0c;稍微简单一点&#xff0c;但是栈的难点在于通过栈去理解递归算法。 概述 **栈&#xff1a;**一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。…

Redis解决缓存问题

目录 一、引言二、缓存三、Redis缓存四、缓存一致性1.缓存更新策略2.主动更新 五、缓存穿透六、缓存雪崩七、缓存击穿1.基于互斥锁解决具体业务2.基于逻辑过期解决具体业务 一、引言 在一些大型的网站中会有十分庞大的用户访问流量&#xff0c;而过多的用户访问对我们的MySQL数…

初学UE5 C++①

游戏类 1.创建所需项的类 2.创建游戏模式类&#xff0c;在该类上实现所需项&#xff0c;引入头文件和构造函数时实例化 三种时间函数类型函数和提示类型 FName、FString、FText类型相互转化 FName用FName FString用ToString&#xff08;&#xff09; FText用FText&#xff1a;…

零代码搭建:无需编程基础,轻松搭建数据自己的能源监测管理平台

零代码搭建能源管理平台&#xff0c;其核心是通过使用图形用户界面和可视化建模工具&#xff0c;来减少编写代码的工作量以及技能要求。平台拥有丰富的预定义组件&#xff0c;可以帮助管理人员快速构建应用程序。并可自定义区域框架&#xff0c;在搭建自己区域时&#xff0c;能…

说说对React Hooks的理解?解决了什么问题?

一、是什么 Hook 是 React 16.8 的新增特性。它可以让你在不编写 class 的情况下使用 state 以及其他的 React 特性 至于为什么引入hook,官方给出的动机是解决长时间使用和维护react过程中常遇到的问题,例如: 难以重用和共享组件中的与状态相关的逻辑逻辑复杂的组件难以开…

Juniper PPPOE双线路冗余RPM配置

------------------ 浮动静态路由 set routing-options static route 0.0.0.0/0 next-hop pp0.0 qualified-next-hop pp0.1 preference 10 ----------------- RPM测试的内容,包括从哪个接口发起测试,测试ping等等 #指定探针类型用ICMP请求 #探测的目标地址 #探测间隔 #探测阈…

编译原理-语法分析-自上而下分析

文章目录 语法分析器的功能自上而下分析面临的问题LL&#xff08;1&#xff09;分析法左递归的消除直接左递归非直接左递归 消除左递归的算法消除回溯、提左因子FIRST提左因子FOLLOW集 LL(1)的分析条件LL(1)文法构造FIRST和FOLLOW集合构造每个文法符号的FIRST集合构造FOLLOW集合…