YOLOv5改进 | Head | 将yolov5的检测头替换为ASFF_Detect

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

在目标检测中,为了解决尺度变化的问题,通常采用金字塔特征表示。然而,对于基于特征金字塔的单次检测器来说,不同特征尺度之间的不一致性是一个主要限制。为此,研究人员提出了一种新颖的、基于数据的策略,用于金字塔特征融合,称为自适应空间特征融合(ASFF)。它学习了一种方法,用以在空间上过滤冲突信息,从而抑制不一致性,提高了特征的尺度不变性,并且几乎不引入额外的推理开销。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. 将ASFF_DETECT代码实现

2.1 ASFF_DETECT添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:Learning Spatial Fusion for Single-Shot Object Detection——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

自适应空间特征融合(ASFF)的主要原理旨在解决单次检测器中不同尺度特征的不一致性问题。具体来说,ASFF通过动态调整来自不同尺度特征金字塔层的特征贡献,确保每个检测对象的特征表示是一致且最优的。以下是ASFF的主要原理:

原理概述

  1. 多尺度特征的融合

    • 传统的特征金字塔网络(FPN)在不同尺度上提取特征,但这些特征在空间位置上可能存在不一致性,导致检测效果不佳。

    • ASFF通过一个自适应融合模块,动态地结合来自不同尺度的特征图,使得每个像素点能够获得来自各个尺度的最优特征表示。

  2. 自适应权重学习

    • ASFF在训练过程中通过一个轻量级的网络结构(如1x1卷积层)学习自适应权重,这些权重用于加权组合来自不同尺度的特征。

    • 这个学习过程是自适应的,即权重会根据输入图像的特征和目标物体的位置进行调整,从而确保融合后的特征在空间和语义上都是最优的。

  3. 特征一致性

    • 通过自适应权重,ASFF能有效地调节各尺度特征的贡献,解决了特征金字塔中不同层次特征之间的空间位置不一致性问题。

    • 这种融合方式不仅增强了特征的一致性,还提高了检测器对各种尺度目标的响应能力。

具体步骤

  1. 特征提取

    输入图像通过基础卷积神经网络(如ResNet)提取特征,并通过特征金字塔网络(FPN)生成不同尺度的特征图。
  2. 权重生成

    对每个尺度的特征图,ASFF使用一个小型网络(如1x1卷积层)生成对应的自适应权重图。
  3. 特征融合

    将不同尺度的特征图与其对应的权重图逐像素相乘,然后进行加权求和,生成最终的融合特征图。
  4. 检测输出

    最终的融合特征图输入到检测头中,生成检测结果(如边界框和类别预测)。

优势

  • 性能提升:通过自适应融合不同尺度的特征,ASFF显著提升了检测精度,特别是在复杂场景和多尺度目标检测任务中。

  • 高效性:ASFF在提高性能的同时,保持了较低的计算开销,仅增加了极少的推理时间,适合实时应用。

ASFF的方法通过动态调整特征贡献,确保每个像素点在不同尺度特征上的最优组合,从而提高了单次检测器的整体检测性能。

2. 将ASFF_DETECT代码实现

2.1 ASFF_DETECT添加到YOLOv5中

 关键步骤一:将下面代码粘贴到/yolov5-6.1/models/yolo.py文件中

class ASFF_Detect(nn.Module):   #add ASFFV5 layer and Rfb 
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter
    export = False  # export mode

    def __init__(self, nc=80, anchors=(), ch=(), multiplier=0.5,rfb=False,inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.l0_fusion = ASFFV5(level=0, multiplier=multiplier,rfb=rfb)
        self.l1_fusion = ASFFV5(level=1, multiplier=multiplier,rfb=rfb)
        self.l2_fusion = ASFFV5(level=2, multiplier=multiplier,rfb=rfb)
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)

    def forward(self, x):
        z = []  # inference output
        result=[]
       
        result.append(self.l2_fusion(x))
        result.append(self.l1_fusion(x))
        result.append(self.l0_fusion(x))
        x=result      
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid() # https://github.com/iscyy/yoloair
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
    
    def _make_grid(self, nx=20, ny=20, i=0):
        d = self.anchors[i].device
        t = self.anchors[i].dtype
        shape = 1, self.na, ny, nx, 2  # grid shape
        y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
        if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
            yv, xv = torch.meshgrid(y, x, indexing='ij')
        else:
            yv, xv = torch.meshgrid(y, x)
        grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5
        anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
        #print(anchor_grid)
        return grid, anchor_grid

2.2 新增yaml文件

关键步骤二在下/yolov5-6.1/models下新建文件 yolov5_ASFF.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, ASFF_Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.3 注册模块

