视觉Transformers中的位置嵌入 - 研究与应用指南

视觉 Transformer 中位置嵌入背后的数学和代码简介。

自从 2017 年推出《Attention is All You Need》以来,Transformer 已成为自然语言处理 (NLP) 领域最先进的技术。 2021 年,An Image is Worth 16x16 Words² 成功地将 Transformer 应用于计算机视觉任务。从那时起,人们提出了许多基于Transformer的计算机视觉架构。

本文[1]研究了为什么位置嵌入是视觉Transformer的必要组成部分,以及不同的论文如何实现位置嵌入。它包括位置嵌入的开源代码以及概念解释。所有代码都使用 PyTorch 包。

为什么使用位置嵌入?

Attention is All You Need 指出,Transformer由于缺乏递归或卷积,无法学习有关一组标记顺序的信息。如果没有位置嵌入,Transformer对于标记的顺序是不变的。对于图像,这意味着可以对图像的补丁进行加扰,而不会影响预测的输出。

让我们看一下 Luis Zuno 的像素艺术《黄昏山》中补丁顺序的示例。原始艺术作品已被裁剪并转换为单通道图像。这意味着每个像素都有一个介于 0 和 1 之间的值。单通道图像通常以灰度显示;但是,我们将以紫色配色显示它,因为它更容易看到。

mountains = np.load(os.path.join(figure_path, 'mountains.npy'))

H = mountains.shape[0]
W = mountains.shape[1]
print('Mountain at Dusk is H =', H, 'and W =', W, 'pixels.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
plt.clim([0,1])
cbar_ax = fig.add_axes([0.95.110.050.77])
plt.clim([01])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountains.png'), bbox_inches='tight')
alt

我们可以将此图像分割成大小为 20 的块。

P = 20
N = int((H*W)/(P**2))
print('There will be', N, 'patches, each', P, 'by', str(P)+'.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_patches.png'), bbox_inches='tight')
alt

据称,视觉Transformer将无法区分原始图像和补丁被打乱的版本。

np.random.seed(21)
scramble_order = np.random.permutation(N)
left_x = np.tile(np.arange(0, W-P+120), 3)
right_x = np.tile(np.arange(P, W+120), 3)
top_y = np.repeat(np.arange(0, H-P+120), 5)
bottom_y = np.repeat(np.arange(P, H+120), 5)

scramble = np.zeros_like(mountains)
for i in range(N):
    t = scramble_order[i]
    scramble[top_y[i]:bottom_y[i], left_x[i]:right_x[i]] = mountains[top_y[t]:bottom_y[t], left_x[t]:right_x[t]]
    
fig = plt.figure(figsize=(10,6))
plt.imshow(scramble, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(N):
    plt.text(x_text[i], y_text[i], str(scramble_order[i]+1), color='w', fontsize='xx-large', ha='center')
    
i3 = np.where(scramble_order==2)[0][0]
plt.text(x_text[i3], y_text[i3], str(scramble_order[i3]+1), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_scrambled_patches.png'), bbox_inches='tight')
alt

显然,这是与原始图像非常不同的图像,并且您不希望视觉Transformer将这两个图像视为相同。

排列的注意力不变性

让我们研究一下视觉Transformer对于标记顺序不变的说法。Transformer中对 token 顺序不变的组件是注意力模块。

注意力是根据三个矩阵(查询、键和值)计算得出的,每个矩阵都是通过将token传递到线性层而生成的。生成 Q、K 和 V 矩阵后,将使用以下公式计算注意力。

alt

其中 Q、K、V 分别是查询、键和值; dₖ 是缩放值。为了证明注意力对 token 顺序的不变性,我们将从三个随机生成的矩阵开始来表示 Q、K 和 V。Q、K 和 V 的形状如下:

alt

在此示例中,我们将使用 4 个预计长度为 9 的标记。矩阵将包含整数以避免浮点乘法错误。生成后,我们将交换token 0 和token 2 在所有三个矩阵中的位置。具有交换标记的矩阵将用下标 s 表示。

n_tokens = 4
l_tokens = 9
shape = n_tokens, l_tokens
mx = 20 #max integer for generated matricies

# Generate Normal Matricies
np.random.seed(21)
Q = np.random.randint(1, mx, shape)
K = np.random.randint(1, mx, shape)
V = np.random.randint(1, mx, shape)

# Generate Row-Swapped Matricies
swapQ = copy.deepcopy(Q)
swapQ[[02]] = swapQ[[20]]
swapK = copy.deepcopy(K)
swapK[[02]] = swapK[[20]]
swapV = copy.deepcopy(V)
swapV[[02]] = swapV[[20]]

# Plot Matricies
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(8,8))
fig.tight_layout(pad=2.0)
plt.subplot(321)
mat_plot(Q, 'Q')
plt.subplot(322)
mat_plot(swapQ, r'$Q_S$')
plt.subplot(323)
mat_plot(K, 'K')
plt.subplot(324)
mat_plot(swapK, r'$K_S$')
plt.subplot(325)
mat_plot(V, 'V')
plt.subplot(326)
mat_plot(swapV, r'$V_S$')
alt

注意力公式中的第一个矩阵乘法是 Q·Kᵀ=A,其中得到的矩阵 A 是一个大小等于 token 数量的正方形。当我们用 Qₛ 和 Kₛ 计算 Aₛ 时,得到的 Aₛ 的行 [0, 2] 和列 [0,2] 都与 A 交换。

A = Q @ K.transpose()
swapA = swapQ @ swapK.transpose()
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows
modA[:, [20]] = modA[:, [02]] #swap cols

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(8,3))
fig.tight_layout(pad=1.0)
plt.subplot(131)
mat_plot(A, r'$A = Q*K^T$')
plt.subplot(132)
mat_plot(swapA, r'$A_S = Q_S * K_S^T$')
plt.subplot(133)
mat_plot(modA, 'A\nwith rows [0,2] swaped\n and cols [0,2] swaped')
alt

