【深度学习】注意力机制(二)

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。

【深度学习】注意力机制(一)

【深度学习】注意力机制(三)

目录

一、EA(External Attention)

二、Multi Head Self Attention

三、SK(Selective Kernel Networks)

四、DA(Dual Attention)

五、EPSA(Efficient Pyramid Squeeze Attention)


一、EA(External Attention)

EA可以关注全局的空间信息,论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init

class External_attention(nn.Module):
    '''
    Arguments:
        c (int): The input and output channel number.
    '''
    def __init__(self, c):
        super(External_attention, self).__init__()
        
        self.conv1 = nn.Conv2d(c, c, 1)

        self.k = 64
        self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

        self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
        self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            norm_layer(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
 

    def forward(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = F.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x)
        return x

二、Multi Head Self Attention

注意力机制的经典,Transformer的基石。论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out

三、SK(Selective Kernel Networks)

SK是通道注意力机制。论文地址:论文连接

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict



class SKAttention(nn.Module):

    def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):
        super().__init__()
        self.d=max(L,channel//reduction)
        self.convs=nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),
                    ('bn',nn.BatchNorm2d(channel)),
                    ('relu',nn.ReLU())
                ]))
            )
        self.fc=nn.Linear(channel,self.d)
        self.fcs=nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d,channel))
        self.softmax=nn.Softmax(dim=0)



    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs=[]
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats=torch.stack(conv_outs,0)#k,bs,channel,h,w

        ### fuse
        U=sum(conv_outs) #bs,c,h,w

        ### reduction channel
        S=U.mean(-1).mean(-1) #bs,c
        Z=self.fc(S) #bs,d

        ### calculate attention weight
        weights=[]
        for fc in self.fcs:
            weight=fc(Z)
            weights.append(weight.view(bs,c,1,1)) #bs,channel
        attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1
        attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1

        ### fuse
        V=(attention_weughts*feats).sum(0)
        return V

四、DA(Dual Attention)

DA融合了通道注意力和空间注意力机制。论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention

class PositionAttentionModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)
    
    def forward(self,x):
        bs,c,h,w=x.shape
        y=self.cnn(x)
        y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c
        y=self.pa(y,y,y) #bs,h*w,c
        return y


class ChannelAttentionModule(nn.Module):
    
    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)
    
    def forward(self,x):
        bs,c,h,w=x.shape
        y=self.cnn(x)
        y=y.view(bs,c,-1) #bs,c,h*w
        y=self.pa(y,y,y) #bs,c,h*w
        return y


class DAModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
        self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
    
    def forward(self,input):
        bs,c,h,w=input.shape
        p_out=self.position_attention_module(input)
        c_out=self.channel_attention_module(input)
        p_out=p_out.permute(0,2,1).view(bs,c,h,w)
        c_out=c_out.view(bs,c,h,w)
        return p_out+c_out

五、EPSA(Efficient Pyramid Squeeze Attention)

论文:论文地址

如下图:

代码如下(代码连接):

import torch.nn as nn

