Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能

前言

近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中加载和推理训练好的深度学习模型,尤其适用于需要人工智能(AI)或机器学习(ML)推理的游戏或应用。

YOLO(You Only Look Once)是一种用于目标检测的深度学习模型,它是由Joseph Redmon等人在2015年提出的。YOLO的核心思想是将目标检测问题转化为一个回归问题,在单一的神经网络中同时预测图像中的多个目标位置和类别标签。它通过将目标检测转化为回归问题,极大地提高了检测速度,并且在精度上也能达到非常好的水平。随着版本的更新和技术的不断进步,YOLO逐渐成为了计算机视觉领域中最重要和最广泛应用的模型之一,特别适用于实时处理、嵌入式设备和大规模部署。

本文依托上述两个技术,在Unity3d中实现YOLO的目标检测功能,基于Barracuda(2.0.0)的跨平台性,将实现包含移动端(目前测试了安卓)的目标检测功能,能检测出日常物体桌、椅、人、狗、羊、马等对象。

理论上本工程可以在Windows/Mac/iPhone/Android/Magic Leap/Switch/PS4/Xbox等系统和平台正常工作,目前仅测试了Windows和Android平台,相比Windows平台的流畅,Android手机上运行有明显的掉帧和卡顿,具体可以对比效果图。

官方给出支持的平台说明:
CPU 推理:支持所有 Unity 平台。
GPU 推理:支持所有 Unity 平台,但以下平台:
OpenGL ESon :使用 Vulkan/Metal。Android/iOS
OpenGL Core上:使用 Metal。Mac
WebGL:使用 CPU 推理。

关注并私信 U3D目标检测免费获取应用包(底部公众号)。

效果

手机端效果:
在这里插入图片描述

在这里插入图片描述

PC端效果:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

实现

Barracuda 是一个简单、对开发人员友好的API,只需编写少量代码即可开始使用Barracuda:

var model = ModelLoader.Load(filename);
var engine = WorkerFactory.CreateWorker(model, WorkerFactory.Device.GPU);
var input = new Tensor(1, 1, 1, 10);
var output = engine.Execute(input).PeekOutput();

Barracuda 神经网络导入管道基于ONNX(Open Neural Network Exchange)格式的模型,允许您从各种外部框架(包括Pytorch、TensorFlow和Keras)引入神经网络模型。

关于模型

Barracuda目前仅支持推理,所以模型靠TensorFlow/Pytorch/Keras训练、导入,而且必须先将其转换为 ONNX,然后将其加载到 Unity中。ONNX(Open Neural Network Exchange)是一种用于ML 模型的开放格式。它允许您在各种ML框架和工具之间轻松交换模型。
Pytorch将模型导出到ONNX很容易

# network
net = ...

# Input to the model
x = torch.randn(1, 3, 256, 256)

# Export the model
torch.onnx.export(net,                       # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "example.onnx",            # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,           # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['X'],       # the model's input names
                  output_names = ['Y']       # the model's output names
                  )

我这里准备的是很简单的模型,如下图:
在这里插入图片描述
在这里插入图片描述

确保ONNX模型的输入尺寸、通道顺序(NCHW)与Barracuda兼容。
因为要兼顾移动端效果,所以模型检测识别对象较少,以防止在移动设备上的推理慢。

UI搭建

运行时的UI相对简单,两个button用于打开摄像头和打开视频功能,一个Slider用于控制标记框的显示阈值,就是检测的可信度从0-1(0%-100%)的范围;一个rawImage组件用于显示检测的画面:
在这里插入图片描述

其次是标记框的UI,由一个图片和Text构成:
在这里插入图片描述

编码

加载模型

var model = ModelLoader.Load(resources.model);

其中模型类型是NNModel。

创建推理引擎 (Worker)并执行模型:

_worker = model.CreateWorker();
using (var t = new Tensor(_config.InputShape, _buffers.preprocess))
  _worker.Execute(t);

提取神经网络输出:

_worker.CopyOutput("Identity", _buffers.feature1);
_worker.CopyOutput("Identity_1", _buffers.feature2);

将网络的两个输出复制到缓冲区。

第一阶段后处理,检测数据:

