【大规模语言模型:从理论到实践】Transformer中PositionalEncoder详解

书籍链接:大规模语言模型:从理论到实践

第15页位置表示层代码详解
PositionalEncoder

1. 构造函数 __init__()

def __init__(self, d_model, max_seq_len=80):
    super().__init__()
    self.d_model = d_model  # 嵌入的维度(embedding dimension)
  • d_model: 表示输入词向量的维度。
  • max_seq_len: 表示句子的最大长度(最大序列长度)。
  • self.d_model: 保存词嵌入的维度。
创建 PE 矩阵
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
    for i in range(0, d_model, 2):
        pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
        pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

这里,我们为所有可能的位置 pos 和维度 i 生成了位置编码矩阵 pe。编码规则是使用正弦和余弦函数来生成位置编码:

  • 对于每个位置 pos,在每个嵌入维度 i 上:

    • 奇数维度使用正弦函数 sin(pos / 10000^(2i/d_model))
    • 偶数维度使用余弦函数 cos(pos / 10000^(2i/d_model))

    这样做的好处是,正弦和余弦函数生成了一个平滑的周期性变化,使得位置编码具有一定的连续性和距离信息。

pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
  • pe.unsqueeze(0):将 pe 的第一个维度扩展为 1,这是为了便于后续将其与输入批次结合在一起。
  • register_buffer:将 pe 作为一个不可训练的参数(Tensor),并注册为模型的一部分,以确保其在模型的 .cuda().to(device) 等操作时也能够转移到对应设备上。

2. 前向传播 forward()

def forward(self, x):
    x = x * math.sqrt(self.d_model)  # 对输入乘以嵌入维度的平方根,使得它们的值更大一些
  • 这里的 x 是输入的词嵌入(word embeddings),即一个形状为 [batch_size, seq_len, d_model] 的张量。
  • x = x * math.sqrt(self.d_model):这一行操作是为了放大嵌入值,使得单词嵌入值的范围更加合适。
seq_len = x.size(1)  # 获取序列长度(句子长度)
x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
  • seq_len = x.size(1):获取当前输入序列的长度。
  • self.pe[:, :seq_len]:根据当前序列长度,从 pe 中提取对应的位置信息(只取前 seq_len 个位置的编码)。
  • x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda():将位置信息 pe 添加到输入词嵌入中。requires_grad=False 表示不对位置编码进行梯度更新。

3. 详细分析x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()

这行代码在位置编码器中的作用是将预计算好的位置编码矩阵 pe 加到输入的词嵌入矩阵 x 上。这是为了在词嵌入的基础上加入位置信息,使模型能够同时使用词汇语义和位置信息。我们分解这句话的各个部分:

x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
1. self.pe[:, :seq_len]
  • self.pe 是我们在初始化时生成的位置编码矩阵,其形状为 [1, max_seq_len, d_model]

    • 这里的 1 是 batch 维度,用来保持与输入张量 x 形状的一致性。
    • max_seq_len 是句子可能的最大长度,表示可以编码的最大序列长度。
    • d_model 是词嵌入的维度。
  • self.pe[:, :seq_len] 表示从 pe 矩阵中取出前 seq_len 个位置的编码。这个操作的作用是根据输入句子的实际长度(seq_len)来选择对应长度的位置信息。例如,如果 seq_len 是 50,则取出 pe 中前 50 行的编码。

2. Variable(self.pe[:, :seq_len], requires_grad=False)
  • Variable 是用于包裹张量,使其在反向传播中能够区分哪些需要计算梯度,哪些不需要。
    • requires_grad=False 表示位置编码 pe 不参与梯度计算,位置编码是一个固定值,不会像模型权重那样进行训练或更新。

注意: 在较新的版本的 PyTorch 中,Variable 已经被整合到了 Tensor 中,不再需要显式使用 Variable。直接使用张量即可,它们本身已经具有 requires_grad 属性。

3. .cuda()
  • .cuda() 将张量移动到 GPU 上进行计算,确保模型的所有张量在同一个设备上。如果你使用的是 CPU,这一部分会报错或需要改成 .to(device),以便适应不同设备。
