LLama2源码分析——Rotary Position Embedding分析

参考:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)

原理推导参考自上文,以下结合huggingface代码分析公式计算过程

1 旋转角度计算

计算公式如下,其中d为词嵌入维度,这部分和论文原文一样
θ j = 1000 0 − 2 ( j − 1 ) / d , j ∈ [ 1 , 2 , … , d / 2 ] \theta_j=10000^{-2(j-1)/d},j\in [1,2,\ldots,d/2] θj=100002(j1)/d,j[1,2,,d/2]

# 计算词向量元素两两分组之后,每组元素对应的旋转角度
# 维度:[dim / 2]
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

2 计算整个seq的cos_sin矩阵

def _set_cos_sin_cache(self, seq_len, device, dtype):
    self.max_seq_len_cached = seq_len
    # 生成token长度序列
    t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
    # 计算两个矩阵的外积,结果维度[seq_len, dim // 2]
    freqs = torch.einsum("i,j->ij", t, self.inv_freq)
    # 类似[[0, 2, 4, ..., 0, 2, 4, ...], ...]形式,旋转角度两两一组相同
    emb = torch.cat((freqs, freqs), dim=-1)
    self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
    self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

3 计算旋转式位置编码

f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ \begin{aligned}f_q(x_m,m)&=(W_qx_m)e^{im\theta} \\f_k(x_n,n)&=(W_kx_n)e^{in\theta}\end{aligned} fq(xm,m)fk(xn,n)=(Wqxm)eimθ=(Wkxn)einθ
公式根据欧拉公式转化后为
( q m ( 1 ) + i q m ( 2 ) ) ∗ ( cos ⁡ ( m θ ) + i sin ⁡ ( m θ ) ) (q_{m}^{(1)}+iq_{m}^{(2)})*(\cos(m\theta)+i\sin(m\theta)) (qm(1)+iqm(2))(cos(mθ)+isin(mθ))

展开后将结果重新表示为实数向量即为
q m e i m θ = [ q m ( 1 ) cos ⁡ ( m θ ) − q m ( 2 ) sin ⁡ ( m θ ) , q m ( 2 ) cos ⁡ ( m θ ) + q m ( 1 ) sin ⁡ ( m θ ) ] q_me^{im\theta}=[q_m^{(1)}\cos(m\theta)-q_m^{(2)}\sin(m\theta),q_m^{(2)}\cos(m\theta)+q_m^{(1)}\sin(m\theta)] qmeimθ=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]
key的计算同理,以上公式是2维embedding的旋转编码计算,实际代码中是将高纬度的embedding两两分组按照上述公式计算,同一组内的旋转角度相同,此处Llama代码中的分组计算方式与论文原文有所区别,论文原文中是将embedding_dim维度(最后一维)的向量按照相邻两个位置数字为一组,可以按照如下代码理解

>>> a
tensor([[1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8]])
>>> a.view(2, -1, 2)
tensor([[[1, 2],
         [3, 4],
         [5, 6],
         [7, 8]],

        [[1, 2],
         [3, 4],
         [5, 6],
         [7, 8]]])

因此,单个token的位置编码是如下图方式计算
image
但以上的R矩阵比较稀疏,计算时浪费大量算力,因此Llama中采用不同的方式计算

  • Llama源码中计算方法

( q 0 q 1 ⋮ q d / 2 − 1 q d / 2 q d / 2 + 1 ⋮ q d − 1 ) ⊗ ( cos ⁡ m θ 0 cos ⁡ m θ 2 cos ⁡ m θ 4 ⋮ cos ⁡ m θ d − 2 cos ⁡ m θ 0 cos ⁡ m θ 2 ⋮ cos ⁡ m θ d − 2 ) + ( − q d / 2 − q d / 2 + 1 ⋮ − q d − 1 q 1 q 2 ⋮ q d / 2 − 1 ) ⊗ ( sin ⁡ m θ 0 sin ⁡ m θ 2 sin ⁡ m θ 4 ⋮ sin ⁡ m θ d − 2 sin ⁡ m θ 0 sin ⁡ m θ 2 ⋮ sin ⁡ m θ d − 2 ) \begin{pmatrix} {q_0}\\{q_1}\\{\vdots}\\{q_{d/2-1}}\\{q_{d/2}}\\{q_{d/2+1}}\\{\vdots}\\{q_{d-1}} \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_0\\\cos m\theta_2\\\cos m\theta_4\\\vdots\\\cos m\theta_{d-2}\\\cos m\theta_0\\\cos m\theta_2\\\vdots\\\cos m\theta_{d-2} \end{pmatrix} + \begin{pmatrix} {-q_{d/2}}\\{-q_{d/2+1}}\\\vdots\\{-q_{d-1}}\\{q_{1}}\\{q_{2}}\\\vdots\\{q_{d/2-1}} \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_0\\\sin m\theta_2\\\sin m\theta_4\\\vdots\\\sin m\theta_{d-2}\\\sin m\theta_0\\\sin m\theta_2\\\vdots\\\sin m\theta_{d-2} \end{pmatrix} q0q1qd/21qd/2qd/2+1qd1 cosmθ0cosmθ2cosmθ4cosmθd2cosmθ0cosmθ2cosmθd2 + qd/2qd/2+1qd1q1q2qd/21 sinmθ0sinmθ2sinmθ4sinmθd2sinmθ0sinmθ2sinmθd2

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/685584.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序下载、安装教程-2024年6月6日