var post1 = _resources.postprocess1;
post1.SetInt("ClassCount", _config.ClassCount);
post1.SetFloat("Threshold", threshold);
post1.SetBuffer(0, "Output", _buffers.post1);
post1.SetBuffer(0, "OutputCount", _buffers.counter);

var width1 = _config.FeatureMap1Width;
post1.SetTexture(0, "Input", _buffers.feature1);
post1.SetInt("InputSize", width1);
post1.SetFloats("Anchors", _config.AnchorArray1);
post1.DispatchThreads(0, width1, width1, 1);

var width2 = _config.FeatureMap2Width;
post1.SetTexture(0, "Input", _buffers.feature2);
post1.SetInt("InputSize", width2);
post1.SetFloats("Anchors", _config.AnchorArray2);
post1.DispatchThreads(0, width2, width2, 1);

聚合检测结果,使用两个特征图进行目标检测,执行目标定位(Bounding Box)预测。

第二阶段后处理,重叠移除:

var post2 = _resources.postprocess2;
post2.SetFloat("Threshold", 0.5f);
post2.SetBuffer(0, "Input", _buffers.post1);
post2.SetBuffer(0, "InputCount", _buffers.counter);
post2.SetBuffer(0, "Output", _buffers.post2);
post2.Dispatch(0, 1, 1, 1);

移除重叠的边界框。

上面的复杂处理是借Compute Shader的Preprocess、Postprocess1和postprocess2来实现的,Compute Shader 是一种图形编程中的着色器类型,专门用于执行计算任务,而不直接参与渲染。详细内容如下。

Common.hlsl:

// Compile-time constants
#define MAX_DETECTION 512
#define ANCHOR_COUNT 3

// Detection data structure - The layout of this structure must be matched
// with the one defined in Detection.cs.
struct Detection
{
    float x, y, w, h;
    uint classIndex;
    float score;
};

// Misc math functions

float CalculateIOU(in Detection d1, in Detection d2)
{
    float x0 = max(d1.x - d1.w / 2, d2.x - d2.w / 2);
    float x1 = min(d1.x + d1.w / 2, d2.x + d2.w / 2);
    float y0 = max(d1.y - d1.h / 2, d2.y - d2.h / 2);
    float y1 = min(d1.y + d1.h / 2, d2.y + d2.h / 2);

    float area0 = d1.w * d1.h;
    float area1 = d2.w * d2.h;
    float areaInner = max(0, x1 - x0) * max(0, y1 - y0);

    return areaInner / (area0 + area1 - areaInner);
}

float Sigmoid(float x)
{
    return 1 / (1 + exp(-x));
}

#endif

Postprocess1.compute:

#pragma kernel Postprocess1

#include "Common.hlsl"

// Input
Texture2D<float> Input;
uint InputSize;
float2 Anchors[ANCHOR_COUNT];
uint ClassCount;
float Threshold;

// Output buffer
RWStructuredBuffer<Detection> Output;
RWStructuredBuffer<uint> OutputCount; // Only used as a counter

[numthreads(8, 8, 1)]
void Postprocess1(uint2 id : SV_DispatchThreadID)
{
    if (!all(id < InputSize)) return;

    // Input reference point:
    // We have to read the input tensor in reversed order.
    uint ref_y = (InputSize - 1 - id.y) * InputSize + (InputSize - 1 - id.x);

    for (uint aidx = 0; aidx < ANCHOR_COUNT; aidx++)
    {
        uint ref_x = aidx * (5 + ClassCount);

        // Bounding box / confidence
        float x = Input[uint2(ref_x + 0, ref_y)];
        float y = Input[uint2(ref_x + 1, ref_y)];
        float w = Input[uint2(ref_x + 2, ref_y)];
        float h = Input[uint2(ref_x + 3, ref_y)];
        float c = Input[uint2(ref_x + 4, ref_y)];

        // ArgMax[SoftMax[classes]]
        uint maxClass = 0;
        float maxScore = exp(Input[uint2(ref_x + 5, ref_y)]);
        float scoreSum = maxScore;
        for (uint cidx = 1; cidx < ClassCount; cidx++)
        {
            float score = exp(Input[uint2(ref_x + 5 + cidx, ref_y)]);
            if (score > maxScore)
            {
                maxClass = cidx;
                maxScore = score;
            }
            scoreSum += score;
        }

        // Output structure
        Detection data;
        data.x = (id.x + Sigmoid(x)) / InputSize;
        data.y = (id.y + Sigmoid(y)) / InputSize;
        data.w = exp(w) * Anchors[aidx].x;
        data.h = exp(h) * Anchors[aidx].y;
        data.classIndex = maxClass;
        data.score = Sigmoid(c) * maxScore / scoreSum;

        // Thresholding
        if (data.score > Threshold)
        {
            // Detected: Count and output
            uint count = OutputCount.IncrementCounter();
            if (count < MAX_DETECTION) Output[count] = data;
        }
    }
}