下一个矩阵乘法是 A·V=A,其中生成的矩阵 A 与初始 Q、K 和 V 矩阵具有相同的形状。当我们用 Aₛ 和 Vₛ 计算 Aₛ 时,得到的 Aₛ 的行 [0,2] 与 A 交换。

A = A @ V
swapA = swapA @ swapV
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(127))
fig.tight_layout(pad=1.0)
plt.subplot(221)
mat_plot(A, r'$A = A*V$')
plt.subplot(222)
mat_plot(swapA, r'$A_S = A_S * V_S$')
plt.subplot(224)
mat_plot(modA, 'A\nwith rows [0,2] swaped')
axs[1,0].axis('off')
alt

这表明,更改注意力层输入中标记的顺序会导致输出注意矩阵的相同标记行发生变化。因为注意力是对标记之间关系的计算。如果没有位置信息,更改token顺序不会改变token的关联方式。

示例

定义位置嵌入

现在,我们可以看看正弦位置嵌入的细节。该代码基于 Tokens-to-Token ViT 的公开可用 GitHub 代码。从功能上来说,位置嵌入是一个与 token 形状相同的矩阵。这看起来像:

alt

正弦位置嵌入公式如下所示

alt

其中 PE 是位置嵌入矩阵,i 是沿着标记的数量,j 是沿着标记的长度,d 是标记长度。代码实现:

def get_sinusoid_encoding(num_tokens, token_len):
    """ Make Sinusoid Encoding Table

        Args:
            num_tokens (int): number of tokens
            token_len (int): length of a token
            
        Returns:
            (torch.FloatTensor) sinusoidal position encoding table
    """


    def get_position_angle_vec(i):
        return [i / np.power(100002 * (j // 2) / token_len) for j in range(token_len)]

    sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

让我们生成一个示例位置嵌入矩阵。我们将使用 176 个tokens。每个token的长度为 768,这是 T2T-ViT代码中的默认长度。一旦生成了矩阵,我们就可以绘制它。

PE = get_sinusoid_encoding(num_tokens=176, token_len=768)

fig = plt.figure(figsize=(108))
plt.imshow(PE[0, :, :], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([-11])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'fullPE.png'), bbox_inches='tight')
alt

放大标记的开头。

fig = plt.figure()
plt.imshow(PE[0, :, 0:301], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95.20.050.6])
plt.clim([-11])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'zoomedinPE.png'), bbox_inches='tight')
alt

具有正弦结构!

将位置嵌入应用于tokens

现在,我们可以将位置嵌入添加到我们的tokens中!我们将使用《Mountain at Dusk》,并具有与上述相同的补丁标记化。这将为我们提供 15 个长度为 20²=400 的token。

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center')
cbar_ax = fig.add_axes([0.95.110.050.77])
plt.clim([01])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountain_patches_w_colorbar.png'), bbox_inches='tight')
alt

