鹅厂面试官:Transformer 为何需要位置编码?

最近这一两周看到不少互联网公司都已经开始秋招发放Offer。

不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。

最近,我们又陆续整理了很多大厂的面试题,帮助一些球友解惑答疑,分享技术面试中的那些弯弯绕绕。

  • 《大模型面试宝典》(2024版) 正式发布!

喜欢本文记得收藏、关注、点赞。更多实战和面试交流,文末加入我们

技术交流

在这里插入图片描述

本文基于 llama 模型的源码,学习相对位置编码的实现方法,本文不细究绝对位置编码和相对位置编码的数学原理。

大模型新人在学习中容易困惑的几个问题:

  • 为什么一定要在 transformer 中使用位置编码?

  • 相对位置编码在 llama 中是怎么实现的?

  • 大模型的超长文本预测和位置编码有什么关系?

01 为什么需要位置编码

很多初学者都会读到这样一句话:transformer 使用位置编码的原因是它不具备位置信息。大家都只把这句话当作公理,却很少思考这句话到底是什么意思?

这句话的意思是,如果没有位置编码,那么 “床前明月”、“前床明月”、“前明床月” 这几个输入,会预测出完全一样的文本。

也就是说,不管你输入的 prompt 顺序是什么,只要 prompt 的文本是相同的,那么模型 decode 的文本就只取决于 prompt 的最后一个 token。

import torch
from torch import nn
import math


batch = 1
dim = 10
num_head = 2
embedding = nn.Embedding(5, dim)
q_matrix = nn.Linear(dim, dim, bias=False)
k_matrix = nn.Linear(dim, dim, bias=False)
v_matrix = nn.Linear(dim, dim, bias=False)


x = embedding(torch.tensor([1,2,3])).unsqueeze(0)
y = embedding(torch.tensor([2,1,3])).unsqueeze(0)


def attention(input):
    q = q_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)
    k = k_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)
    v = v_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)


    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim // num_head)
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    outputs = torch.matmul(attn_weights, v).transpose(1, 2).reshape(1, len([1,2,3]), dim)
    print(outputs)


attention(x)
attention(y)

执行上面的代码会发现,虽然 x 和 y 交换了第一个 token 和第二个 token 的输入顺序,但是第三个 token 的计算结果完全没有发生改变,那么模型预测第四个 token 时,便会得到相同的结果。

如果有读者对矩阵运算感到混淆的话,可以看看下面的简单推导:

图片

可以看出,当第一个 token 与第二个 token 交换顺序后,模型输出矩阵的第一维和第二维也交换了顺序,但输出的值完全没有变化。

第三个 token 的输出结果也是完全没有受到影响,这也就是前面说的:如果没有位置编码,模型 decode 的文本就只取决于 prompt 的最后一个 token

不过需要注意的是,由于 attention_mask 的存在(前置位 token 看不到后置位 token),所以即使不加位置编码,transformer 的输出还是会受到 token 的位置影响。

02 相对位置编码的实现

我们以 modeling_llama.py 的源码为例,来学习相对位置编码的实现方法。

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)


        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)


    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

相对位置编码在 attention 中的应用方法如下:

self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)


query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)


if past_key_value is not None:
    # reuse k, v, self_attention
    key_states = torch.cat([past_key_value[0], key_states], dim=1)
    value_states = torch.cat([past_key_value[1], value_states], dim=1)

根据 value_states 矩阵的形状去调取 cos 和 sin 两个 tensor, cos 与 sin 的维度均是 batch_size * head_num * seq_len * head_dim;

利用 apply_rotary_pos_emb 去修改 query_states 和 key_states 两个 tensor,得到新的 q,k 矩阵

需要注意的是,在解码时,position_ids 的长度是和输入 token 的长度保持一致的,prompt 是 4 个 token 的话。

第一次解码时,position_ids: tensor([[0, 1, 2, 3]], device=‘cuda:0’),q 矩阵与 k 矩阵的相对位置编码信息通过 apply_rotary_pos_emb() 获得;

第二次解码时,position_ids: tensor([[4]], device=‘cuda:0’),当前 token 的相对位置编码信息通过 apply_rotary_pos_emb() 获得。

