YOLOv8改进 | DAttention (DAT)注意力机制实现极限涨点

论文地址: DAT论文地址

官方地址:官方代码的地址

代码地址:文末有修改了官方代码BUG的代码块复制粘贴即可

一、本文介绍

本文给大家带来的是YOLOv8改进DAT(Vision Transformer with Deformable Attention)的教程,其发布于2022年CVPR2022上同时被评选为Best Paper,由此可以证明其是一种十分有效的改进机制,其主要的核心思想是:引入可变形注意力机制和动态采样点(听着是不是和可变形动态卷积DCN挺相似)。同时在网络结构中引入一个DAT计算量由8.9GFLOPs涨到了9.4GFLOPs。本文的讲解主要包含三方面:DAT的网络结构思想、DAttention的代码复现,如何添加DAttention到你的结构中实现涨点,下面先来分享我测试的对比图(因为资源有限,我只用了100张图片的数据集进行了100个epoch的训练,虽然这个实验不能产生确定性的结论,但是可以作为一个参考)。

适用检测对象->各种检测目标都可以使用,并不针对于某一特定的目标有效。

视频讲解->暂未更新

前文回顾->YOLOv8改进有效涨点专栏->持续复现各种最新机制

目录

一、本文介绍

二、DAT的网络结构思想

2.1 DAT的主要思想和改进

2.2 DAT的网络结构图 

2.3 DAT和其他机制的对比

三、DAT即插即用的代码块

四、添加DAT到你的网络中

五、DAT可添加的位置

5.1推荐DAT可添加的位置 

5.2图示DAT可添加的位置 

六、本文总结 


二、DAT的网络结构思想

2.1 DAT的主要思想和改进

DAT(Vision Transformer with Deformable Attention)是一种引入了可变形注意力机制的视觉Transformer,DAT的核心思想主要包括以下几个方面:

  1. 可变形注意力(Deformable Attention):传统的Transformer使用标准的自注意力机制,这种机制会处理图像中的所有像素,导致计算量很大。而DAT引入了可变形注意力机制,它只关注图像中的一小部分关键区域。这种方法可以显著减少计算量,同时保持良好的性能。

  2. 动态采样点:在可变形注意力机制中,DAT动态地选择采样点,而不是固定地处理整个图像。这种动态选择机制使得模型可以更加集中地关注于那些对当前任务最重要的区域。

  3. 即插即用:DAT的设计允许它适应不同的图像大小和内容,使其在多种视觉任务中都能有效工作,如图像分类、对象检测等。

总结:DAT通过引入可变形注意力机制,改进了视觉Transformer的效率和性能,使其在处理复杂的视觉任务时更加高效和准确。

2.2 DAT的网络结构图 

(a) 展示了可变形注意力的信息流。左侧部分,一组参考点均匀地放置在特征图上,这些点的偏移量是由查询通过偏移网络学习得到的。然后,如右侧所示,根据变形点从采样特征中投影出变形的键和值。相对位置偏差也通过变形点计算,增强了输出转换特征的多头注意力。为了清晰展示,图中仅显示了4个参考点,但在实际实现中实际上有更多的点。

(b) 展示了偏移生成网络的详细结构,每层输入和输出特征图的大小都有标注(这个Offset network在网络的代码中需要控制可添加可不添加)。

通过上面的方式产生多种参考点分布在图像上,从而提高检测的效率,最终的效果图如下->

2.3 DAT和其他机制的对比

DAT与其他视觉Transformer模型和CNN模型中的DCN(可变形卷积网络)的对比图如下,突出了它们处理查询的不同方法(图片展示的很直观,不给大家描述过程了)

三、DAT即插即用的代码块

下面的代码是DAT的网络结构代码,官方的代码中存在许多bug而且参数都未定义,这里我替大家都行了修改而且在使用时无需手动添加任何参数,我都设置了通过模型进行了自动计算,使用方法看章节四。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from timm.models.layers import to_2tuple, trunc_normal_

class LayerNormProxy(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')


class DAttentionBaseline(nn.Module):

    def __init__(
            self, q_size=(224,224), kv_size=(224,224), n_heads=8, n_head_channels=32, n_groups=1,
            attn_drop=0.0, proj_drop=0.0, stride=1,
            offset_range_factor=-1, use_pe=True, dwc_pe=True,
            no_off=False, fixed_pe=False, ksize=9, log_cpb=False
    ):

        super().__init__()
        n_head_channels = int(q_size / 8)
        q_size = (q_size, q_size)

        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0


        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0)

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    def forward(self, x):
        x = x
        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous()  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
                grid=pos[..., (1, 0)],  # y, x -> x, y
                mode='bilinear', align_corners=True)  # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)
        # self.proj_k.weight = torch.nn.Parameter(self.proj_k.weight.float())
        # self.proj_k.bias = torch.nn.Parameter(self.proj_k.bias.float())
        # self.proj_v.weight = torch.nn.Parameter(self.proj_v.weight.float())
        # self.proj_v.bias = torch.nn.Parameter(self.proj_v.bias.float())
        # 检查权重的数据类型
        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)

        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k)  # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels,
                                                                              H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                            q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                                   n_sample,
                                                                                                   2).unsqueeze(1)).mul(
                    4.0)  # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement)  # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                            q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                                   n_sample,
                                                                                                   2).unsqueeze(1)).mul(
                    0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads,
                                           g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True)  # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))
        h, w = pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

        return y