当我们将这些补丁转换为token时,它看起来像

tokens = np.zeros((1520**2))
for i in range(15):
    patch = gray_mountains[top_y[i]:bottom_y[i], left_x[i]:right_x[i]]
    tokens[i, :] = patch.reshape(120**2)
tokens = tokens.astype(int)
tokens = tokens/255

fig = plt.figure(figsize=(10,6))
plt.imshow(tokens, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

现在,我们可以以正确的形状进行位置嵌入:

PE = get_sinusoid_encoding(num_tokens=15, token_len=400).numpy()[0,:,:]
fig = plt.figure(figsize=(10,6))
plt.imshow(PE, aspect=5, cmap='PuOr_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

我们现在准备将位置嵌入添加到标记中。位置嵌入中的紫色区域将使令牌变暗,而橙色区域将使它们变亮。

mountainsPE = tokens + PE
resclaed_mtPE = (position_mountains - np.min(position_mountains)) / np.max(position_mountains - np.min(position_mountains))

fig = plt.figure(figsize=(10,6))
plt.imshow(resclaed_mtPE, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

您可以从原始token中看到结构,以及位置嵌入中的结构!这两条信息都将被转发到Transformer中。

Reference
[1]

Source: https://towardsdatascience.com/position-embeddings-for-vision-transformers-explained-a6f9add341d5

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/431389.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【go语言开发】yaml文件配置和解析

本文主要介绍使用第三方库来对yaml文件配置和解析。首先安装yaml依赖库;然后yaml文件中配置各项值,并给出demo参考;最后解析yaml文件,由于yaml文件的配置在全局中可能需要,可定义全局变量Config,便于调用 文…

面试题HTML+CSS+网络+浏览器篇

文章目录 Css预处理sass less是什么?为什么使用他们怎么转换 less 为 css?重绘和回流是什么http 是什么?有什么特点HTTP 协议和 HTTPS 区别什么是 CSRF 攻击HTML5 新增的内容有哪些Css3 新增的特性flex VS grid清除浮动的方式有哪些&#xff…

SAR ADC学习笔记(3)

一、SAR ADC采样电路 1.采样网络的时域响应:采保信号 2.采样网络的KT/C噪声 3.采样抖动 采样开关的种类 1.单MOS管开关 2.传输门开关 3.栅极自举(Bootstrap)开关 结论:M4的衬底需要和B点短接,保证B点能够到达高压&…

完美解决Iframe嵌入帆软报表出现跨域cookie写不进去的问题

随着google chrome对第三方cookie的限制越来越狠,现在发现之前使用iframe嵌入的帆软报表已经不好使了。官方现在解决iframe嵌入帆软报表出现跨域导致cookie写不进去的方案是主推 统一主域名的方案(谷歌浏览器单点登录失败- FineReport帮助文档 - 全面的报表使用教程和学习资料…

大唐杯学习笔记:Day5

1.1 小区搜索 搜索流程 PLMN选择 自动模式:UE根据NAS的请求或自主地向NAS报告可用的PLMN 手动模式:通过手动选择一个可用的VPLMN获取正常服务 频点选择 5G NR中,3GPP主要指定了两个频率范围,一个是6GHZ以下,另一个是毫米波,分别称之为FR1和FR2。 N…

稀碎从零算法笔记Day5-LeetCode:轮转数组

题型:数组、数学、双指针 前言:LC说你得用三种方法做出来(悲) 链接:189. 轮转数组 - 力扣(LeetCode) 来源:LeetCode 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 …

专业145+总分410+西工大西北工业大学827信号与系统考研经验电子信息与通信工程,海航,真题,大纲,参考书。

经过一年的努力,分数终于出来。今年专业课827信号与系统145(很遗憾差了一点点满分,没有达到Jenny老师的最高要求),数一130,英语和政治也都比较平衡,总分410分,当然和信息通信考研Jen…

学习java第一天(下载并配置环境+写第一个java程序)

一.安装 1.下载 直接去官网上选择与你电脑符合的版本下载 官网链接Java Archive Downloads - Java SE 8u211 and later (拿我的为例 Windows x64版本) ​ 2.然后安装好exe(要让自己知道在哪) 3.配置环境 大佬链接&#xff1…

蓝桥杯前端Web赛道-新鲜的蔬菜

蓝桥杯前端Web赛道-新鲜的蔬菜 题目链接:1.新鲜的蔬菜 - 蓝桥云课 (lanqiao.cn) 题目要求如下: 其实很容易联想到使用flex布局,这是flex布局一种非常经典的骰子布局,推荐Flex 布局教程:实例篇 - 阮一峰的网络日志 (r…

Java基于SpringBoot网上超市的设计与实现论文

摘 要 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计算机运行的软件更是受到各界的关注。加上现在人们已经步入信息时代,所以对于信息的宣传和管理就很关键。因此超市商品销售信…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:点击回弹效果)

设置组件点击时回弹效果。 说明: 从API Version 10开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 clickEffect clickEffect(value: ClickEffect | null) 设置当前组件点击回弹效果。 系统能力: SystemCapabilit…

SQL技巧笔记(一):连续3人的连号问题—— LeetCode601.体育馆的人流量

SQL 技巧笔记 前言:我发现大数据招聘岗位上的应聘流程都是需要先进行笔试,其中占比很大的部分是SQL题目,经过一段时间的学习之后,今天开了一个力扣年会员,我觉得我很有必要去多练习笔试题目,这些题目是有技…

Linux - 进程概念

1、冯诺依曼体系结构 我们常见的计算机,如笔记本。我们不常见的计算机,如服务器,大部分都遵守冯诺依曼体系; 截至目前,我们所认识的计算机,都是有一个个的硬件组件组成: 输入单元:…

浏览器发出一个请求到收到响应步骤详解

前言 在网络通信中,浏览器向Web服务器发送HTTP请求消息的过程是一个复杂而精密的环节,涉及到URL解析、DNS解析、数据拆分、路由表规则和MAC头部添加等一系列步骤。本文将深入探讨这一过程的每个环节,帮助读者更全面地了解浏览器与Web服务器之…

肠道菌群参与利那洛肽对便秘型肠易激综合征 (IBS-C) 患者的影响:一项多中心、前瞻性、前后对照研究

谷禾健康 肠易激综合征 (IBS) 是一种普遍存在的、症状驱动的慢性疾病,其特征是腹部不适和排便不规律,估计影响全球 11.2% 的人口。这些患者中大约三分之一被诊断患有便秘型肠易激综合征(IBS-C),这是IBS 的一种亚型。 鉴…

JVM 第二部分-1(程序计数器,虚拟机栈,本地方法栈)

第二部分:运行时数据区 1.程序计数器: 全称是程序计数寄存器,像CPU的寄存器一样,存放线程的下一条指令的地址。每个线程都有一个 (区域小,执行速度快,不会有垃圾回收,也不会报oom错…

【外汇天眼】外汇投资策略:区间突破交易系统

RangeBreak系统介绍 RangeBreak区间突破交易系统被市场广泛用于日内交易,曾经连续多年在《美国期货杂志》盈利交易系统排行榜中位居前十。 目前该交易系统也仍旧被很多专业机构和个人投资者所推崇。 交易者可根据自己的交易习惯和性格特点进行改进,并不…

网络编程 24/3/4 作业

1、广播 发送端 #include <myhead.h> int main(int argc, const char *argv[]) {//创建套接字int sfdsocket(AF_INET,SOCK_DGRAM,0);if(sfd-1){perror("socket error");return -1;}//设置当前套接字允许广播属性int broadcast1;if(setsockopt(sfd,SOL_SOCKET…

numpy数据操作

numpy数据操作 读取数据 # numpy读取数据np.loadtxt(fname, 文件、字符串或产生器, 也可以是压缩文件dtypenp.float, 数据类型, 可选, 即确认csv的字符串以什么数据类型读入数组中, 默认为np.floatdelimiterNone, 分割读取的字符串, 默认是…

基于灰狼算法GWO的城市三维无人机路径规划(复杂地形三维航迹路径规划)

摘要 本文提出了一种利用灰狼算法GWO来解决城市环境下无人机三维路径规划问题的方法。这种方法将复杂的无人机航迹规划任务转化为一个优化问题&#xff0c;然后运用灰狼算法GWO来解决这个优化问题。灰狼算法GWO是一种模拟灰狼种群捕猎行为的优化算法&#xff0c;它具备强大的全…