Postprocess2.compute:

#pragma kernel Postprocess2

#include "Common.hlsl"

// Input
StructuredBuffer<Detection> Input;
RWStructuredBuffer<uint> InputCount; // Only used as a counter
float Threshold;

// Output
AppendStructuredBuffer<Detection> Output;

// Local arrays for data cache
groupshared Detection _entries[MAX_DETECTION];
groupshared bool _flags[MAX_DETECTION];

[numthreads(1, 1, 1)]
void Postprocess2(uint3 id : SV_DispatchThreadID)
{
    // Initialize data cache arrays
    uint entry_count = min(MAX_DETECTION, InputCount.IncrementCounter());
    if (entry_count == 0) return;

    for (uint i = 0; i < entry_count; i++)
    {
        _entries[i] = Input[i];
        _flags[i] = true;
    }

    for (i = 0; i < entry_count - 1; i++)
    {
        if (!_flags[i]) continue;

        for (uint j = i + 1; j < entry_count; j++)
        {
            if (!_flags[j]) continue;
            if (CalculateIOU(_entries[i], _entries[j]) < Threshold)
                continue;
            if (_entries[i].score < _entries[j].score)
            {
                _flags[i] = false;
                break;
            }
            else
                _flags[j] = false;
        }
    }
    for (i = 0; i < entry_count; i++)
        if (_flags[i]) Output.Append(_entries[i]);
}

Postprocess.compute:

#pragma kernel Preprocess

sampler2D Image;
RWStructuredBuffer<float> Tensor;
uint Size;

[numthreads(8, 8, 1)]
void Preprocess(uint2 id : SV_DispatchThreadID)
{
    // UV (vertically flipped)
    float2 uv = float2(0.5 + id.x, Size - 0.5 - id.y) / Size;

    // UV gradients
    float2 duv_dx = float2(1.0 / Size, 0);
    float2 duv_dy = float2(0, -1.0 / Size);

    // Texture sample
    float3 rgb = tex2Dgrad(Image, uv, duv_dx, duv_dy).rgb;

    // Tensor element output
    uint offs = (id.y * Size + id.x) * 3;
    Tensor[offs + 0] = rgb.r;
    Tensor[offs + 1] = rgb.g;
    Tensor[offs + 2] = rgb.b;
}

通过以上的处理,最后输出了一个目标检测的对象结果数组,主要包含如下数据:

public readonly struct Detection
{
    public readonly float x, y, w, h;
    public readonly uint classIndex;
    public readonly float score;
}

通过遍历这个数组,并将结果标记框和对象名称等信息显示出来:

public void SetAttributes(in Detection d)
{
    var rect = _parent.rect;

    var x = d.x * rect.width;
    var y = (1 - d.y) * rect.height;
    var w = d.w * rect.width;
    var h = d.h * rect.height;

    _xform.anchoredPosition = new Vector2(x, y);
    _xform.SetSizeWithCurrentAnchors(RectTransform.Axis.Horizontal, w);
    _xform.SetSizeWithCurrentAnchors(RectTransform.Axis.Vertical, h);

    var name = _labels[(int)d.classIndex];
    _label.text = $"{name} {(int)(d.score * 100)}%";

    var hue = d.classIndex * 0.073f % 1.0f;
    var color = Color.HSVToRGB(hue, 1, 1);

    _panel.color = color;

    transform.localScale = Vector3.one;
}

