YOLOv8改进 添加可变形注意力机制DAttention

一、Deformable Attention Transformer论文

论文地址:arxiv.org/pdf/2201.00520.pdf

二、Deformable Attention Transformer注意力结构

Deformable Attention Transformer包含可变形注意力机制,允许模型根据输入的内容动态调整注意力权重。在传统的Transformer中,注意力是通过对查询和键向量之间的点积来确定的,然后将输入嵌入的加权和进行计算。然而,这种方法假设了一个刚性的注意力模式,其中每个查询都会参与固定的一组键。

在可变形注意力转换器中,注意力权重使用学习机制进行计算,该机制可以根据输入学习调整注意力模式。这使得模型能够捕捉输入的不同部分之间更复杂的关系,从而在需要建模长距离依赖或捕捉细粒度细节的任务上提供更好的性能。可变形注意力机制DAttention引入了额外的可学习参数,用于计算注意力权重。这些参数通过反向传播在训练过程中进行学习,使得模型能够调整注意力模式以更好地适应输入数据。

三、代码实现

1、在官方的yolov8包中ultralytics\ultralytics\nn\modules\__init__.py文件中的from .conv import和__all__中加入注意力机制DAttention。

2、在ultralytics\ultralytics\nn\modules\conv.py文件中上边引用包:

import einops
from timm.models.layers import trunc_normal_
import torch.nn.functional as F

__all__中同样添加DAttention:

并在该conv.py文件中输入DAttention的代码:

 #########            添加DAttention        #####################
class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')
class DAttention(nn.Module):

    def __init__(
            self, q_size=(224, 224), kv_size=(224, 224), n_heads=8, n_head_channels=32, n_groups=1,
            attn_drop=0.0, proj_drop=0.0, stride=1,
            offset_range_factor=-1, use_pe=True, dwc_pe=True,
            no_off=False, fixed_pe=False, ksize=9, log_cpb=False
    ):

        super().__init__()
        n_head_channels = int(q_size / 8)
        q_size = (q_size, q_size)

        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0)

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    def forward(self, x):
        x = x
        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous()  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
                grid=pos[..., (1, 0)],  # y, x -> x, y
                mode='bilinear', align_corners=True)  # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)
        # self.proj_k.weight = torch.nn.Parameter(self.proj_k.weight.float())
        # self.proj_k.bias = torch.nn.Parameter(self.proj_k.bias.float())
        # self.proj_v.weight = torch.nn.Parameter(self.proj_v.weight.float())
        # self.proj_v.bias = torch.nn.Parameter(self.proj_v.bias.float())
        # 检查权重的数据类型
        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)

        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k)  # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels,
                                                                              H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                        q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                               n_sample,
                                                                                               2).unsqueeze(1)).mul(
                    4.0)  # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement)  # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                        q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                               n_sample,
                                                                                               2).unsqueeze(1)).mul(
                    0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads,
                                           g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True)  # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))
        h, w = pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

        return y

3、在 ultralytics\ultralytics\nn\tasks.py文件中开头引入DAttention。

并在该文件 def parse_model模块中加入DAttention注意力机制代码:

        elif m in {DAttention}:
            c2 = ch[f]
            args = [c2, *args]

4、创建yolov8+DAttention的yaml文件:

(可根据自己的需求选择DAttention注意力机制插入的位置,本文以插入yolov8结构中池化层SPPF后边为例)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
  - [-1, 1, DAttention, [[20, 20]]]

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

四、模型验证

可以看出模型中已经包含DAttention注意力机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/280987.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

pytorch03:transforms常见数据增强操作

目录 一、数据增强二、transforms--Crop裁剪2.1 transforms.CenterCrop2.2 transforms.RandomCrop2.3 RandomResizedCrop2.4 FiveCrop和TenCrop 三、transforms—Flip翻转、旋转3.1RandomHorizontalFlip和RandomVerticalFlip3.2 RandomRotation 四、transforms —图像变换4.1 t…

[Angular] 笔记 23:Renderer2 - ElementRef 的生产版本

chatgpt: Renderer2 简介 在 Angular 中,Renderer2 是一个服务,用于处理 DOM 操作的抽象层。它提供了一种安全的方式来操作 DOM,同时与平台无关,有助于维护应用程序的跨浏览器兼容性和安全性。 Renderer2 的作用是在 Angular 组…

【vim 学习系列文章 3.1 -- vim 删除 ^M】

请阅读【嵌入式开发学习必备专栏 之 VIM 专栏】 文章目录 ^M 来源^M 删除 ^M 来源 在 Vim 中打开文件时,您可能会遇到行尾的 ^M 字符,这通常是因为文件使用了 Windows 风格的回车换行符(CRLF),而不是 Unix/Linux 风格…

集合基础知识点

集合基础 1. 集合的由来 当 Java 程序中需要存放数据的时候,通常会定义变量来实现数据的存储,但是,当需要存储大量数据的时候该怎么办呢?这时首先想到的是数组,但是!数组只能存放同一类型的数据&#xff…

【计算机网络】第五,六章摘要重点

1.运输层协议概述 运输层提供的是进程之间的通信 2. 3.套接字指的是什么 ip地址端口号 4.每一条TCP语句唯一地被通信两端连接的两个端点 5.TCP传输如何实现 以字节为单位的滑动窗口 超时重传 选择确认 6.TCP流量控制和拥塞控制的区别 流量控制:点对点通信…