关键步骤三:在yolo.py中注册,

首先在model的类下面添加下面内容,位置如图所示

if isinstance(m, ASFF_Detect):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
            m.anchors /= m.stride.view(-1, 1, 1)
            check_anchor_order(m)
            self.stride = m.stride
            try:
                self._initialize_biases()  # only run once    
                LOGGER.info('initialize_biases done')
            except:
                LOGGER.info('decoupled no biase ')

 然后修改_profile_one_layer函数下的代码为

c = isinstance(m, Detect) or isinstance(m, ASFF_Detect) # is final layer, copy input as inplace fix

 修改后如下图所示

修改_apply的内容

if isinstance(m, Detect) or isinstance(m, ASFF_Detect):

修改后如下

 在parse_model函数中注册模块

内容如下位置如下

elif m is ASFF_Detect:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_ASFF.yaml的路径

建议大家写绝对路径,确保一定能找到

  🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1C98TemcSlia0n4ngAb9guQ?pwd=z6n4

提取码: z6n4 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs

5. 进阶

现在的代码只能适配yolov5s版本,你能将他们扩展为更大的模型吗?

6. 总结

ASFF检测头的核心在于自适应地融合来自不同尺度的特征,以提高单次检测器的精度和鲁棒性。ASFF检测头首先通过基础卷积神经网络提取输入图像的基本特征,并通过特征金字塔网络(FPN)生成多个尺度的特征图。然后,ASFF模块在每个尺度上使用一个轻量级的网络(例如1x1卷积层)生成自适应权重图,这些权重图用来表示各个尺度特征对最终融合特征的贡献。接下来,不同尺度的特征图与对应的权重图逐像素相乘,再进行加权求和,生成一个融合后的特征图,该特征图在空间和语义上都更加一致。最后,这个融合特征图输入到检测头中,用于生成检测结果,包括物体的边界框和类别预测。通过这种自适应的特征融合方法,ASFF检测头有效地解决了不同尺度特征之间的不一致性问题,显著提高了检测精度,同时保持了较低的计算开销,使其适用于实时应用场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/704305.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

凡尔码来访登记卡助力来访安全

来访登记制度是指为了加强对来访人员的管理和安全控制,确保组织内部秩序和安全的一项制度。通过来访登记制度,可以对来访人员的身份进行核实,了解来访目的,并采取相应的安全措施,为组织内部的工作和人员安全提供保障。…

Sass实战运用,如何利用好Sass

Sass(Syntactically Awesome Stylesheets)是一种CSS预处理器,它提供了许多强大的功能,如变量、嵌套规则、混合(Mixins)、函数等,使得CSS的编写更加高效、灵活和易于维护。以下是关于Sass实战运用…

Go基础编程 - 05 - 数组与切片

目录 1. 数组2. 切片2.1. slice 声明、初始化2.2. slice 操作2.3. append() 追加切片、扩容2.4. 字符串和切片 3. Copy4. Array、Slice 内存布局 上一篇:基本类型、常量和变量 1. 数组 数组是同一种类型固定长度的序列(有长度、类型构成)。…

Postgres 正在吞噬数据库世界

Postgres 正在吞噬数据库世界 作者:Ruohang Feng(Vonng)|微信| Medium | 2024-03-04 标签: PostgreSQL生态系统 PostgreSQL 不仅仅是一个简单的关系型数据库,它还是一个数据管理框架,具有席卷整个数据库领…

基于WPF技术的换热站智能监控系统04--实现左侧历史曲线

1、区域划分 左侧分5行,第一行信息标题,第二行历史曲线 2、安装livecharts图表控件 3、引入图表控件命名空间 4、使用控件 5、运行效果 走过路过不要错过,点赞关注收藏又圈粉,共同致富,为财务自由作出贡献

IP地址乱成一团?用Shell一键搞定!

在日常的运维工作中,我们经常需要对各种数据进行处理和分析,其中包括对IP地址的管理和排序。排序后的IP地址列表可以帮助我们更好地进行日志分析、网络流量监控和故障排除。 本文将模拟一个运维场景,展示如何对IP地址进行排序,并探…

Mongodb使用$pop删除数组中的元素

学习mongodb,体会mongodb的每一个使用细节,欢迎阅读威赞的文章。这是威赞发布的第67篇mongodb技术文章,欢迎浏览本专栏威赞发布的其他文章。如果您认为我的文章对您有帮助或者解决您的问题,欢迎在文章下面点个赞,或者关…

编译和连接

