神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

  • 学习前言
  • 代码下载
  • CA注意力机制的概念与实现
  • 注意力机制的应用

学习前言

CA注意力机制是最近提出的一种注意力机制,全面关注特征层的空间信息和通道信息。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-keras

复制该路径到地址栏跳转。

CA注意力机制的概念与实现

请添加图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[H, W, C],在经过宽方向的平均池化后,获得的特征层shape为[H, 1, C],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[1, W, C],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[1, H+W, C],利用卷积+标准化+激活函数获得特征。需要注意的是,这里的卷积通道数一般会小一点,做一个缩放,可以减少参数量。卷积后的特征层的shape为[1, H+W, C/r],其中r为缩放系数。

之后再次分开为两个并行阶段,再将宽高分开成为:[1, H, C/r]和[1, W, C/r],之后进行转置。获得两个特征层[H, 1, C/r]和[1, W, C/r]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况,前者在宽上拓展一下,后者在高上拓展一下,然后一起乘上原有的特征就是CA注意力机制。

实现的python代码为:

def ca_block(input_feature, ratio=16, name=""):
	channel = input_feature._keras_shape[-1]
	h		= input_feature._keras_shape[1]
	w		= input_feature._keras_shape[2]
 
	x_h = Lambda(lambda x: K.mean(x, axis=2, keepdims=True))(input_feature)
	x_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_h)
	x_w = Lambda(lambda x: K.max(x, axis=1, keepdims=True))(input_feature)
	
	x_cat_conv_relu = Concatenate(axis=2)([x_w, x_h])
	x_cat_conv_relu = Conv2D(channel // ratio, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv1_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = BatchNormalization(name = "ca_block_bn_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = Activation('relu')(x_cat_conv_relu)
 
	x_cat_conv_split_h, x_cat_conv_split_w = Lambda(lambda x: tf.split(x, num_or_size_splits=[h, w], axis=2))(x_cat_conv_relu)
	x_cat_conv_split_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_cat_conv_split_h)
	x_cat_conv_split_h = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv2_"+str(name))(x_cat_conv_split_h)
	x_cat_conv_split_h = Activation('sigmoid')(x_cat_conv_split_h)
 
	x_cat_conv_split_w = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv3_"+str(name))(x_cat_conv_split_w)
	x_cat_conv_split_w = Activation('sigmoid')(x_cat_conv_split_w)
 
	output = multiply([input_feature, x_cat_conv_split_h])
	output = multiply([output, x_cat_conv_split_w])
	return output

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention = [se_block, cbam_block, eca_block, ca_block]

#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
def yolo_body(input_shape, anchors_mask, num_classes, phi = 0):
    inputs = Input(input_shape)
    #---------------------------------------------------#
    #   生成CSPdarknet53_tiny的主干模型
    #   feat1的shape为26,26,256
    #   feat2的shape为13,13,512
    #---------------------------------------------------#
    feat1, feat2 = darknet_body(inputs)
    if phi >= 1 and phi <= 4:
        feat1 = attention[phi - 1](feat1, name='feat1')
        feat2 = attention[phi - 1](feat2, name='feat2')

    # 13,13,512 -> 13,13,256
    P5 = DarknetConv2D_BN_Leaky(256, (1,1))(feat2)
    # 13,13,256 -> 13,13,512 -> 13,13,255
    P5_output = DarknetConv2D_BN_Leaky(512, (3,3))(P5)
    P5_output = DarknetConv2D(len(anchors_mask[0]) * (num_classes+5), (1,1))(P5_output)
    
    # 13,13,256 -> 13,13,128 -> 26,26,128
    P5_upsample = compose(DarknetConv2D_BN_Leaky(128, (1,1)), UpSampling2D(2))(P5)
    if phi >= 1 and phi <= 4:
        P5_upsample = attention[phi - 1](P5_upsample, name='P5_upsample')

    # 26,26,256 + 26,26,128 -> 26,26,384
    P4 = Concatenate()([P5_upsample, feat1])
    
    # 26,26,384 -> 26,26,256 -> 26,26,255
    P4_output = DarknetConv2D_BN_Leaky(256, (3,3))(P4)
    P4_output = DarknetConv2D(len(anchors_mask[1]) * (num_classes+5), (1,1))(P4_output)
    
    return Model(inputs, [P5_output, P4_output])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391252.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

