Transformer模型:Postion Embedding实现

前言

        这是对上一篇WordEmbedding的续篇PositionEmbedding。

视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

上一篇链接:Transformer模型:WordEmbedding实现-CSDN博客


正文

        先回顾一下原论文中对Position Embedding的计算公式:pos表示位置,i表示维度索引,d_model表示嵌入向量的维度,position分奇数列和偶数列。

        Position Embedding也是二维的,行数是训练的序列最大长度,列是d_model。首先定义position的最大长度,这里定为12,也就是训练中的长度最大值都是12。

max_position_len = 12

        这里先循环遍历得到pos,构造Pos序列,pos是从0到最大长度的遍历,决定行:

pos_mat = torch.arange(max_position_len)

        但是此时得到的是一维的,我们要将它转为二维矩阵的,也就是得到目标行数,使用.reshape()函数,这样就构造好了行矩阵

pos_mat = torch.arange(max_position_len).reshape((-1,1))

tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11]]) 

        接下来要构造列矩阵,构造 i 序列,首先是是2i/d_model部分,这里的8是因为我们设定的d_model=8,2是步长:

i_mat = torch.arange(0, 8, 2)/model_dim

        这时候再把分母的完整形式实现,幂次使用pow()函数:

i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)

tensor([   1.,   10.,  100., 1000.]) 

         此时就得到了列向量,这时候就有疑问了为什么列只有4列,我们的d_model不是8吗,应该有8列才对啊。这是因为区分了奇数列跟偶数列的计算,所以这里才要求步长为2生成的只有4列。

        先初始化一个max_position_len*model_dim的零矩阵(12*8),然后再分别使用sin和cos填充偶数列和奇数列:

pe_embedding_table = torch.zeros(max_position_len, model_dim)

pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)   # 从第0列到结束,步长为2,也就是填充偶数列
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)   # 从第1列到结束,步长为2,也就是填充奇数列

        得到的就是Position Embedding的权重矩阵了: 

        这下面采用的是使用nn.Embedding()的方法,得到的跟上面的结果还是一样的,只不过这里的pe_embedding是可以传入位置的,之后的调用就是这样得到的:

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)

         这里就要构造位置索引了:

src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)

        然后传入位置索引,就得到了src跟tgt的Position Embedding:

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

         这里我很疑惑的点是生成的结果src_pe_embedding跟tgt_pe_embedding内容是一样的,并且单个里面的一个内容也就是position embedding,刚入门听得我还是有点不太能理解。

src_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
tgt_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
src_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])
tgt_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])

 代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
 
# 句子数
batch_size = 2
 
# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10
 
# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12
 
# 模型的维度
model_dim = 8
 
# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)
 
#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)
 
# 构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)  
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)    
print(src_embedding)    
print(tgt_embedding)

# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))    
i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)

# 构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)    
pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)
print("pe_embedding_table is:\n",pe_embedding_table)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)

# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)
print("src_pos is:\n",src_pos)
print("tgt_pos is:\n",tgt_pos)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print("src_pe_embedding is:\n",src_pe_embedding)
print("tgt_pe_embedding is:\n",tgt_pe_embedding)

参考

Python的reshape的用法:reshape(1,-1)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/799544.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

张量分解(5)——Tucker分解

🍅 写在前面 👨‍🎓 博主介绍:大家好,这里是hyk写算法了吗,一枚致力于学习算法和人工智能领域的小菜鸟。 🔎个人主页:主页链接(欢迎各位大佬光临指导) ⭐️近…

【C++】构造函数详解

📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…

【开源 Mac 工具推荐之 1】gibMacOS:方便快捷的 macOS 完整包下载 Shell 工具

简介 gibMacOS 是由 GitHub 开发者 corpnewt 编写的一款 Shell 工具。它采用 Python 编程语言,可以让用户打开后在纯文本页面中轻松选择并下载来源于 Apple 官方的 macOS 完整安装包。 Repo 地址:https://github.com/corpnewt/gibMacOS (其…

阿里通义音频生成大模型 FunAudioLLM 开源