4. x + self.pe[:, :seq_len]
  • x 是输入的词嵌入矩阵,形状为 [batch_size, seq_len, d_model]
  • self.pe[:, :seq_len] 是位置编码矩阵,形状为 [1, seq_len, d_model],即与 x 的第二、第三维度一致。
  • 加法操作x + self.pe[:, :seq_len] 表示将对应位置的词嵌入和位置编码逐元素相加。这个加法是一个广播操作,即 self.pe 的第一个维度为 1,自动扩展到与 xbatch_size 相同大小,然后再进行相加操作。
5. self.pe[:, :seq_len]self.pe[:, :seq_len, :]相互替换

两者在功能上是等价的,但后者更明确地表达了正在获取 pe 矩阵的所有维度。这种做法在某些情况下可以提高代码的可读性,特别是当你的张量具有多个维度时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/872525.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot的二手车交易系统的设计与实现

题目:基于springboot的二手车交易系统的设计与实现 摘 要 如今社会上各行各业,都喜欢用自己行业的专属软件工作,互联网发展到这个时候,人们已经发现离不开了互联网。新技术的产生,往往能解决一些老技术的弊端问题。因…

【书籍推荐】马斯克推荐的5部经典书籍

埃隆马斯克是谁想必已经不需要介绍,世界首富都推荐过哪些值得看的好书?今天这篇文章整理了5本马斯克曾推荐过或评价值得一读的书,或许可以从中一探他改变世界的方法和奥秘。 《结构是什么》 结构高于内容,结构决定内容。内容是表…

C++笔记15•数据结构:二叉树之二叉搜索树•

二叉搜索树 1.二叉搜索树 概念: 二叉搜索树又称二叉排序树也叫二叉查找树,它可以是一棵空树。 二叉树具有以下性质: 若它的左子树不为空,则左子树上所有节点的值都小于根节点的值 若它的右子树不为空,则右子树上所有节点的值都…

vue3+ts封装类似于微信消息的组件

组件代码如下&#xff1a; <template><div:class"[voice-message, { sent: isSent, received: !isSent }]":style"{ backgroundColor: backgroundColor }"click"togglePlayback"><!-- isSent为false在左侧&#xff0c;为true在右…

十分钟简单了解Java中的数据类型和变量!

一.字面常量 public class test{public static void main(String[] args){system.out.println("Hello world!");} }在上述代码中&#xff0c;system.out.println(“Hello world!”);语句不管何时运行&#xff0c;输出的结果都是Hello world!,其实Hello world&#xf…

Obsidian git sync error / Obsidian git 同步失敗

Issue: commit due to empty commit message Solution 添加commit資訊&#xff0c;確保不留空白 我的設置&#xff1a;auto-backup: {{hostname}}/{{date}}/

虚幻引擎(Unreal Engine)技术使得《黑神话悟空传》大火,现在重视C++的开始吃香了,JAVA,Go,Unity都不能和C++相媲美!

虚幻引擎&#xff08;Unreal Engine&#xff09;火了黑神话游戏。 往后&#xff0c;会有大批量的公司开始模仿这个赛道&#xff01; C 的虚拟引擎技术通常指的是使用 C 语言开发的游戏引擎&#xff0c;如虚幻引擎&#xff08;Unreal Engine&#xff09;等。以下是对 C 虚拟引…

ThreadPoolExecutor状态流转和源码分析

为什么使用线程池 降低资源消耗 &#xff0c;可以重复利用已创建的线程降低线程创建和销毁造成的消耗。提高响应速度&#xff0c;当任务到达时&#xff0c;任务可以不需要等到线程创建就能立即执行。提高线程的可管理性 &#xff0c;线程是稀缺资源&#xff0c;如果无限制地创…

如何从 AWS CodeCommit 迁移到极狐GitLab?