走近 Next.js:全栈框架的简介与应用

微信搜索“好朋友乐平”关注公众号。 1. Next.js Next.js 是一个使用 React 构建单页应用程序&#xff08;SPA&#xff09;的开源 JavaScript 框架。它使得构建服务端渲染&#xff08;SSR&#xff09;和静态网站生成&#xff08;SSG&#xff09;的 React 应用程序变得简单和高…

安全基础~通用漏洞5

文章目录 知识补充CSRFSSRFxss与csrf结合创建管理员账号 知识补充 NAT&#xff1a;网络地址转换&#xff0c;可以将IP数据报文头中的IP地址转换为另一个IP地址&#xff0c;并通过转换端口号达到地址重用的目的。即通过将一个外部IP地址和端口映射更大的内部IP地址集来转换IP地…

人工智能学习与实训笔记(七):神经网络之模型压缩与知识蒸馏

人工智能专栏文章汇总&#xff1a;人工智能学习专栏文章汇总-CSDN博客 本篇目录 七、模型压缩与知识蒸馏 7.1 模型压缩 7.2 知识蒸馏 7.2.1 知识蒸馏的原理 7.2.2 知识蒸馏的种类 7.2.3 知识蒸馏的作用 七、模型压缩与知识蒸馏 出于对响应速度&#xff0c;存储大小和能…

鸿蒙应用开发工程师招聘多吗?工资有多少呢?

随着鸿蒙操作系统的快速普及&#xff0c;越来越多的企业开始重视鸿蒙应用开发人才的培养和引进。那么&#xff0c;目前市场上鸿蒙应用开发工程师招聘多吗&#xff1f;工资有多少呢&#xff1f; 首先&#xff0c;我们来了解一下鸿蒙应用开发工程师的招聘情况。随着鸿蒙操作系统…

IEEE Internet of Things Journal投稿经验

期刊名&#xff1a;IEEE Internet of Things Journal 期刊分区&#xff1a;中科院一区 Top 影响因子&#xff1a;10.6 投稿状态 &#xff08;1&#xff09;2023.11.3&#xff0c;投稿成功&#xff0c;状态为&#xff1a;under review&#xff08;大u大r&#xff09;&#xff1…

无心剑小诗《爱的迷宫》

爱的迷宫 在心海深处悄然铺陈 情感线索纷乱中交织缠绵 一道道的拐角&#xff0c;星云的光辉 指引渴望的心追寻那深邃的真理 曲折通道&#xff0c;每次心跳 爱的重奏在无边探寻里 每一寸肌肤与每一根神经 都铭记你的气息&#xff0c;你的轮廓 迷宫的幻影&#xff0c;无尽的梦…

掌上新闻随心播控,HarmonyOS SDK助力新浪新闻打造精致易用的资讯服务新体验

原生智能是HarmonyOS NEXT的核心亮点之一&#xff0c;依托HarmonyOS SDK丰富全面的开放能力&#xff0c;开发者只需通过几行代码&#xff0c;即可快速实现AI功能。新浪新闻作为鸿蒙原生应用开发的先行者之一&#xff0c;从有声资讯入手&#xff0c;将基于Speech Kit朗读控件上线…

什么软件可以监控电脑屏幕?

随着信息技术的飞速发展&#xff0c;电脑已成为企业日常运营不可或缺的工具。然而&#xff0c;这也带来了一系列的管理挑战&#xff0c;如何确保员工高效工作、避免信息泄露、监控不当行为等成为了企业迫切需要解决的问题。在这种背景下&#xff0c;电脑屏幕监控软件应运而生&a…

