机器学习笔记 - EANet 外部注意论文简读及代码实现

一、论文简述

        论文作者提出了一种新的轻量级注意力机制,称之为外部注意力。如图所示,计算自注意力需要首先通过计算自查询向量和自关键字向量之间的仿射关系来计算注意力图,然后通过用该注意力图加权自值向量来生成新的特征图。外部关注的作用不同。我们首先通过计算自查询向量和外部可学习密钥存储器之间的亲和力来计算注意力图,然后通过将该注意力图乘以另一个外部可学习值存储器来生成细化的特征图。

         在实践中,这两个存储器是用线性层实现的,因此可以通过端到端的反向传播来优化。它们独立于单个样本,并在整个数据集中共享,这起到了很强的正则化作用,提高了注意力机制的泛化能力。外部注意力的轻量级本质的关键在于,存储器中的元素数量远小于输入特征中的元素数目,从而产生输入中元素数目线性的计算复杂度。外部存储器旨在学习整个数据集中最具鉴别力的特征,捕捉信息量最大的部分,并排除其他样本中的干扰信息。类似的想法可以在稀疏编码[9]或字典学习中找到。然而,与这些方法不同的是,我们既没有尝试重建输入特征,也没有对注意力图应用任何显式稀疏正则化。

        尽管所提出的外部注意力方法很简单,但它对各种视觉任务都是有效的。由于其简单性,它可以很容易地融入现有流行的基于自注意的架构中,如DANet、SAGAN和T2T Transformer。图3展示了一个典型的架构,该架构将图像语义分割任务的自我注意力替换为外部注意力。我们使用不同的输入模式(图像和点云),对分类、对象检测、语义分割、实例分割和生成等基本视觉任务进行了广泛的实验。结果表明,我们的方法获得的结果与原始的自注意机制及其一些变体相当或更好,以低得多的计算成本。

使用我们提出的外部注意力进行语义分割的EANet架构。

二、相关参考代码

1、基于torch的实现

        外部注意力实现

import numpy as np
import torch
from torch import nn
from torch.nn import init

class ExternalAttention(nn.Module):

    def __init__(self, d_model,S=64):
        super().__init__()
        self.mk=nn.Linear(d_model,S,bias=False)
        self.mv=nn.Linear(S,d_model,bias=False)
        self.softmax=nn.Softmax(dim=1)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries):
        attn=self.mk(queries) #bs,n,S
        attn=self.softmax(attn) #bs,n,S
        attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S
        out=self.mv(attn) #bs,n,d_model

        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    ea = ExternalAttention(d_model=512,S=8)
    output=ea(input)
    print(output.shape)

        调用参考

input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)

2、基于tensorflow的EANet实现

        EANet只是用外部注意力代替了Vit中的自我注意力。

        这里的数据集是采用cifar100数据集,首先加载划分数据集,然后配置超参数

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt



num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")



weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2  # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2  # Number of patch
embedding_dim = 64  # Number of hidden units.
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8  # Number of repetitions of the transformer layer

print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")

        进行数据增强

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.1),
        layers.RandomContrast(factor=0.1),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

        实现补丁提取和编码层‘

class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size, self.patch_size, 1),
            strides=(1, self.patch_size, self.patch_size, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)

        实现外部注意力块

def external_attention(
    x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
):
    _, num_patch, channel = x.shape
    assert dim % num_heads == 0
    num_heads = num_heads * dim_coefficient

    x = layers.Dense(dim * dim_coefficient)(x)
    # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
    x = tf.reshape(
        x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
    )
    x = tf.transpose(x, perm=[0, 2, 1, 3])
    # a linear layer M_k
    attn = layers.Dense(dim // dim_coefficient)(x)
    # normalize attention map
    attn = layers.Softmax(axis=2)(attn)
    # dobule-normalization
    attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
    attn = layers.Dropout(attention_dropout)(attn)
    # a linear layer M_v
    x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
    x = tf.transpose(x, perm=[0, 2, 1, 3])
    x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
    # a linear layer to project original dim
    x = layers.Dense(dim)(x)
    x = layers.Dropout(projection_dropout)(x)
    return x

        实施 MLP

def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
    x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
    x = layers.Dropout(drop_rate)(x)
    x = layers.Dense(embedding_dim)(x)
    x = layers.Dropout(drop_rate)(x)
    return x

        实现变压器模块,基于参数配置选择外部注意或是自我关注。

def transformer_encoder(
    x,
    embedding_dim,
    mlp_dim,
    num_heads,
    dim_coefficient,
    attention_dropout,
    projection_dropout,
    attention_type="external_attention",
):
    residual_1 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    if attention_type == "external_attention":
        x = external_attention(
            x,
            embedding_dim,
            num_heads,
            dim_coefficient,
            attention_dropout,
            projection_dropout,
        )
    elif attention_type == "self_attention":
        x = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
        )(x, x)
    x = layers.add([x, residual_1])
    residual_2 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    x = mlp(x, embedding_dim, mlp_dim)
    x = layers.add([x, residual_2])
    return x

        实施 EANet 模型