四、添加DAT到你的网络中

添加教程这里不再重复介绍、因为专栏内容有许多,添加过程又需要截特别图片会导致文章大家读者也不通顺如果你已经会添加注意力机制了,可以跳过本章节,如果你还不会,大家可以看我下面的文章,里面详细的介绍了拿到一个任意机制(C2f、Conv、Bottleneck、Loss、DetectHead)如何添加到你的网络结构中去。

添加教程->YOLOv8改进 | 如何在网络结构中添加注意力机制、C2f、卷积、Neck、检测头

五、DAT可添加的位置

5.1推荐DAT可添加的位置 

DAT可以是一种即插即用的注意力机制,其可以添加的位置有很多,添加的位置不同效果也不同,所以我下面推荐几个添加的位,置大家可以进行参考,当然不一定要按照我推荐的地方添加。

  1. 残差连接中:在残差网络的残差连接中加入注意力机制(这个位置我推荐的原因是因为DCN放在残差里面效果挺好的大家可以尝试)

  2. 特征金字塔(SPPF):在特征金字塔网络之前,可以帮助模型更好地融合不同尺度的特征。

  3. Neck部分:YOLOv8的Neck部分负责特征融合,这里添加注意力机制可以帮助模型更有效地融合不同层次的特征。

  4. 输出层前:在最终的输出层前加入注意力机制可以使模型在做出最终预测之前,更加集中注意力于最关键的特征。

大家可能看我描述不太懂,大家可以看下面的网络结构图中我进行了标注。

5.2图示DAT可添加的位置 

六、本文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

本专栏其它内容(持续更新) 

YOLOv8改进 | 如何在网络结构中添加注意力机制、C2f、卷积、Neck、检测头

YOLOv8改进 | ODConv附修改后的C2f、Bottleneck模块代码

YOLOv8改进有效涨点系列->手把手教你添加动态蛇形卷积(Dynamic Snake Convolution)

YOLOv8性能评估指标->mAP、Precision、Recall、FPS、IoU

YOLOv8改进有效涨点系列->适合多种检测场景的BiFormer注意力机制(Bi-level Routing Attention)

 YOLOv8改进有效涨点系列->多位置替换可变形卷积(DCNv1、DCNv2、DCNv3) 

详解YOLOv8网络结构/环境搭建/数据集获取/训练/推理/验证/导出/部署

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/161373.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用 Python进行量化交易:前向验证分析

运行环境:Google Colab 1. 利用 yfinance 下载数据 import yfinance as yfticker AAPL df yf.download(ticker) df下载苹果的股票数据 df df.loc[2018-01-01:].copy()dfdf[change_tomorrow] df[Adj Close].pct_change(-1) df.change_tomorrow df.change_tom…

YOLOv8中训练参数中文解释

预测函数: from ultralytics import YOLO# Load a model model YOLO(yolov8n.pt) # Train the model model.train(datarD:\yolov8\ultralytics-main\data1.yaml, workers0, epochs100, batch16) 可选参数:

el-table树形数据隐藏子选择框

0 效果 1 代码 type是table数据中用来区分一级和二级的标识 // 隐藏子合同选择框 cellNone(row) {if (row.row.type 3 || row.row.type 4) {return "checkNone";} }, <style lang"scss" scoped>::v-deep {.checkNone .el-checkbox__input {displa…

数据结构与算法设计分析——常用搜索算法

目录 一、穷举搜索二、图的遍历算法&#xff08;一&#xff09;深度优先搜索&#xff08;DFS&#xff09;&#xff08;二&#xff09;广度优先搜索&#xff08;BFS&#xff09; 三、回溯法&#xff08;一&#xff09;回溯法的定义&#xff08;二&#xff09;回溯法的应用 四、分…

轻量级 Java 日志组件

日志记录功能在开发中很常用&#xff0c;不仅可以记录程序运行的细节&#xff0c;方便调试&#xff0c;也可以记录用户的行为&#xff0c;是框架中不可或缺的组件。为最大程度复用现有的组件&#xff0c;我们就地取材使用了 JDK 自带的 JUL&#xff08;java.util.logging&#…

学习模拟简明教程【Learning to simulate】

深度神经网络是一项令人惊叹的技术。 有了足够的标记数据&#xff0c;他们可以学习为图像和声音等高维输入生成非常准确的分类器。 近年来&#xff0c;机器学习社区已经能够成功解决诸如对象分类、图像中对象检测和图像分割等问题。 上述声明中的加黑字体警告是有足够的标记数…

手把手教你用C语言写出“走迷宫”小游戏(能看懂文字就会自己敲系列)