基于Springboot+Vue实现的宿舍管理系统

基于SpringbootVue的宿舍管理系统 1.系统相关性介绍1.1 系统架构1.2 设计思路 2.功能模块介绍2.1 用户信息模块2.2 宿舍管理模块2.3 信息管理模块 3. 源码获取以及远程部署 前言&#xff1a; 在现代教育环境中&#xff0c;学生宿舍的管理显得尤为重要&#xff0c;需要一套能…

明年再见了,

过年家里待了好天&#xff0c;也停更了好些天&#xff0c;明天准备启程回深。 年前跟朋友kt聊天&#xff0c;他先是跟我说了他的投资方向「这是一件非常隐秘的事情」&#xff0c;也说了他对我的建议&#xff0c;2024年也想有新的计划&#xff0c;今年也计划减少更新文章频率&am…

DNS出现问题了,怎么处理?-提供完整解决方案

DNS作用 将域名(网址)解析为IP地址 DNS的作用是将域名(网址)解析为IP地址,方便用户访问互联网。通过DNS,用户可以轻松地通过域名来获取对应的IP地址,无需记住复杂的数字串。 负载均衡 负载均衡是DNS的一种功能,它能够将访问请求转发到不同的服务器,从而实现负载均衡。…

【网站项目】154智能无人仓库管理

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

2024年通信安全员ABC证证模拟考试题库及通信安全员ABC证理论考试试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年通信安全员ABC证证模拟考试题库及通信安全员ABC证理论考试试题是由安全生产模拟考试一点通提供&#xff0c;通信安全员ABC证证模拟考试题库是根据通信安全员ABC证最新版教材&#xff0c;通信安全员ABC证大纲整理…

python-自动化篇-运维-网络-IP

文章目录 IP自我介绍IPy安装模块windowsLinux IPy介绍支持大多数 IP 地址格式IPv4 地址IPv6 地址网络掩码和前缀 派生网络地址将地址转换为字符串使用多个网络多网络计算方法 IP自我介绍 IP地址规划是网络设计中非常重要的一个环节&#xff0c;规划的好坏会直接影响路由协议算…

(07)Hive——窗口函数详解

一、 窗口函数知识点 1.1 窗户函数的定义 窗口函数可以拆分为【窗口函数】。窗口函数官网指路&#xff1a; LanguageManual WindowingAndAnalytics - Apache Hive - Apache Software Foundationhttps://cwiki.apache.org/confluence/display/Hive/LanguageManual%20Windowing…

使用 RAG 创建 LLM 应用程序

如果您考虑为您的文件或网站制作一个能够回应您的个性化机器人&#xff0c;那么您来对地方了。我可以帮助您使用Langchain和RAG策略来创建这样一个机器人。 了解ChatGPT的局限性和LLMs ChatGPT和其他大型语言模型&#xff08;LLMs&#xff09;经过广泛训练&#xff0c;以理解…

Redis篇----第三篇

系列文章目录 文章目录 系列文章目录前言一、缓存雪崩**二、缓存穿透三、缓存预热四、缓存更新五、缓存降级 前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站&#xff0c;这篇文章男女…

【漏洞复现】蓝网科技临床浏览系统信息泄露漏洞

Nx01 产品简介 蓝网科技临床浏览系统是一个专门用于医疗行业的软件系统&#xff0c;主要用于医生、护士和其他医疗专业人员在临床工作中进行信息浏览、查询和管理。 Nx02 漏洞描述 蓝网科技临床浏览系统存在信息泄露漏洞&#xff0c;攻击者可以利用该漏洞获取敏感信息。 Nx03…

MySQL数据表的约束

已经一个星期没更了&#xff0c;因为过年嘛&#xff0c;比较忙&#xff0c;我玩心也大&#xff0c;就没继续更新&#xff0c;在这里给大家道歉&#xff0c;也祝大家新的一年快快乐乐&#xff0c;新年快乐。 为防止数据表中插入错误的数据&#xff0c;MySQL定义了一些规则维护数…