yolov5模型Detection输出内容与源码详细解读

文章目录

  • 前言
  • 一、Detiction类源码说明
  • 二、Detection类初始化参数解读
  • 三、Detection的训练输出源码解读
  • 四、Detection的预测输出源码解读
    • 1、self.grid内容解读
    • 2、xy/wh内容解读
    • 3、推理输出解读
  • 总结


前言

最近,需要修改yolov5推理结果,通过推理特征添加一些其它操作(如蒸馏)。显然,你需要对yolov5推理输出内容有详细了解,方可被你使用。为此,本文将记录个人对yolov5输出内容源码解读,这样对于你修改源码或蒸馏操作可提供理论参考。


一、Detiction类源码说明

yolov5的detection类输出包含2个部分,一个是训练的输出,一个是预测的输出。而我将在这里解释类参数与训练、预测输出内容。为什么我使用一篇文章来说明?显然是训练与预测输出内容与原始图像尺寸、模型尺寸、特征尺寸以及回归box含义与使用细节。

整体源码如下:

class Detect(nn.Module):
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter

    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid()
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, y[..., 4:]), -1)
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1), x)

二、Detection类初始化参数解读

我已voc数据为例说明,voc数据类别为20。我将其具体解释注释其代码中,如下:

def __init__(self, nc=20, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
	 self.nc = nc  #   =20 类别数量
	 self.no = nc + 5  # =25 每个类别需添加位置与置信度
	 self.nl = len(anchors)  # =3 yolov5检测层为3,每层特征有一个anchor,可使用anchors替代 
	 self.na = len(anchors[0]) // 2  # =3 获得每个grid的anchor数量
	 self.grid = [torch.zeros(1)] * self.nl  # =[tensor([0.]), tensor([0.]), tensor([0.])],初始化grid,后期会每个特征变成[1,3,feature_w,feature_h,2],共三个
	 self.anchor_grid = [torch.zeros(1)] * self.nl  # 初始化anchor grid,与上面self.grid类似,直白说就是占位初始化
	 self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # 将anchors值变成shape(3,3,2)与对应格式
	 self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # 3个模型数据输出转换,每个特征图输出转换,
	 self.inplace = inplace  # use in-place ops (e.g. slice assignment)

anchors值如下:

  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

self.m模块:

ModuleList(
  (0): Conv2d(320, 75, kernel_size=(1, 1), stride=(1, 1))
  (1): Conv2d(640, 75, kernel_size=(1, 1), stride=(1, 1))
  (2): Conv2d(1280, 75, kernel_size=(1, 1), stride=(1, 1))
)

其中ch是输入,与模型图像大小关联。

三、Detection的训练输出源码解读

为便于快速理解模型训练输出内容,我直接去掉推理输出代码,这样一目了然知道模型输出内容。

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # 预测输出
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()    

        return x 

输出x值为三个特征层的值输出,都表示[batch,3,h,w,cls+5],含义为batch size图像的h*w个像素有3个预测,每个预测都是类别概率、置信度与box位置。其输出如下:

在这里插入图片描述
注:特别说明,x输出都没有经过sigmoid函数。

四、Detection的预测输出源码解读

为便于快速理解模型推理输出内容,我直去掉不必要判断,给出推理输出代码,这样一目了然知道模型输出内容。

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

          	 # inference
             if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                 self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

             y = x[i].sigmoid()
             if self.inplace:
                 y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                 y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
             else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                 xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                 wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                 y = torch.cat((xy, wh, y[..., 4:]), -1)
             z.append(y.view(bs, -1, self.no))

        return   (torch.cat(z, 1), x)

1、self.grid内容解读

grid为网格,可以理解是对应特征图的网格,以第一个特征宽高为80说明,grid[0]存的是像素位置坐标如下tensort示意。相当于,每个batch的高宽对应像素重复3次分别从0到79的x坐标和y坐标得到grid。而grid有三个特征,其它二个特征图也是同样原理,如下图。

在这里插入图片描述

tensor([[[[[ 0.,  0.],
           [ 1.,  0.],
           [ 2.,  0.],
           ...,
           [77.,  0.],
           [78.,  0.],
           [79.,  0.]],

          [[ 0.,  1.],
           [ 1.,  1.],
           [ 2.,  1.],
           ...,
           [77.,  1.],
           [78.,  1.],
           [79.,  1.]],

          [[ 0.,  2.],
           [ 1.,  2.],
           [ 2.,  2.],
           ...,
           [77.,  2.],
           [78.,  2.],
           [79.,  2.]],

          ...,

2、xy/wh内容解读

yolov5在经历一次sigmoid为y = x[i].sigmoid(),获得dx、dy、dw、dh值(我暂时这么理解),在使用以下代码实现在该特征图对应xy值与wh值。特别注意,假设该特征图是8080尺寸,仅是获得在该特征8080层恢复到模型输入特征图像尺寸,还未到原始图像输入尺寸。源码如下:

y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh

这里也说明一下,dw/dh转到对应尺寸是直接乘以self.anchor_grid,而self.anchor_grid保留是anchors值如下:

  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

最终扩展到对应特征图维度,如下显示:
在这里插入图片描述

注:这里知道dx/dy/dw/dh恢复到模型输入尺寸(如640)即可。dw/dh并没有*strid,而是直接*anchors值。

3、推理输出解读

最后将输出内容通过view转换为[batch,-1,25]形式。

z.append(y.view(bs, -1, self.no))

即可获得(torch.cat(z, 1), x)这样的推理输出。我想说x实际还是训练输出内容,没有变化。

推理输出为2个列表,第一个列表就是上面说到的z值,其经历了sigmoid,包含类、置信度、box都使用了sigmoid,且box转为了模型输出对应尺寸位置。第二个列表为三个特征层的值输出,都表示[batch,3,h,w,cls+5],含义为batch size图像的h*w个像素有3个预测,每个预测都是类别概率、置信度与box位置,是没有经历过sigmoid,实际就是训练输出内容。其结果示意如下:

在这里插入图片描述

注:x没有经过sigmoid,而z经过sigmoid后转成模型输出尺寸box。


总结

掌握yolov5的Detection类训练与预测输出内容,有利于对源码更改提供理论依据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/321248.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

青动CRM-E售后 售后工单CRM系统 erp系统 带前端小程序全开源可二开

应用介绍 一款基于FastAdminThinkPHP和uniapp开发的CRM售后管理系统,旨在助力企业销售售后全流程精细化、数字化管理,主要功能:客户、合同、工单、任务、报价、产品、库存、出纳、收费,适用于:服装鞋帽、化妆品、机械机…

change事件传递多个参数

1.传递value页面参数 change"handleChange($event,123)" 2.传递选中的keyvalue或是选中的item 我用的是a-auto-complete,试验了用a-select也可以 就是在option里面,:value"JSON.stringify(d)" 然后在eval(( value ))转化就可…

zepplin记录1

zepplin记录1 文章目录 zepplin记录1前言一、配置python环境二、测试可用性1.配置interpreter2.测试代码 总结 前言 Apache Zeppelin是一个开源的数据分析和可视化的交互式笔记本,类似于Jupyter Notebook。它支持多种编程语言(如Scala、Python、R、SQL等…

智慧园区数字孪生智能可视运营平台解决方案:PPT全文82页,附下载

关键词:智慧园区解决方案,数字孪生解决方案,数字孪生应用场景及典型案例,数字孪生可视化平台,数字孪生技术,数字孪生概念,智慧园区一体化管理平台 一、基于数字孪生的智慧园区建设目标 1、实现…

【技术选型】Doris vs starRocks

比对结论 仅从当前能看到的数据中,相比于doris,starRocks在性能方面具备优势,且更新频率高(降低维护成本)。 目标诉求 并发性不能太低——相比于clickhouse不到100的QPS支持大表关联——降低数据清洗的压力&#xf…

【PlantUML】- 时序图

写在前面 本篇文章,我们来介绍一下PlantUML的时序图。这个相对类图来讲,比较简单,也不需要布局。读完文章,相信你就能实际操作了。 目录 写在前面一、基本概念二、具体步骤1.环境说明2.元素3.语法4.示例 三、参考资料写在后面系列…

Spring Boot 3 + Vue 3实战:引入数据库实现用户登录功能

文章目录 一、实战概述二、实战步骤(一)创建数据库(二)创建用户表(三)后端项目引入数据库1、添加相关依赖2、用户实体类保持不变3、编写应用配置文件4、创建用户映射器接口5、创建用户服务类6、修改登录控制…

【Fiddler抓包】微信扫码访问链接打不开网页

又来每天进步一点点~~~ 背景:某天发版的时候,手机连接电脑抓包查看用户登录之前的sessionID,由于业务需要,是需要用户登录微信扫码跳转至某一页面的,微信(分身)扫码成功,跳转时打不…

HarmonyOS 通过 animateTo讲解尺寸动画效果

上文 HarmonyOS讲解并演示 animateTo 动画效果 我们已经做出了基本的动画效果 也对 animateTo 的使用比较熟悉了 第一个参数是 配置动画参数的json 第二个参数 则是改变我们元素属性值的事件 但属性值 远远不止位置属性 本文 我们来说 通过尺寸变化 完成动画效果 如果你有看过…

指针理解C部分

目录 1.二级指针 2.指针数组 2.1指针数组的定义和表现形式 2.2指针数组模拟实现二维数组 2.2.1二维数组 2.2.2使用指针数组模拟实现二维数组 3.字符指针 2.数组指针 3.二维数组传参 4.函数指针 4.1函数指针变量的定义和创建 4.2函数指针变量的使用 4.3两段有趣的代码 4.…

【NI国产替代】USB‑7846 Kintex-7 160T FPGA,500 kS/s多功能可重配置I/O设备

Kintex-7 160T FPGA,500 kS/s多功能可重配置I/O设备 USB‑7846具有用户可编程FPGA,可用于高性能板载处理和对I/O信号进行直接控制,以确保系统定时和同步的完全灵活性。 您可以使用LabVIEW FPGA模块自定义这些设备,开发需要精确定时…

NLP论文阅读记录 - 2022 | WOS 一种新颖的优化的与语言无关的文本摘要技术

文章目录 前言0、论文摘要一、Introduction1.1目标问题1.2相关的尝试1.3本文贡献 二.前提三.本文方法四 实验效果4.1数据集4.2 对比模型4.3实施细节4.4评估指标4.5 实验结果4.6 细粒度分析 五 总结思考 前言 A Novel Optimized Language-Independent Text Summarization Techni…

Linux系统编程(十):线程同步(下)

参考引用 UNIX 环境高级编程 (第3版)嵌入式Linux C应用编程-正点原子 1. 为什么需要线程同步? 线程同步是为了对共享资源的访问进行保护 共享资源指的是多个线程都会进行访问的资源(如:全局变量) 保护的目的是为了解决数据一致性…

前端对接电子秤、扫码枪设备serialPort 串口使用教程

因为最近工作项目中用到了电子秤,需要对接电子秤设备。以前也没有对接过这种设备,当时也是一脸懵逼,脑袋空空。后来就去网上搜了一下前端怎么对接,然后就发现了SerialPort串口。 Serialport 官网地址:https://serialpo…

C# 静态代码织入AOP组件之肉夹馍

写在前面 关于肉夹馍组件的官方介绍说明: Rougamo是一个静态代码织入的AOP组件,同为AOP组件较为常用的有Castle、Autofac、AspectCore等,与这些组件不同的是,这些组件基本都是通过动态代理IoC的方式实现AOP,是运行时…

Mysql-redoLog

Redo Log redo log进行刷盘的效率要远高于数据页刷盘,具体表现如下 redo log体积小,只记录了哪一页修改的内容,因此体积小,刷盘快 redo log是一直往末尾进行追加,属于顺序IO。效率显然比随机IO来的快Redo log 格式 在MySQL的InnoDB存储引擎中,redo log(重做日志)被用…

【EMC专题】浪涌的成因与ICE 61000-4-5标准

什么是浪涌? 浪涌是一种无法预料的瞬态电压或电流尖峰,由附近的电子产品或是环境导致。 了解浪涌非常重要,因为浪涌有可能会导致设备的电气过应力损坏,造成系统故障等。 对于系统设计来说,重要的一点是我们如果无法控制浪涌的产生,那么只能通过将瞬态峰值电流导入到地,…

Mysql查询与更新语句的执行

一条SQL查询语句的执行顺序 FROM&#xff1a;对 FROM 子句中的左表<left_table>和右表<right_table>执行笛卡儿积&#xff08;Cartesianproduct&#xff09;&#xff0c;产生虚拟表 VT1 ON&#xff1a;对虚拟表 VT1 应用 ON 筛选&#xff0c;只有那些符合<join_…

Kafka消费全流程

Kafka消费全流程 1.Kafka一条消息发送和消费的流程图(非集群) 2.三种发送方式 准备工作 创建maven工程&#xff0c;引入依赖 <dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>3.3.1&l…

UDS 诊断通讯

UDS有哪些车型支持 UDS(统一诊断服务)协议被广泛应用于汽车行业中,支持多种车型。具体来说,UDS协议被用于汽车电子控制单元(ECU)之间的通讯,以实现故障诊断、标定、编程和监控等功能。 支持UDS协议的车型包括但不限于以下几种: 奥迪(Audi)车型:包括A3、A4、A5、A6…