class SEWeightModule(nn.Module):

    def __init__(self, channels, reduction=16):
        super(SEWeightModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)

        return weight


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    """standard convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class PSAModule(nn.Module):

    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):
        super(PSAModule, self).__init__()
        self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,
                            stride=stride, groups=conv_groups[0])
        self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,
                            stride=stride, groups=conv_groups[1])
        self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,
                            stride=stride, groups=conv_groups[2])
        self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,
                            stride=stride, groups=conv_groups[3])
        self.se = SEWeightModule(planes // 4)
        self.split_channel = planes // 4
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        feats = torch.cat((x1, x2, x3, x4), dim=1)
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])

        x1_se = self.se(x1)
        x2_se = self.se(x2)
        x3_se = self.se(x3)
        x4_se = self.se(x4)

        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_weight = feats * attention_vectors
        for i in range(4):
            x_se_weight_fp = feats_weight[:, i, :, :]
            if i == 0:
                out = x_se_weight_fp
            else:
                out = torch.cat((x_se_weight_fp, out), 1)

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/237823.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据在内存中的存储(浮点型篇)

1.例子:5.5:内存存储为101.1,十分位百分位依次为2的-1次方,2的-2次方,而使用科学计数法可以改写为1.011*2的2次方 2.国际标准公式:-1的D次方*M*2的E次方,x1负0正 3.M在存储时默认整数部分为1&…

C 语言指针学习笔记

C 语言中,指针存储的是变量的内存地址!!! 要彻底理解指针,首先要理解 C 语言中变量的存储本质,也就内存。 内存编址与内存空间 计算机的内存是一块用于存储数据的空间,由一系列连续的存储单元…

Python 反编译Il2Cpp APK

引入 https://github.com/Perfare/Il2CppDumper/ 实现 开源的Ii2Cpp Dumper可以帮助我们将So和globalmetadata.dat文件反编译出 Assembly-CSharp.dll 本博客教程可以帮助我们直接拖入APK反编译出来 调用方式 两种 第一种 拖入后回车运行 第二种 放入运行的根目录下 源码 i…

手动搭建koa+ts项目框架(日志篇)

文章目录 前言一、安装koa-logger二、引入koa-logger并使用总结如有启发,可点赞收藏哟~ 前言 本文基于手动搭建koats项目框架(路由篇)新增日志记录 一、安装koa-logger npm i -S koa-onerror and npm i -D types/koa-logger二、引入koa-lo…

IoTDB JavaAPI

文章目录 使用样例Java使用样例 官方已经给出了相关使用Demo,下载地址为: https://github.com/apache/iotdb 直接拉取相对应版本的源码 使用样例 Java使用样例 代码位置 iotdb/example/session/src/main/java/org/apache/iotdb/SessionExample.java iotdb/exa…

借助 AI 梳理知识:Quivr 帮你打造第二大脑 | 开源日报 No.103

fastlane/fastlane Stars: 37.8k License: MIT fastlane 是一个用于 iOS 和 Android 开发人员自动化繁琐任务的工具,如生成屏幕截图、处理配置文件和发布应用程序。 可以轻松地生成屏幕截图处理证书文件发布应用程序通过命令行快速执行操作 DrKLO/Telegram Sta…

血的教训,BigDecimal踩过的坑

很多人都用过Java的BigDecimal类型,但是很多人都用错了。如果使用不当,可能会造成非常致命的线上问题,因为这涉及到金额等数据的计算精度。 首先说一下,一般对于不需要特别高精度的计算,我们使用double或float类型就可…

【lesson7】数据类型之string类型

文章目录 数据类型分类string类型set类型测试 enum类型测试 string类型的内容查找找所有女生(enum中)找爱好有游泳的人(set中)找到爱好中有足球和篮球的人 数据类型分类 string类型 set类型 说明: set:集…

nrfutil工具安装

准备工作,下载相关安装包 链接:https://pan.baidu.com/s/1LWxhibf8LiP_Cq3sw0kALQ 提取码:2dlc 解压后,分别安装以下安装包 在C盘下创建目录nordic_tools,并将nrfutil复制到刚创建的目录下 环境变量path下添加C:\nor…

【截图版本】Linux常用指令详解

———————————————— 版权声明:本文为CSDN博主「小呆瓜历险记」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 原文链接:https://blog.csdn.net/m0_58963318/article/details/134713282

ProroBuf C++笔记

一.什么是protobuf Protocol Buffers是Google的⼀种语⾔⽆关、平台⽆关、可扩展的序列化结构数据的⽅法,它可⽤于(数据)通信协议、数据存储等。Protocol Buffers 类⽐于XML,是⼀种灵活,⾼效,⾃动化机制的结…

51单片机应用从零开始(十一)·数组函数、指针函数

51单片机应用从零开始(九)数组-CSDN博客 51单片机应用从零开始(十)指针-CSDN博客 目录 1. 用数组作函数参数控制流水花样 2. 用指针作函数参数控制 P0 口 8 位 LED 流水点亮 1. 用数组作函数参数控制流水花样 要在51单片机中…

0012Java安卓程序设计-ssm记账app

文章目录 **摘要**目 录系统设计5.1 APP端(用户功能)5.2后端管理员功能模块开发环境 编程技术交流、源码分享、模板分享、网课分享 企鹅🐧裙:776871563 摘要 网络的广泛应用给生活带来了十分的便利。所以把记账管理与现在网络相…

GaussDB数据库语法及gsql入门

一、GaussDB数据库语法入门 之前我们讲了如何连接数据库实例,那连接数据库后如何使用数据库呢?那么我们今天就带大家了解一下GaussDB,以下简称GaussDB的基本语法。 关于如何连接数据库,请戳这里。 学习本节课程之后&#xff0c…

【金华模式】双龙旅游引燃露营设计和文旅产融合新方式

(中国国际教育电视台 黎明)金华双龙风景旅游区位于浙江省金华市北郊的金华山麓,是一处融自然山水、溶洞群景观、科普探险、康体休闲、避暑度假、观光朝圣于一体的景区。旅游区人文积淀深厚,道、儒、释文化兼收并蓄,东汉…

大语言模型有什么意义?亚马逊训练自己的大语言模型有什么用?

近年来,大语言模型的崭露头角引起了广泛的关注,成为科技领域的一项重要突破。而在这个领域的巅峰之上,亚马逊云科技一直致力于推动人工智能的发展。那么,作为一家全球科技巨头,亚马逊为何会如此注重大语言模型的研发与…

Blender学习:走路机器人,骨骼绑定

文章目录 建模骨骼创建骨骼绑定 教程地址:八个案例教程带你从0到1入门blender【已完结】 建模 1 做头:新建立方体,Ctrl2细分并应用,进入编辑模式,删除一半点,然后添加镜像修改器,开启范围限制…

软件测试20个基础面试题及答案

什么是软件测试? 答案:软件测试是指在预定的环境中运行程序,为了发现软件存在的错误、缺陷以及其他不符合要求的行为的过程。 软件测试的目的是什么? 答案:软件测试的主要目的是保证软件的质量,并尽可能…

python socket编程9 - PyQt6界面实现UDP server/client 多客户端通讯的例子

本篇实现 UDP server和client多客户端通讯的例子。 在UDP单机通讯的基础上进行重构,实现UDP server与多个 client通讯的例子。 创建两个 PyQt6的项目,一个作为UDP server 项目,另一个作为UDP client项目。 一、效果图 1、udp server界面 …

在线学习平台-课程分页、用户管理、教师查询

在线学习平台------手把手教程👈 用户管理 添加功能增强 新增属性 若依里的用户模块(SysUser)是没有课程这一属性的,要实现我们自己的课程分页查询功能 这个位置传入的实体类SysUser要加上classId,记得加上get、set方法 更改sql语句 ctrl 鼠标左键不断点进去…