def get_model(attention_type="external_attention"):
    inputs = layers.Input(shape=input_shape)
    # Image augment
    x = data_augmentation(inputs)
    # Extract patches.
    x = PatchExtract(patch_size)(x)
    # Create patch embedding.
    x = PatchEmbedding(num_patches, embedding_dim)(x)
    # Create Transformer block.
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(
            x,
            embedding_dim,
            mlp_dim,
            num_heads,
            dim_coefficient,
            attention_dropout,
            projection_dropout,
            attention_type,
        )

    x = layers.GlobalAvgPool1D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

        进行训练并可视化

model = get_model(attention_type="external_attention")

model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

history = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_split=validation_split,
)


model.save('eanet_cifar100.h5')

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

         进行验证

loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

三、相关参考

arxiv.org/pdf/2105.02358.pdfhttps://arxiv.org/pdf/2105.02358.pdfExternal-Attention-pytorch/ExternalAttention.py at master · xmu-xiaoma666/External-Attention-pytorch · GitHubhttps://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/ExternalAttention.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/27282.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智慧加油站解决方案,提高加油区和卸油区的安全性和效率

英码科技智慧加油站解决方案是一个综合应用了AI智能算法的视觉分析方案,旨在提高加油区和卸油区的安全性和效率。 加油区算法: 吸烟检测:通过AI算法分析视频流,检测是否有人在加油区域吸烟,以防止火灾风险。 打电话…

STM32开发——串口通讯(非中断+中断)

目录 1.串口简介 2.非中断接收发送字符 3.中断接收字符 1.串口简介 通过中断的方法接受串口工具发送的字符串,并将其发送回串口工具。 串口发送/接收函数: HAL_UART_Transmit(); 串口发送数据,使用超时管理机制HAL_UART_Receive(); 串口…

案例26:基于Springboot校园社团管理系统开题报告

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

苹果MacOS系统傻瓜式本地部署AI绘画Stable Diffusion教程

Stable Diffusion的部署对小白来说非常麻烦,特别是又不懂技术的人。今天分享两个一键傻瓜式安装包,对小白来说非常有用。下面两个任选一个安装就可以。 一、DiffusionBee 简单介绍 DiffusionBee是基于stable diffusion的一个安装包,有图形…

python文字转语音(pyttsx3+flask)

提示:文章结尾有全部代码 目录 前言一、Flaskpyttsx基本使用Flask导入Flask框架配置基础环境初始Flask代码 pyttsx3库基本使用导入pyttsx3初始化pyttsx3文字转语音运行 二、具体实现1.引入库 总结 前言 本文主要讲解如何用python的pyttsx3库flask框架,手…

有了IP地址,还需要MAC地址嘛?二选一可否?

概要 在计算机网络中,IP地址和MAC地址是两个最基本的概念。IP地址在互联网中是用于标识主机的逻辑地址,而MAC地址则是用于标识网卡的物理地址。虽然它们都是用于标识一个设备的地址,但是它们的作用和使用场景是不同的。 IP地址是在网络层&am…

数据结构之树与二叉树——算法与数据结构入门笔记(五)

本文是算法与数据结构的学习笔记第五篇,将持续更新,欢迎小伙伴们阅读学习。有不懂的或错误的地方,欢迎交流 引言 前面章节介绍的都是线性存储的数据结构,包括数组、链表、栈、队列。本节带大家学习一种非线性存储的数据结构&…

chatgpt赋能python:Python分词:从原理到实践

Python分词:从原理到实践 分词是自然语言处理中的关键步骤之一,它是指将一句话或一段文本分成若干个词语(token)并进行标注。Python作为一种非常流行的编程语言,具备强大的文本处理能力,而分词也是它的强项…