目录 设计迷宫地图 设计主角——小球 完整代码 这次教大家编写一个简单的“走迷宫”小游戏&#xff0c;我们可以通过键盘上的‘W’、‘S’、‘A’、‘D’四个键来控制一个“小球”向上&#xff0c;下&#xff0c;左&#xff0c;右移动&#xff0c;目的就是让这个“小球”从起…

Python3语法总结-数据转换②

Python3语法总结-数据转换② Python3语法总结二.Python数据类型转换隐式类型转换显示类型转换 Python3语法总结 二.Python数据类型转换 有时候我们&#xff0c;需要对数据内置的类型进行转换&#xff0c;数据类型的转换。 Python 数据类型转换可以分为两种&#xff1a; 隐式类…

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

文章目录 原型网络进行分类的基本流程一、原始代码---计算欧氏距离&#xff0c;设计原型网络&#xff08;计算原型开始训练&#xff09;二、每一行代码的详细解释总结 原型网络进行分类的基本流程 利用原型网络进行分类&#xff0c;基本流程如下&#xff1a; 1.对于每一个样本…

信号完整性分析基础知识之有损传输线、上升时间衰减和材料特性(十):有损传输线在时域中的表现

如果高频衰减大于低频衰减&#xff0c;随着信号传输&#xff0c;上升时间将会增加。上升时间通常定义为边沿在最终值的 10% 到 90% 之间过渡的时间。这假设信号的边缘轮廓看起来有点高斯分布&#xff0c;中间是最快的斜率区域。对于该波形&#xff0c;10%−90% 的上升时间是有意…

MIB 6.1810实验Xv6 and Unix utilities(4)primes

难度: hard/moderate Write a concurrent prime sieve program for xv6 using pipes and the design illustrated in the picture halfway down this page and the surrounding text. This idea is due to Doug McIlroy, inventor of Unix pipes. Your solution should be in …

让你的Mac体验更便捷,快速启动工具Application Wizard为你助力!

亲爱的Mac用户们&#xff0c;你是否经常感到在繁琐的软件启动过程中浪费了太多时间&#xff1f;你是否希望能够以更快的速度找到并启动你所需的应用程序&#xff1f;如果是的话&#xff0c;那么不要犹豫&#xff0c;让我们来介绍一款强大的软件快速启动工具——Application Wiz…

23年宁波职教中心CTF竞赛-决赛

Web 拳拳组合 进去页面之后查看源码&#xff0c;发现一段注释&#xff0c;写着小明喜欢10的幂次方&#xff0c;那就是10、100、1000、10000 返回页面&#xff0c;在点击红色叉叉的时候抓包&#xff0c;修改count的值为10、100、1000、10000 然后分别获得以下信息 ?count1…

Spring面试题:(八)Spring事务

Spring事务概述 Spring事务基于数据库&#xff0c;基于数据库的事务封装了统一的接口。 编程式事务和声明式事务。 声明式事务分为Xml声明式或者注解声明式 实现事务相关的三个类 事务管理器 事务定义 事务状态 XML声明式事务的使用方法 导入坐标配置目标类配置切面 导入…

JS判断是否存在某个元素(includes、indexOf、find、findeIndex、some)(every 数组内所有值是否相同)

方法一&#xff1a;array.includes(searcElement[,fromIndex]) 此方法判断数组中是否存在某个值&#xff0c;如果存在返回true&#xff0c;否则返回false。 searchElement&#xff1a;需要查找的元素&#xff0c;必选。fromIndex&#xff1a;可选&#xff0c;从该索引处开始查…

浏览器页面被恶意控制时的解决方法

解决360流氓软件控制浏览器页面 提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、接受360安全卫士的好意&#xff08;尽量不要选&#xff09;二、拒绝360安全卫士的好意&#xff08;强烈推荐&#xff09;第…

【Vue渲染】 条件渲染 | v-if | v-show | 列表渲染 | v-for

目录 前言 v-if和v-show的区别和联系 v-show和v-if如何选择 条件渲染|v-if|v-show v-if v-if v-else v-if v-else-if v-else template v-show 列表渲染|v-for v-for 前言 本文介绍Vue渲染&#xff0c;包含条件渲染v-if和v-show的区别和联系以及列表渲染v-for v-if和…

“腾易视连”构建汽车生态新格局 星选计划赋能创作者价值提升

11月16日&#xff0c;在2023年广州国际车展前夕&#xff0c;以“腾易视连&#xff0c;入局视频号抓住增长新机会”为主题的腾易创作者大会在广州隆重举办。此次大会&#xff0c;邀请行业嘉宾、媒体伙伴、生态伙伴、视频号汽车领域原生达人等共济一堂&#xff0c;结合汽车行业数…

轻量级的资源授权:基于 OAuth 规范

了解 OAuth 感觉 OAuth 太负盛名了&#xff0c;以至于后来在 OIDC 反而难以企及前辈 OAuth。倒是大家谈论比较多的是 JWT&#xff08;例如https://www.cnblogs.com/lyzg/p/6132801.html&#xff09;&#xff0c;——实际谈 JWT 就是在实现 OIDC&#xff0c;反而 OIDC 大家不怎…