微信小程序下载、安装教程-2024年6月6日 一、下载二、安装 一、下载 链接:https://pan.baidu.com/s/1pThpJEtOik9sgOI0F3mr_Q?pwdi1p3 提取码:i1p3 –来自百度网盘超级会员V6的分享 本文是用的网盘下载,具体都差不多。 或者从微信小程序官…

数据分析第一天(pandas简单的对快餐店数据进行操作获得想要的信息,使用apply,groupby)

前言 数据保存在 https://github.com/harkbox/DataAnalyseStudy 数据名称:快餐数据.tsv (tsv是用\t作为字符分隔符的文件格式;csv是逗号) 因此可以用pandas的read_csv函数读取数据 1.读取数据 import pandas as pd import matp…

Vue3学习记录(第一天)

Vue3学习记录_第一天 背景说明记录Vue3实现响应式前端的反射前端对象的属性赋值Vue3响应式实现过程稿前端移除对象的属性 背景 本次学习主要是看视频学习, 没有跟练, 但是很多知识点感觉又容易忘记. 所以通过笔记的方式输出一下. 说明 估计只能自己看懂, 如果能提供一些其他…

【Python报错】已解决ModuleNotFoundError: No Module Named ‘openyxl’

成功解决“ModuleNotFoundError: No Module Named ‘openyxl’”错误的全面指南 在Python编程中,遇到ModuleNotFoundError: No Module Named openyxl这样的错误通常意味着Python解释器无法找到名为openyxl的模块。然而,这里存在一个常见的拼写错误&#…

解决CSDN 导入Markdown图片失效不显示问题

每次将MarkDown文件导入CSDN的时候,有些图片总是由于防盗链的问题导致图片加载不出来,还得手动再导一遍,极其不方便。所以我们能不能建立一个属于自己的图片服务器或者说在线图库呢,而且每次使用Typora插入图片的时候都会自动的上…

Docker自定义镜像实现(SpringBoot程序为例)

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。🍎个人主页:Meteors.的博客💞当前专栏:知识备份✨特色专栏:知识分享&#x1f96…

华为HCIP-DATACOM 831最新题目

如图所示的网络,相邻的路由器之间使用直连接口建立EBGP邻居关系,AS号为6500x,其中X为路由器的编号。R1和R4均有到达192.168.1.0/24的静态路由,通过import方式引入BGP。在R3上配置EBGP负载分担的最大等价路由条数为8。缺省情况下&a…

搜索与图论:图中点的层次

搜索与图论&#xff1a;图中点的层次 题目描述参考代码 题目描述 输入样例 4 5 1 2 2 3 3 4 1 3 1 4输出样例 1参考代码 #include <cstring> #include <iostream> #include <algorithm>using namespace std;const int N 100010;int n, m; int h[N], e[N]…

VS2019 QT无法打开 源 文件 “QTcpSocket“

VS2019 QT无法打开 源 文件 "QTcpSocket" QT5.15.2_msvc2019_64 严重性 代码 说明 项目 文件 行 禁止显示状态 错误(活动) E1696 无法打开 源 文件 "QTcpSocket" auto_pack_line_demo D:\vs_qt_project\auto_pack_line_de…

UE5刷植物悬空了