chatgpt赋能python:Python如何切换中文

Python 如何切换中文 Python 是一种广泛使用的编程语言,被用于多种目的,包括数据分析、机器学习、Web 应用程序等。在使用 Python 进行开发时,需要处理不同的语言,其中中文也是包括在内的。对于需要切换中文的情况,本…

学生考试作弊检测系统 yolov8

学生考试作弊检测系统采用yolov8网络模型人工智能技术,学生考试作弊检测系统过在考场中安装监控设备,对学生的作弊行为进行实时监测。当学生出现作弊行为时,学生考试作弊检测系统将自动识别并记录信息。YOLOv8 算法的核心特性和改动可以归结为…

关于数据生成二维码保存和解密删除二维码

文章目录 前言一、pom配置依赖二、文件引入1.BufferedImageLuminanceSource2.QRCodeUtil3.MyPicConfig4.UploadUtils三、测试前言 所需文件: MyPicConfig 主要解决上传图片实时刷新BufferedImageLuminanceSource 算法文件QRCodeUtil 生成二维码工具类UploadUtils 主要解决上传…

软考A计划-系统架构师-官方考试指定教程-(13/15)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&am…

【Java基础学习打卡03】计算机中数据的表示、存储与处理

目录 前言一、数据的表示1.数据与信息2.计算机中的数据3.计算机中数据的单位 二、数据的存储三、数据的处理1.进位计数值2.进制间转换 四、字符编码总结 前言 本小节主要介绍在计算机中数据的表示、存储与处理。要知道计算机内部使用二进制数据,也就是0和1组成的数…

2.3 YARN伪分布式集群搭建

任务目的 重点掌握 YARN 集群的相关配置学会启动和关闭 YARN 集群的两种方式能够使用 jps 命令查看进程的启动情况能够通过 UI 查看 YARN 集群的运行状态任务清单 任务1:YARN 集群主要配置文件讲解任务2:YARN 集群测试任务步骤 任务1:YARN 集群主要配置文件讲解 1.1 配置环…

【新版】系统架构设计师 - 计算机系统基础知识

个人总结,仅供参考,欢迎加好友一起讨论 文章目录 架构 - 计算机系统基础知识考点摘要计算机系统计算机硬件组成浮点数Flynn分类法CISC与RISC流水线技术超标量流水线存储系统层次化存储结构CacheCache的命中率Cache的页面淘汰主存编址磁盘管理&#xff08…

Linux 信号

文章目录 1. 信号1.1 前言1.2 信号的位置1.3 接口1.3.1 sigset_t1.3.2 信号集操作接口1.3.3 signal1.3.4 sigprocmask1.3.5 sigpending 2. 信号的处理2.1 内核态和用户态2.2 信号的监测和处理 1. 信号 1.1 前言 在 Linux 中,信号是一种用于进程之间的通信机制&…

地震勘探基础(十一)之水平叠加处理

水平叠加处理 地震资料经过预处理,静校正,反褶积,速度分析和动校正处理后就要进行水平叠加处理。地震水平叠加处理是地震常规处理的重要环节。 假设一个共中心点道集有三个地震道,经过速度分析和动校正以后,水平叠加…

【数据结构】何为数据结构。

🚩 WRITE IN FRONT 🚩 🔎 介绍:"謓泽"正在路上朝着"攻城狮"方向"前进四" 🔎🏅 荣誉:2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2022博客之星T…

Tik Tok的海外娱乐公会(中亚、巴西、美国、台湾)怎么申请?

TIKTOK 公会海外市场潜力巨大 自 2016 年始,多家直播平台陆续拓展至东南亚、中东、俄罗斯、日韩、 欧美、拉美等地区 海外市场作为直播发展新蓝海,2021 年直播行业整体规模达百亿美元, 并维持高速增长 TikTok 直播市场空间 TikTok 已经成…

【 Python 全栈开发 - WEB开发篇 - 31 】where条件查询

文章目录 一、where条件查询1.关系运算符查询2.IN关键字查询3.BETWEEN AND关键字查询4.空值查询5.AND关键字查询6.OR关键字查询7.LIKE关键字查询普通字符串含有%通配的字符串含有_通配的字符串 一、where条件查询 MySQL 的 where 条件查询是指在查询数据时,通过 wh…