源码

https://download.csdn.net/download/qq_33789001/90242899

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951407.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

123.【C语言】数据结构之快速排序挖坑法和前后指针法

目录 1.挖坑法 执行流程 代码 运行结果 可读性好的代码 2.前后指针法(双指针法) 执行流程 单趟排序代码 将单趟排序代码改造后 写法1 简洁的写法 3.思考题 1.挖坑法 执行流程 "挖坑法"顾名思义:要有坑位,一开始将关键值放入临时变量key中,在数组中形成…

国产信创实践(国能磐石服务器操作系统CEOS +东方通TongHttpServer)

替换介绍&#xff1a; 国能磐石服务器操作系统CEOS 对标 Linux 服务器操作系统&#xff08;Ubuntu, CentOS&#xff09; 东方通TongHttpServer 对标 Nginx 负载均衡Web服务器 第一步&#xff1a; 服务器安装CEOS映像文件&#xff0c;可直接安装&#xff0c;本文采用使用VMware …

Linux中SSH服务(二)

一、基于公私钥的认证&#xff08;免密登录&#xff09; 1、Windows免密登录Linux Windows推荐安装Cygwin软件&#xff1a;Cygwin 1.1Windows上面生成公私钥 之前已经生成过了&#xff0c;所以显示公私钥已存在 lovezywLAPTOP-AABHB5ED ~ $ ssh-keygen Generating public/pr…

Linux-----进程通讯(管道Pipe)

目录 进程不共享内存 匿名管道 通过匿名管道实现通讯 有名管道 库函数mkfifo() 案例 进程不共享内存 不同进程之间内存是不共享的。是相互独立的。 #include <stdio.h> #include <stdlib.h> #include <errno.h>int num 0;int main(int argc, char con…

[工具]git克隆远程仓库到本地快速操作流程

一、新建空目录 二、初始化本地仓库 git init 初始化成功后&#xff0c;会在当前目录生成一个.git的目录。 三、关联远程仓库 git remote add origin <URL>这一步让本地仓库与远程仓库进行关联&#xff0c;origin是远程仓库的别名&#xff0c;可以自定义。 四、克隆…

机器学习之贝叶斯分类器和混淆矩阵可视化

贝叶斯分类器 目录 贝叶斯分类器1 贝叶斯分类器1.1 概念1.2算法理解1.3 算法导入1.4 函数 2 混淆矩阵可视化2.1 概念2.2 理解2.3 函数导入2.4 函数及参数2.5 绘制函数 3 实际预测3.1 数据及理解3.2 代码测试 1 贝叶斯分类器 1.1 概念 贝叶斯分类器是基于贝叶斯定理构建的分类…

基于phpstudy快速搭建本地php环境(Windows)

好好生活&#xff0c;别睡太晚&#xff0c;别爱太满&#xff0c;别想太多。 2025.1.07 声明 仅作为个人学习使用&#xff0c;仅供参考 对于CTF-Web手而言&#xff0c;本地PHP环境必不可少&#xff0c;但对于新手来说从下载PHP安装包到配置PHP环境是个非常繁琐的事情&#xff0…

张朝阳惊现CES展,为中国品牌 “代言”的同时,或将布局搜狐新战略!

每年年初&#xff0c;科技圈的目光都会聚焦在美国拉斯维加斯&#xff0c;因为这里将上演一场被誉为 “科技春晚” 的年度大戏 ——CES 国际消费电子展。作为全球规模最大、最具影响力的科技展会之一&#xff0c;CES 吸引了来自 160 多个国家的创新者和行业领导者&#xff0c;是…

Ollama VS LocalAI:本地大语言模型的深度对比与选择指南

随着人工智能技术的快速发展&#xff0c;大语言模型逐渐成为多个行业的重要工具。从生成内容到智能问答&#xff0c;大模型展现了强大的应用潜力。然而&#xff0c;云端模型的隐私性、使用成本和网络依赖等问题也促使更多用户关注本地化解决方案。Ollama 和 LocalAI 是近年来备…