目录1. 翻译环境和运行环境2. 翻译环境:预编译编译汇编链接1. 翻译环境和运行环境 在ANSI C 的任何一种实现中,存在两个不同环境。 (1) 翻译环境,在这种环境中源代码被转换为可执行的机器指令(二进制指令)。 (2) 执行环境,它用于实际执行的代…

PostgreSQL 多表连接不同维度聚合统计查询

摘要:在本文中,你将学习到如何使用 PostgreSQL 完全外连接,从两个或多个表中聚合维度统计数据。 文章目录 一、常用的连接类型图示二、数据库表设计示例三、连接查询示例1. inner join 内连接(不能满足维度统计需求)2. full join 完全外连接(满足维度统计需求)一、常用的…

Golang免杀-分离式加载器(传参)AES加密

目录 enc.go 生成: dec.go --执行dec.go...--上线 cs生成个c语言的shellcode. enc.go go run .\enc.go shellcode 生成: --key为公钥. --code为AES加密后的数据, ----此脚本每次运行key和code都会变化. package mainimport ("bytes""crypto/aes"&…

java1.8运行arthas-boot.jar运行报错解决

报错内容 输入java -jar arthas-boot.jar,后报错。 [INFO] JAVA_HOME: D:\developing\jdk\jre1.8 [INFO] arthas-boot version: 3.7.2 [INFO] Can not find java process. Try to run jps command lists the instrumented Java HotSpot VMs on the target system.…

Spring Boot集成antlr实现词法和语法分析

1.什么是antlr? Antlr4 是一款强大的语法生成器工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件。基本上是当前 Java 语言中使用最为广泛的语法生成器工具。Twitter搜索使用ANTLR进行语法分析,每天处理超过20亿次查询&#xff1…

20240612在飞凌的OK3588-C开发板的linux系统下测试以太网

20240612在飞凌的OK3588-C开发板的linux系统下测试以太网 2024/6/12 17:56 欢迎您入坑飞凌的OK3588-C开发板,使用飞凌的预编译的固件:OK3588-linuxfs-img.tar.bz2 Z:\rockdev\update.img tar jxvf OK3588-linuxfs-img.tar.bz2 首先,刷Android…

自己用pip下载好模块啦,但是在pycharm里面不显示?

问题: 今天在cmd里面用pip命令安装第三方模块,最后用pip list 命令发现已经成功安装,但是在pycharm里面用该模块的时候,还是爆红,显示没有该库 。 解决方法: 第一种(项目刚创建)&am…

虚拟声卡实现音频回环

虚拟声卡实现音频回环 一、电脑扬声器播放声音路由到麦克风1. Voicemeeters安装设置2. 音频设备选择 二、回声模拟 一、电脑扬声器播放声音路由到麦克风 1. Voicemeeters安装设置 2. 音频设备选择 以腾讯会议为例 二、回声模拟 选中物理输入设备“Stereo Input 1”和物理输出设…

GUI listbox

GUI listbox (自用笔记) 功能details拆分 同时打开多个文件,可以是不同类型的,在listbox中显示出路径和文件名; 计算每个数据文件(.txt或.dat)掉帧出现的行数,存储到元胞数组&…

Vue10-事件修饰符

一、示例&#xff1a;<a>标签不执行默认的跳转行为 1-1、方式一 <a href"http://www.baidu.com" onclick"event.preventDefault();">点击我</a> 1-2、方式二 1-3、方式三&#xff1a;事件修饰符 二、Vue的六种事件修饰符 2-1、prevent&…

今日早报 每日精选15条新闻简报 每天一分钟 知晓天下事 6月13日,星期四

每天一分钟&#xff0c;知晓天下事&#xff01; 2024年6月13日 星期四 农历五月初八 1、 财政部&#xff1a;将在19日第一次续发行2024年20年期超长期特别国债。 2、 成本低&#xff0c;商载高&#xff0c;我国自主研制HH-100商用无人运输机首飞成功。 3、 四川甘孜州石渠县1…

Mongodb在UPDATE操作中使用$pull操作

学习mongodb&#xff0c;体会mongodb的每一个使用细节&#xff0c;欢迎阅读威赞的文章。这是威赞发布的第68篇mongodb技术文章&#xff0c;欢迎浏览本专栏威赞发布的其他文章。如果您认为我的文章对您有帮助或者解决您的问题&#xff0c;欢迎在文章下面点个赞&#xff0c;或者关…

网站线上模板建设的优缺点

优点&#xff1a; 1.搭建的时间短&#xff0c;在线建站&#xff0c;只需要登录注册然后选择网站模板创建网站即可管理自己的网站后台&#xff0c;就几步操作就可以实现。 2.网站出错率少&#xff0c;因为有很多用户在使用&#xff0c;前期所报出来的问题就已经一一…