极狐GitLab 是 GitLab 在中国的发行版&#xff0c;可以私有化部署&#xff0c;对中文的支持非常友好&#xff0c;是专为中国程序员和企业推出的企业级一体化 DevOps 平台&#xff0c;一键就能安装成功。安装详情可以查看官网指南。 本文将分享如何从 AWS CodeCommit 服务无缝迁…

2024年六月英语四级真题及解析PDF共9页

2024年六月英语四级真题及解析PDF共9页&#xff0c;真题就是最好的复习资料&#xff0c;希望对大家有所帮助。

Python爬虫(一文通)

Python爬虫&#xff08;基本篇&#xff09; 一&#xff1a;静态页面爬取 Requests库的使用 1&#xff09;基本概念安装基本代码格式 应用领域&#xff1a;适合处理**静态页面数据和简单的 HTTP 请求响应**。 Requests库的讲解 含义&#xff1a;requests 库是 Python 中一个…

基于百度AIStudio飞桨paddleRS-develop版道路模型开发训练

基于百度AIStudio飞桨paddleRS-develop版道路模型开发训练 参考地址&#xff1a;https://aistudio.baidu.com/projectdetail/8271882 基于python35paddle120env环境 预测可视化结果&#xff1a; &#xff08;一&#xff09;安装环境&#xff1a; 先上传本地下载的源代码Pad…

如何在IDEA的一个工程中创建多个项目?

在IDEA中&#xff0c;可以通过Module来创建新的工程。

​如何通过Kimi强化论文写作中的数据分析?

在学术研究领域&#xff0c;数据分析是验证假设、发现新知识和撰写高质量论文的关键环节。Kimi&#xff0c;作为一款先进的人工智能助手&#xff0c;能够在整个论文写作过程中提供支持&#xff0c;从文献综述到数据分析&#xff0c;再到最终的论文修订。本文将详细介绍如何将Ki…

torch.backends.cudnn.benchmark和torch.use_deterministic_algorithms总结学习记录

经常使用PyTorch框架的应该对于torch.backends.cudnn.benchmark和torch.use_deterministic_algorithms这两个语句并不陌生&#xff0c;在以往开发项目的时候可能专门化花时间去了解过&#xff0c;也可能只是浅尝辄止简单有关注过&#xff0c;正好今天再次遇到了就想着总结梳理一…

Redis安装步骤——离线安装与在线安装详解

Linux环境下Redis的离线安装与在线安装详细步骤 环境信息一、离线安装1、安装环境2、下载redis安装包3、上传到服务器并解压4、编译redis5、安装redis6、配置redis&#xff08;基础配置&#xff09;7、启动redis8、本机访问redis9、远程访问redis 二、在线安装1、更新yum源2、安…

【LeetCode】01.两数之和

题目要求 做题链接&#xff1a;1.两数之和 解题思路 我们这道题是在nums数组中找到两个两个数使得他们的和为target&#xff0c;最简单的方法就是暴力枚举一遍即可&#xff0c;时间复杂度为O&#xff08;N&#xff09;&#xff0c;空间复杂度为O&#xff08;1&#xff09;。…

域内安全:委派攻击

目录 域委派 非約束性委派攻击&#xff1a; 主动访问&#xff1a; 被动访问&#xff08;利用打印机漏洞&#xff09; 约束性委派攻击&#xff1a; 域委派 域委派是指将域内用户的权限委派给服务账户&#xff0c;使得服务账号能够以用户的权限在域内展开活动。 委派是域中…

P4560 [IOI2014] Wall 砖墙

*原题链接* 做法&#xff1a;线段树 一道比较基础的线段树练手题&#xff0c;区间赋值&#xff0c;在修改时加些判断剪枝。 对于add操作&#xff0c;如果此时区间里的最小值都大于等于h的话&#xff0c;就没必要操作&#xff0c;如果最大值都小于h的话&#xff0c;就直接区间…

坐牢第三十五天(c++)

一.作业 1.使用模版类自定义栈 代码&#xff1a; #include <iostream> using namespace std; template<typename T> // 封装一个栈 class stcak { private:T *data; //int max_size; // 最大容量int top; // 下标 public:// 无参构造函数stcak();// 有参…