YOLO11改进|注意力机制篇|引入ELA注意力机制

在这里插入图片描述

目录

    • 一、【ELA】注意力机制
      • 1.1【ELA】注意力介绍
      • 1.2【ELA】核心代码
    • 二、添加【ELA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【ELA】注意力机制

1.1【ELA】注意力介绍

在这里插入图片描述

这篇论文的作者通过分析Coordinate Attention(CA) method的局限性,确定了Batch Normalization中泛化能力的缺乏、降维对通道注意力的不利影响以及注意力生成过程的复杂性。为了克服这些挑战,提出了结合一维卷积和Group Normalization特征增强技术。这种方法通过有效地编码两个一维位置特征图,无需降维即可精确定位感兴趣区域,同时允许轻量级实现。与2D卷积相比,1D卷积更适合处理序列信号,并且更轻量、更快。GN与BN相比,展现出可比较的性能和更好的泛化能力。与 CA 类似,ELA 采用strip pooling在空间维度上获取水平和垂直方向的特征向量,保持窄核形状以捕获长程依赖关系,防止不相关区域影响标签预测,从而在各自方向上产生丰富的目标位置特征。ELA 针对每个方向独立处理上述特征向量以获得注意力预测,然后使用点乘操作将其组合在一起,从而确保感兴趣区域的准确位置信息。由下图可以看出ELA相比于CA显得更简单,所以模型也就更轻
在这里插入图片描述

1.2【ELA】核心代码

import torch
import torch.nn as nn