Android Studio如何创建尺寸大小及API通用的模拟器

目录 前言 一、操作步骤 二、总结 三、更多资源 前言 在开发移动应用程序的过程中,使用模拟器进行测试是一种常见和方便的方式。Android Studio是一款功能强大的集成开发环境,它提供了创建和管理模拟器的功能。在本文中,我们将介绍如何创…

XTU-OJ-1452-完全平方数-笔记

参考博客 XTU-OJ 1452-完全平方数 题意 输入一个奇数&#xff0c;使得 n*(2*an-1)/2是一个完全平方数&#xff0c;求满足条件的最小的a 1<n<1e9 先输入样例数&#xff0c;再输入n 输入 2 1 3 输出 0 2 代码 #include<stdio.h>#define N 1000000010int a…

SpringIOC之ApplicationObjectSupport

博主介绍&#xff1a;✌全网粉丝5W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

数据库视图

使用视图就是执行视图的select操作。 为什么要使用视图&#xff1f; 安全性&#xff1a;只开放表中的部分数据&#xff08;只能看到视图开放的列&#xff09;便捷性&#xff1a;可以将复杂的多表关联查询放到视图中&#xff0c;屏蔽底层的复杂性。 create view viewa_name as…

2023年03月20日_对李开复3月20日线下媒体会的解读

最近这个AI大模型 因为GPT4.0 ChatGPT 文心一言等等这些事情呢 一下子就被推到了风口浪尖 我们也做了来介绍相关的进展 国内呢也不断有一些大佬开始下场 包括王慧文、张朝阳、李彦宏什么的 都开始说自己要搞AI大模型 就在昨天呢 创新工厂的董事长兼CEO李开复 也发朋友…

2022第十届中国互联网测试开发大会(MTSC2022)-核心PPT资料下载

一、峰会简介 MTSC中国互联网测试开发大会 Make Tester Super Cool&#xff08;简称 MTSC&#xff09;是由国内最大的移动测试技术社区 TesterHome 发起的软件测试行业技术会议&#xff0c;大会以“软件质量保障体系和测试研发技术交流”为主要目的。MTSC 大会于 2015 年举办第…

产品经理学习-怎么写PRD文档

目录 瀑布流方法论介绍 产品需求文档&#xff08;PRD&#xff09;介绍 产品需求文档的基本要素 撰写产品需求文档 优先产品需求文档的特点 其他相关文档 瀑布流方法论介绍 瀑布流模型是一种项目的开发和管理的方法论&#xff0c;是敏捷的开发管理方式相对应的另一种方法…

前端算法之双指针之快慢指针(Floyd 判圈法)

双指针与快慢指针快慢指针&#xff08;Floyd 判圈法&#xff09; 简介推导 在链表中&#xff0c;快指针和慢指针都可以指向头节点&#xff0c;然后根据问题的要求进行移动。 快指针通常会比慢指针移动得更快&#xff0c;例如每次移动两步&#xff0c;而慢指针则每次移动一步。…

Linux之磁盘分区,挂载

Linux分区 分区介绍 对linux来说无论有几个分区&#xff0c;分给哪个目录使用&#xff0c;归根结底只有一个根目录&#xff0c;linux中每个分区都是用来组成整个文件系统的一部分。linux采用“载入"的处理方法&#xff0c;他的整个文件系统中包含一整套的文件和目录&…

68.乐理基础-打拍子-大附点与变体

上一节内容&#xff1a;66.乐理基础-打拍子-小切分-CSDN博客&#xff0c;只所以没有67因为67可以不用知道&#xff0c;67节内容在&#xff1a;※-打拍子&#xff08;8&#xff09;-一拍内的变体1-乐理教程-腾讯课堂 (qq.com) 大附点&#xff1a;大附点这个名字不是通用的&…

如何使用python脚本生成redis格式的数据包

用python脚本生成redis格式的数据包 &#xff08;1&#xff09;使用下述网站下载开源的生成gopher协议规则的包的工具 https://github.com/firebroo/sec_tools/tree/master/redis-over-gopher &#xff08;2&#xff09;首先要修改redis.cmd中的内容 flushall config set di…

【面试】 Maven 的八大核心概念

Maven 的八大核心概念 在这里&#xff0c;举出这个标题&#xff0c;自然大家知道Maven是干啥的&#xff0c;就不过多进行赘述&#xff01;我们主要对于Maven的八大核心概念做一个解释补充&#xff0c;这也是我自己的一个学习历程&#xff0c;我们一起共勉&#xff01; 文章概述…

【网络安全常用术语解读】SCAP详解

本文主要介绍什么是SCAP&#xff0c;SCAP的产生背景是怎样的&#xff0c;SCAP有什么用途&#xff0c;有哪些组件&#xff0c;各个组件的用途是什么&#xff1f; SCAP产生背景 由于计算机和网络技术的快速发展&#xff0c;越来越多的软件和系统被应用到企业和机构中&#xff0c…

一文掌握Java注解之@SpringBootApplication知识文集(1)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

编程羔手解决Maven引入多个版本的依赖包,导致包冲突了

最近升级了些依赖发现有个hutool的方法老报错&#xff0c;java.lang.NoSuchMethodError: cn.hutool.core.util.ObjectUtil.defaultIfNull(Ljava/lang/Object;Ljava/util/function/Supplier;) 在 Maven 项目中&#xff0c;当不同的依赖模块引入 Hutool 的不同版本时&#xff0c…