UE5系列文章目录 文章目录 UE5系列文章目录前言一、解决办法 前言 在Unreal Engine5.3中使用植物模式刷各种植物时&#xff0c;有时会发现有的植物要么悬空&#xff0c;要不有刷不上地板的情况。而且悬空的植物还不能接触到地面&#xff0c;感觉很奇怪&#xff0c;就像下图所示…

2024.6.9周报

目录 摘要 ABSTRACT 一、文献阅读 1、相关信息 2、摘要 3、文献解读 1、Introduction 2、文章主要贡献 3、模型架构 4、实验 4、结论 二、代码实现 总结 摘要 本周我阅读了一篇题目为《Unlocking the Potential of Transformers in Time Series Forecasting with …

三十四篇:办公效率革命:深入探索办公自动化系统的全面策略

办公效率革命&#xff1a;深入探索办公自动化系统的全面策略 1. 引言 1.1 办公自动化系统&#xff08;OAS&#xff09;的定义与关键作用 在当前的企业环境中&#xff0c;办公自动化系统&#xff08;Office Automation System, OAS&#xff09;已成为提高效率和执行力的关键技…

全面守护你的健康ZL-0891A小动物多参数监护仪

简单介绍&#xff1a; 12.1英寸彩色TFT显示&#xff0c;分辨率800X600,采用数字血氧DSP算法&#xff0c;低灌注&#xff0c;小动物多参数监护仪具有优良的抗运动性能;动物用血压算法&#xff0c;支持测量各种动物类型,特有的中英文语音报警;支持USB数据导出&#xff0c;可以在…

嵌入式学习记录6.6(拷贝构造/友元函数/常成员函数)

一.拷贝构造函数和拷贝赋值函数 1.1拷贝构造函数功能,格式 拷贝构造函数是一种特殊的构造函数&#xff0c;用来将一个类对象给另一个类对象初始化使用的。 1> 用一个类对象给另一个类对象初始化时&#xff0c;会自动调用拷贝构造函数。 2> 当一个类对作为函数的实参&…

jdk快速配置

在系统变量新建两个变量先下载&#xff0c;直接安装 jdk-****-windows-x64 名称&#xff0c;看看面对java安装目录&#xff0c;我这里是默认目录为例 1.JAVA_HOME C:\Program Files\Java\jdk-1.82.CLASSPATH .;%JAVA_HOME%\lib;%JAVA_HOME%\lib\tools.jarpath里新建这两个…

韩顺平0基础学java——第18天

p374-395 类变量和类方法 类变量&#xff08;静态变量&#xff09; 例&#xff1a; class Child{ public static Int count&#xff1b;//这个count可以被所有Child实例共享 /..../ } 内存中&#xff0c;static在堆中是独立存放的&#xff0c;并不在某个对象的空间中。 由于…

【数据结构】C语言实现二叉树的基本操作——二叉树的遍历(先序遍历、中序遍历、后序遍历)

C语言实现二叉树的基本操作 导读一、二叉树的遍历二、先序遍历三、中序遍历四、后序遍历五、结点序列六、递归算法与非递归算法的转化结语 导读 大家好&#xff0c;很高兴又和大家见面啦&#xff01;&#xff01;&#xff01; 通过前面的介绍&#xff0c;我们已经认识了二叉树…

SAP 限制物料类型在BOM组件中简介

我们在创建BOM的时候通常是基于成品或者是半成品虚拟件创建BOM。正常情况下某些特殊的物料类型是不存在BOM中的。我们可以通过系统后台配置的方式对物料类型进行控制,控制对应的物料类型是否允许出现在BOM的组件中 1、后台配置路径: SPRO—生产—基本信息—物料清单—项目数…

【Linux取经路】网络套接字编程——TCP篇

文章目录 前言十、Tcp Server 端代码10.1 socket、bind10.1 listen——监听一个套接字10.2 accept——获取一个新连接10.3 read——从套接字中读取数据10.4 write——向套接字中进行写入10.5 Tcp Service 端完整代码&#xff08;单进程版&#xff09;10.6 Tcp Server 端代码&am…

【ZYNQ】CPU 私有定时器

Zynq 的每个 Cortex-A9 处理器都有自己的专用 32 位定时器和 32 位看门狗定时器&#xff0c;两个处理器共享一个全局 64 位定时器&#xff0c;这些计时器的时钟频率始终为 CPU 频率的 1/2。本文主要介绍 Zynq 芯片 CPU 私有定时器的工作特性&#xff0c;以及私有定时器的基本使…