【C++】B2101 计算矩阵边缘元素之和

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 &#x1f4af;前言&#x1f4af;题目背景与描述题目描述输入格式输出格式输入输出样例说明与提示 &#x1f4af;分析与解决方案解法一&#xff1a;我的做法代码实现解题思路优点与局限性 解法二&#xff1…

保护性暂停原理

什么是保护性暂停&#xff1f; 保护性暂停&#xff08;Guarded Suspension&#xff09;是一种常见的线程同步设计模式&#xff0c;常用于解决 生产者-消费者问题 或其他需要等待条件满足后再继续执行的场景。通过这种模式&#xff0c;一个线程在执行过程中会检查某个条件是否满…

穷举vs暴搜vs深搜vs回溯vs剪枝系列一>字母大小写全排列

题目&#xff1a; 解析&#xff1a; 代码&#xff1a; private List<String> ret;private StringBuffer path;public List<String> letterCasePermutation(String s) {ret new ArrayList<>();path new StringBuffer();dfs(s,0);return ret;}private voi…

解决nginx多层代理后应用部署后访问发现css、js、图片等样式加载失败

一般是采用前后端分离部署方式&#xff0c;被上一层ng代理后&#xff0c;通过域名访问报错&#xff0c;例如&#xff1a;sqx.com.cn/应用代理路径。 修改nginx配置&#xff0c;配置前端页面的路径&#xff1a; location / {proxy_pass http://前端页面所在服务器的IP:PORT;pro…

前端-计算机网络篇

一.网络分类 1.按照网络的作用范围进行分类 &#xff08;1&#xff09;广域网WAN(Wide Area Network) 广域网的作用范围通常为几十到几千公里,因而有时也称为远程网&#xff08;long haul network&#xff09;。广域网是互联网的核心部分&#xff0c;其任务是长距离运送主机…

挑战20天刷完leecode100

2025.1.5 二分查找 1 搜索插入位置 就是简单的二分查找 注意开闭就行 这里有一句话就是nums是升序的 如果他不是严格递增 就是有相同的数字的情况下应该怎么写? int lower_bound(vector<int>& nums, int target) {int left 0, right (int) nums.size() - 1; …

Android原生开发同一局域网内利用socket通信进行数据传输

1、数据接收端代码如下&#xff0c;注意&#xff1a;socket 接收信息需要异步运行&#xff1a; // port 端口号自定义一个值&#xff0c;比如 8888&#xff0c;但需和发送端使用的端口号保持一致 ServerSocket serverSocket new ServerSocket(port); while (true) {//这里为了…

Linux 获取文本部分内容

Linux获取文本部分内容 前言场景获取前几行内容获取末尾几行内容获取中间内容head 命令 tail 命令 结合sed 命令awk 命令 前言 test.log 文本内容如下&#xff1a; &#xff08;注意&#xff1a;内容 a1004和a1005之间有一空行&#xff09; [rootgaussdb002 tmp]# cat test.…

常见的端口号大全,2025年整理

端口号是网络通信的基础&#xff0c;它定义了不同服务的入口和出口。了解服务端口号不仅有助于网络配置&#xff0c;还能提升问题排查效率。在实际应用中&#xff0c;熟悉常见端口号可以帮助你快速定位网络故障、优化服务性能&#xff0c;并确保网络安全。 一、常见的网络服务…

音视频入门基础:MPEG2-PS专题(6)——FFmpeg源码中,获取PS流的视频信息的实现

音视频入门基础&#xff1a;MPEG2-PS专题系列文章&#xff1a; 音视频入门基础&#xff1a;MPEG2-PS专题&#xff08;1&#xff09;——MPEG2-PS官方文档下载 音视频入门基础&#xff1a;MPEG2-PS专题&#xff08;2&#xff09;——使用FFmpeg命令生成ps文件 音视频入门基础…

【Arthas命令实践】heapdump实现原理

&#x1f3ae; 作者主页&#xff1a;点击 &#x1f381; 完整专栏和代码&#xff1a;点击 &#x1f3e1; 博客主页&#xff1a;点击 文章目录 使用原理 使用 dump java heap, 类似 jmap 命令的 heap dump 功能。 【dump 到指定文件】 heapdump arthas-output/dump.hprof【只 …