前 4 个 token 的相对位置编码信息则是通过 key_states = torch.cat([past_key_value[0], key_states], dim=1) 集成到 k 矩阵中;

……

……

以上代码的公式,均可以从苏神原文中找到。

这些代码可以从 llama 模型中剥离出来直接执行,如果感到困惑,可以像下面一样,将 apply_rotary_pos_emb() 的整个过程给 print 出来观察一下:

head_num, head_dim, kv_seq_len = 8, 20, 5
position_ids = torch.tensor([[0, 1, 2, 3, 4]])
query_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
key_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
value_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
rotary_emb = LlamaRotaryEmbedding(head_dim)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
print(cos, sin)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

03 位置编码与长度外推

长度外推指的是,大模型在训练的只见过长度为 X 的文本,但在实际应用时却有如下情况:

图片

我们假设 X 的取值为 4096,那么也就意味着,模型自始至终没有见到过 pos_id >= 4096 的位置编码,进而导致模型的预测结果完全不可控。

因此,解决长度外推问题的关键便是如何让模型见到比训练文本更长的位置编码。

图片

以上关于文本外推的介绍均是比较大白话的理解,只是为了强调位置编码很重要这一观点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/901323.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端零基础入门到上班:【Day2】开发环境VSCode安装

VSCode 安装教程:图文保姆教程 引言 在前端开发中,选择合适的代码编辑器是提高工作效率的重要一步。Visual Studio Code(简称 VSCode)作为一款强大的开源编辑器,因其简洁易用、功能强大、扩展性好而广受开发者喜爱。…

【智能大数据分析 | 实验四】Spark实验:Spark Streaming

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈智能大数据分析 ⌋ ⌋ ⌋ 智能大数据分析是指利用先进的技术和算法对大规模数据进行深入分析和挖掘,以提取有价值的信息和洞察。它结合了大数据技术、人工智能(AI)、机器学习(ML&a…

Postman常见问题及解决方(全)

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 1、网络连接问题 如果Postman无法发送请求或接收响应,可以尝试以下操作: 检查网络连接是否正常,包括检查网络设置、代理设置…

前端零基础入门到上班:【Day3】从零开始构建网页骨架HTML

HTML 基础入门&#xff1a;从零开始构建网页骨架 目录 1. 什么是 HTML&#xff1f;HTML 的核心作用 2. HTML 基本结构2.1 DOCTYPE 声明2.2 <html> 标签2.3 <head> 标签2.4 <body> 标签 3. HTML 常用标签详解3.1 标题标签3.2 段落和文本标签3.3 链接标签3.4 图…

市面上热门的四款PDF转换器解析!!

在互联网普及的今天&#xff0c;PDF和Excel已经成为我们工作中不可或缺的两种文件格式。PDF常用于文档的阅读、打印和分享&#xff0c;而Excel则适用于数据的分析和处理。但是&#xff0c;有时候我们需要在两者之间进行转换&#xff0c;例如将PDF中的数据导入到Excel中进行进一…

物联网数据采集网关详细介绍-天拓四方

一、物联网数据采集网关的概述 物联网数据采集网关&#xff0c;简称数据采集网关&#xff0c;是物联网系统中的重要组成部分&#xff0c;位于物联网设备和云端平台之间。其主要职责是实现数据的采集、汇聚、转换、传输等功能&#xff0c;确保来自不同物联网设备的数据能够统一…

Hadoop 踩坑汇总

文章目录 一、完整教程二、解决问题问题①&#xff1a; DataNode 没有问题②&#xff1a; 网页打不开 三、大功告成&#xff01;&#xff01; 一、完整教程 这个教程比较详细&#xff0c;博主是按照这个来执行的 https://blog.csdn.net/qq_47831505/article/details/123806514…

VsCode插件:前端每日一题

Javascript本地存储的方式有哪些&#xff1f; 区别及应用场景? 1. Cookie Cookie 是网站为了辨别用户身份、进行session跟踪而储存在用户本地终端上的数据。Cookie 通常包含了用户的一些个人信息&#xff0c;如用户名、密码、浏览记录、偏好设置等。Cookie 一般在用户访问网站…

Excel:vba实现生成随机数

Sub 生成随机数字()Dim randomNumber As IntegerDim minValue As IntegerDim maxValue As Integer 设置随机数的范围(假入班级里面有43个学生&#xff0c;学号是从1→43)minValue 1maxValue 43 生成随机数(在1到43之间生成随机数)randomNumber Application.WorksheetFunctio…

智联招聘×Milvus:向量召回技术提升招聘匹配效率

01. 业务背景 在智联招聘平台&#xff0c;求职者和招聘者之间的高效匹配至关重要。招聘者可以发布职位寻找合适的人才&#xff0c;求职者则通过上传简历寻找合适的工作。在这种复杂的场景中&#xff0c;我们的核心目标是为双方提供精准的匹配结果。在搜索推荐场景下&#xff0c…

深入理解gPTP时间同步过程

泛化精确时间协议(gPTP)是一个用于实现精确时间同步的协议,特别适用于分布式系统中需要高度协调的操作,比如汽车电子、工业自动化等。 gPTP通过同步主节点(Time Master)和从节点(Time Slave)的时钟,实现全局一致的时间参考。 以下是gPTP实现主从时间同步的详细过程:…

奥迪一汽新能源:300台AGV、1000台机器人、24米立体库

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 位于长春的奥迪新能源工厂&#xff0c;占地面积广阔&#xff0c;达到了约150公顷&#xff0c;其规模之宏大&#xff0c;甚至超越了奥迪在欧洲的内卡苏姆工厂。 这座工厂不仅是奥迪在中…

一、在cubemx下RTC配置调试实例测试

一、rtc的时钟有lse提供。 二、选择rtc唤醒与闹钟功能 内部参数介绍 闹钟配置 在配置时间时&#xff0c;注意将时间信息存储起来&#xff0c;防止复位后时间重新配置。 if(HAL_RTCEx_BKUPRead(&hrtc, RTC_BKP_DR0)! 0x55AA)//判断标志位是否配置过&#xff0c;没有则进…

使用Angular构建动态Web应用

&#x1f496; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4bb; Gitee主页&#xff1a;瑕疵的gitee主页 &#x1f680; 文章专栏&#xff1a;《热点资讯》 使用Angular构建动态Web应用 1 引言 2 Angular简介 3 安装Angular CLI 4 创建Angular项目 5 设计应用结构 6 创建组件…

【每日一题】LeetCode - 盛最多水的容器

给定一个长度为 n 的整数数组 height。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i])。要求找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 输入示例&#xff1a; height [1,8,6,2,5,4,8,3,7]输出&#xff1a; 4…