class ELA(nn.Module):
    def __init__(self, in_channels, phi='T'):
        super(ELA, self).__init__()
        '''
        ELA-T 和 ELA-B 设计为轻量级,非常适合网络层数较少或轻量级网络的 CNN 架构
        ELA-B 和 ELA-S 在具有更深结构的网络上表现最佳
        ELA-L 特别适合大型网络。
        '''
        Kernel_size = {'T': 5, 'B': 7, 'S': 5, 'L': 7}[phi]
        groups = {'T': in_channels, 'B': in_channels, 'S': in_channels // 8, 'L': in_channels // 8}[phi]
        num_groups = {'T': 32, 'B': 16, 'S': 16, 'L': 16}[phi]
        pad = Kernel_size // 2
        self.con1 = nn.Conv1d(in_channels, in_channels, kernel_size=Kernel_size, padding=pad, groups=groups, bias=False)
        self.GN = nn.GroupNorm(num_groups, in_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        b, c, h, w = input.size()
        x_h = torch.mean(input, dim=3, keepdim=True).view(b, c, h)
        x_w = torch.mean(input, dim=2, keepdim=True).view(b, c, w)
        x_h = self.con1(x_h)  # [b,c,h]
        x_w = self.con1(x_w)  # [b,c,w]
        x_h = self.sigmoid(self.GN(x_h)).view(b, c, h, 1)  # [b, c, h, 1]
        x_w = self.sigmoid(self.GN(x_w)).view(b, c, 1, w)  # [b, c, 1, w]
        return x_h * x_w * input


if __name__ == "__main__":
    # 创建一个形状为 [batch_size, channels, height, width] 的虚拟输入张量
    input = torch.randn(2, 256, 40, 40)
    ela = ELA(in_channels=256, phi='T')
    output = ela(input)
    print(output.size())

同时大家可以根据下面代码替换自己想要的型号,只需要将phi=''中的字母替换即可
在这里插入图片描述

二、添加【ELA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个ELA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示

在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【ELA】注意力机制在大目标检测层中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1,1,ELA,[]]


  - [[16, 19, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【ELA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/887252.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java Supplier和Consumer接口

Supplier 在Java中,Supplier接口是一个重要的函数式接口,它属于java.util.function包,Supplier通常用于延迟计算或生成值的场景。Supplier接口是一个泛型接口,其get()方法不接受任何参数但返回一个泛型类型T的值。 这个接口被注解…

STM32新建工程-基于库函数

目录 一、创建一个新工程 二、为工程添加文件和路径 三、创建一个main.c文件,并调试 四、修改一些配置 五、用库函数进行写程序 1、首先加入一些库函数和头文件 2、编写库函数程序 一、创建一个新工程 我这里选择STM32F103C8的型号,然后点击OK。 …

Maven下载、安装与环境配置详解:从零开始搭建高效Java开发环境

下载 官方网站:http://maven.apache.org/ 下载页面:http://maven.apache.org/download.cgi 官网 下载页面 注:本教程使用的是3.3.9版本的maven。 安装 maven安装包下载完成后是一个压缩文件,如下图所示: 我们需要将…

java 数据存储方式

1. 变量存储 这是最基本的数据存储方式,通过声明变量来存储数据。变量可以是基本数据类型(如int、float、char等),也可以是引用数据类型(如对象、数组等)。变量存储的数据通常存储在内存中,随着…

Redis --- 第三讲 --- 通用命令

一、get和set命令 Redis中最核心的两个命令 get 根据key来取value set 把key和value存储进去 redis是按照键值对的方式存储数据的。必须要先进入到redis客户端。 语法 set key value : key和value都是字符串。 对于上述这里的key value 不需要加上引号&#…

【D3.js in Action 3 精译_028】3.4 小节 DIY 实战:使用 Observable 在线绘制 D3 条形图

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一部分 D3.js 基础知识 第一章 D3.js 简介(已完结) 1.1 何为 D3.js?1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践(上)1.3 数据可…

关于Fake Location定位,运动世界校园问题

不好意思,之前那个文章其实是很早之前的,不知道为什么审核了很久一直没有通过,然后前几周莫名其妙点了一下重新发布,竟然发布成功了,这个方法已经失效了,要可以稳定,我建议是买一台root的手机&a…

Discord:报错:A fatal Javascript error occured(解决办法)

按 Windows 键 R 并输入 %appdata% 选择 discord 文件夹并将其删除。 再次按 Windows 键 R 并输入 %LocalAppData% 选择 discord 文件夹并再次将其删除。 附加: 如果还不行,就通过官网下载吧,这个问题通过epic下载可能会有

初识算法 · 滑动窗口(1)

目录 前言: 长度最小的子数组 题目解析 算法原理 算法编写 无重复长度的最小字符串 题目解析 算法原理 算法编写 前言: 本文开始,介绍的是滑动窗口算法类型的题目,滑动窗口本质上其实也是双指针,但是呢&#…

算法笔记(七)——哈希表

文章目录 两数之和判定是否互为字符重排存在重复元素存在重复元素 II字母异位词分组 哈希表:一种存储数据的容器; 可以快速查找某个元素,时间复杂度O(1); 当频繁查找某一个数时,我们可以使用哈希表 创建一个容器&#…

YOLOv4和Darknet实现坑洼检测

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝,拥有2篇国家级人工智能发明专利。 社区特色…

插画共享系统小程序的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,插画信息管理,基础数据管理,论坛管理,公告信息管理,轮播图信息管理 微信端账号功能包括:系统首页,插画信…

【JAVA开源】基于Vue和SpringBoot的服装生产管理系统

本文项目编号 T 066 ,文末自助获取源码 \color{red}{T066,文末自助获取源码} T066,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

Vue的基本用法及模板语法

Vue.js使用了基于 HTML 的模板语法,允许开发者声明式地将 DOM 绑定至底层 Vue实例的数据。所有 Vue.js的模板都是合法的 HTML,所以能被遵循规范的浏览器和 HTML 解析器解析。 在底层的实现上,Vue将模板编译成虚拟 DOM 渲染函数。结合响应系…

10.2 Linux_进程_进程相关函数

创建子进程 函数声明如下: pid_t fork(void); 返回值:失败返回-1,成功返回两次,子进程获得0(系统分配),父进程获得子进程的pid 注意:fork创建子进程,实际上就是将父进程复制一遍作为子进程&…

【基础算法总结】链表篇

目录 一, 链表常用技巧和操作总结二,算法原理和代码实现2.两数相加24.两两交换链表中的节点143.重排链表23.合并k个升序链表25.k个一组翻转链表 三,算法总结 一, 链表常用技巧和操作总结 有关链表的算法题也是一类常见并且经典的题…

STM32-HAL库驱动DHT11温湿度传感器 --2024.9.28

目录 一、教程简介 二、驱动原理讲解 (一)通信4步骤 (二)传感器数据解析 三、CubeMX生成底层代码 (一)基础配置 (二)配置DHT11的驱动引脚 (三)配置串口 四…

pytest(三)——参数化@pytest.mark.parametrize

目录 前言 参数化场景 实际Web UI自动化中的开发场景,比如是一个登录框 parametrize单参数 “笛卡尔积”,多个参数化装饰器 重点知识 参考文献 前言 pytest.mark.parametrize 允许在测试函数或类中定义多组参数和fixtures pytest_generate_tests 允…

对于基础汇编的趣味认识

汇编语言 机器指令 机器语言是机器指令的集合 机器指令展开来讲就是一台机器可以正确执行的命令 电子计算机的机器指令是一列二进制数字 (计算机将其转变为一列高低电平,使得计算机的电子器件受到驱动,进行运算 寄存器:微处理器…

C(九)while循环 --- 军训匕首操情景

匕首操,oi~oi~oi~~~~~ 接下来的几篇推文,杰哥记录的是三大循环结构的运行流程及其变式。 本篇的主角是while循环。👉 目录: while循环 的组成、运行流程及其变式关键字break 和 continue 在while 循环中的作用while 循环的嵌套题目…