简介 近年来,人工智能(AI)技术的进步极大地改变了人类与机器的互动方式,特别是在语音处理领域。阿里巴巴通义实验室最近开源了一个名为FunAudioLLM的语音大模型项目,旨在促进人类与大型语言模型(LLMs&…

HTML+CSS博客文章列表

源代码在图片后面 点赞❤️收藏⭐️关注&#x1f495; 图示 感谢各位大佬支持 &#x1f618;&#x1f618;&#x1f618; 源代码 <!DOCTYPE html> <html lang"zh-CN"> <head> <meta charset"UTF-8"> <title>博…

解决ESLint和Prettier冲突的问题

在配置了ESLint的项目中使用Prettier进行格式化可能会出现冲突&#xff0c;不如Prettier配置了使用双引号&#xff0c;ESLint配置了单引号&#xff0c;当然可以一个一个改成一样的配置&#xff0c;但是比较麻烦。我发现可以直接使用ESLint的规则进行格式化。在VSCode配置过程如…

springmvc1

以前的servlet程序&#xff1a; springmvc 不同的处理器&#xff1a;不同的方法或者处理类 所有的请求都会经过dispathcherservlet的doservice方法&#xff1a; mvc原理&#xff1a; 前端控制器&#xff1a;jsp或者什么东西

AutoMQ 中的元数据管理

本文所述 AutoMQ 的元数据管理机制均基于 AutoMQ Release 1.1.0 版本 [1]。 01 前言 AutoMQ 作为新一代基于云原生理念重新设计的 Apache Kafka 发行版&#xff0c;其底层存储从传统的本地磁盘替换成了以对象存储为主的共享存储服务。对象存储为 AutoMQ 带来可观成本优势的…

【C++】初识C++(下)

前言 本篇博客继续总结一下C入门的一些小知识 &#x1f493; 个人主页&#xff1a;小张同学zkf ⏩ 文章专栏&#xff1a;C 若有问题 评论区见&#x1f4dd; &#x1f389;欢迎大家点赞&#x1f44d;收藏⭐文章 ​ 目录 1.引用 1.1引用的概念 1.2const引用 1.3指针和引用的…

外包干了1个月,技术明显退步。。。

有一种打工人的羡慕&#xff0c;叫做“大厂”。 真是年少不知大厂香&#xff0c;错把青春插稻秧。 但是&#xff0c;在深圳有一群比大厂员工更庞大的群体&#xff0c;他们顶着大厂的“名”&#xff0c;做着大厂的工作&#xff0c;还可以享受大厂的伙食&#xff0c;却没有大厂…

软件测试面试题全网独家没有之一的资深测试工程师面试题集锦

1.自我介绍&#xff1f; 我是谁、工作几年、你上家公司做什么、负责什么、你的优势、为什么适合这个职位、我想做什么、在这个职位上想得到什么 有自信、不能吞吞吐吐 时间长度2-3分钟 2编写测试用例有哪几种方法&#xff1f; 等价类、边界值、因果图、流程分析、错误分析、…

【Pytorch】数据集的加载和处理(一)

Pytorch torchvision 包提供了很多常用数据集 数据按照用途一般分为三组&#xff1a;训练&#xff08;train&#xff09;、验证&#xff08;validation&#xff09;和测试&#xff08;test&#xff09;。使用训练数据集来训练模型&#xff0c;使用验证数据集跟踪模型在训练期间…

ectype:拓展ctype

拓展C库的ctype模块&#xff0c;将字节块或字符串进行分类或转换。

SQL Server 创建用户并授权

创建用户前需要有一个数据库&#xff0c;创建数据库命令如下&#xff1a; CREATE DATABASE [数据库名称]; CREATE DATABASE database1; 一、创建登录用户 方式1&#xff1a;SQL命令 命令格式&#xff1a;CREATE LOGIN [用户名] WITH PASSWORD 密码; 例如&#xff0c;创建…

全球DeepFake攻防挑战赛DataWhale AI 夏令营——图像赛道

全球DeepFake攻防挑战赛&DataWhale AI 夏令营——图像赛道 赛题背景 随着人工智能技术的迅猛发展&#xff0c;深度伪造技术&#xff08;Deepfake&#xff09;正成为数字世界中的一把双刃剑。这项技术不仅为创意内容的生成提供了新的可能性&#xff0c;同时也对数字安全构…

Mac 息屏不断网

这里息屏指的是屏幕不黑&#xff0c;屏幕黑了好像必断网 我的系统是 14.5 我调整了两个地方&#xff0c;一个是电池——选项——唤醒以供访问 另外一个地方是锁定屏幕——延长关闭显示器的时间&#xff08;让显示器不黑&#xff09;

如何批量删除重复数据?推荐两种方法

在日常的办公中&#xff0c;很多用户都会使用Excel。借助这款软件&#xff0c;用户可以完成对各种数据的处理。但很多时候我们会发现&#xff0c;同一张表格里有很多重复的数据&#xff0c;这或许会为统计带来错误。为此&#xff0c;我们就需要删除重复项才可以帮助我们很好的解…

STM32+三色LED智能调光系统源程序 易安卓APP 原理图

资料下载地址&#xff1a;STM32三色LED智能调光系统源程序 易安卓APP 原理图 三色LED手机智能调光系统概述&#xff1a; 利用开发的智能手机软件&#xff0c;对照明三色LED进行智能调光。包含的功能有&#xff0c;支持多手机同时连接服务端&#xff0c;互动调光。支持关…

前端Vue组件化实践:打造仿京东天猫商品属性选择器组件

在前端开发领域&#xff0c;随着业务需求的日益复杂和技术的不断进步&#xff0c;传统的整体式应用开发模式已逐渐显得捉襟见肘。面对日益庞大的系统&#xff0c;每次微小的功能修改或增加都可能导致整个逻辑结构的重构&#xff0c;形成牵一发而动全身的困境。为了解决这一问题…

昇思25天学习打卡营第22天|ResNet50图像分类

上回说到RESNET50做迁移学习&#xff0c;今天看一下他在图片分类上面的表现。图像分类是最基础的计算机视觉应用&#xff0c;属于有监督学习类别&#xff0c;先喂一堆猫的照片告诉他这是猫&#xff0c;再喂一堆狗的图片告诉他这是狗。 CIFAR-10数据集共有60000张32*32的彩色图…