CSS行块标签的显示方式

块级元素 标签&#xff1a;h1-h6&#xff0c;p,div,ul,ol,li,dd,dt 特点&#xff1a; &#xff08;1&#xff09;如果块级元素不设置默认宽度&#xff0c;那么该元素的宽度等于其父元素的宽度。 &#xff08;2&#xff09;所有的块级元素独占一行显示. &#xff08;3&#xff…

安卓在windows连不上fastboot问题记录

fastboot在windows连不上fastboot 前提是android studio安装 google usb driver 搜索设备管理器 插拔几次找安卓设备 在其他设备 或者串行总线设备会出现安卓 右键更新驱动 下一步下一步然后可以了

【FISCO BCOS】二十二、使用Key Manager加密区块链节点

#1024程序员节&#xff5c;征文# 落盘加密是对节点存储在硬盘上的内容进行加密&#xff0c;加密的内容包括&#xff1a;合约的数据、节点的私钥。具体的落盘加密介绍&#xff0c;可参考&#xff1a;落盘加密的介绍&#xff0c;今天我们来部署并对节点进行落盘加密。 环境&a…

高效文本编辑与导航:Vim中的三种基本模式及粘滞位的深度解析

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

Bode图(波特图)

波特图&#xff1a; 通常用波特图分析信号的频率响应。 对设计滤波器的人来说&#xff0c;比较关注的是在特定的频率内&#xff0c;到底有怎样的增益和相移。根据前面分析的内容&#xff0c;波特图刚好是研究增益和相移。所以要想设计一个满足性能的滤波器&#xff0c;必须要…