Ai 算法之Transformer 模型的实现: 一 、Input Embedding模块和Positional Embedding模块的实现


一 文章生成模型简介

比较常见的文章生成模型有以下几种:

  1. RNN:循环神经网络。可以处理长度变化的序列数据,比如自然语言文本。RNN通过隐藏层中的循环结构来传递时间序列中的信息,从而使当前的计算可以参照之前的信息。但这种模型有梯度爆炸和梯度消失的风险,所以只能做简单的生成任务。
  2. LSTM:长短记忆网络。通过引入门控制机制来控制信息传递。有效避免了梯度消失和梯度保障的问题。LSTM可以做些复杂的生成任务。
  3. Transformer:目前最火的,一种基于自注意力机制(self-attention mechanism)的神经网络模型。Transformer 和 以上所述的几个生成模型主要的区别是,RNN、LSTM的训练迭代是串行的,必须要处理完当前字才可以处理下一个。而 Transformer 所有字符是同时训练的,也就是并行的。因此它效率更高,同样,由于参考了全文位置信息,因此效果更好。

值得一提的是这几个模型的价值并不仅限于在文章生成中。所有需要"经验值"的应用场景应该都适合借鉴。比如19年我曾尝试用LSTM来实现物联网小车自动驾驶。将操作指令转换为文字编码,实现了自动巡航、避障、撞墙倒车等操作。效果还不错。相信更换为注意力机制效果会更好

本文无意重塑轮子,纯是基于兴趣学习,尝试复现模型构造过程,本文所使用环境为python3.9+pytorch,参考论文为Google的Attention Is All You Need 2017。欢迎骚扰探讨

关于RNN和LSTM的实现代码,请查看我博客中的相关文章

1.1 Transformer 结构图

左侧为外国原版,右侧为在下翻译版
请添加图片描述
Transformer 模型主要分为两大部分,分别是 Encoder 、 Decoder,即组码器和解码器。组码器负责把输入语言序列映射成隐藏层,然后解码器再把隐藏层映射为其他自然语言序列。在原文中解码器和编码器都被设为6层(N = 6)。据说这个6没有特殊的含义。只是根据经验平衡了训练和精度的尝试数字。
在输入语句进入组码器前需要对数据进行预处理。这就是本章的主要内容:Embedding模块的实现

二 Input Embedding 字符编码模块的实现

字符编码本质上就相当于映射,将现实中的物体用数学的方式映射到计算机中。以翻译任务为例,我们需要准备两种不同的语言数据,并使用索引将他们一一对应。比如英文字符[i, eat, shit], 中文[我,吃,屎],这就相当于我们知道了问题和答案,剩下的就是训练隐藏层的参数了。

在npl中,为了使字符可以计算,首先要先将输入的词汇进行数学转化。在比较在其的语言处理中,一般使用one hot(独热)编码。即指定一个表值范围数组,单独改变某个位置上的值来决定其特征。
独热编码示例:
[1,0 ,0 ,0] = 我
[0,1 ,0 ,0] = 吃
[0,0, 1 ,0] = 屎
独热编码简单清晰,但无法对比两个值之间的相似性,无法进行降维操作。所以在tranfomer中 使用多维向量来表示单词的编码信息。一个向量表示一个单词。多个单词在一起就是一个矩阵。相比较以前的独热编码,词向量可以便于计算单词之间的相似性(点积),也可以进行降维操作。
单词向量示例:
[11,23,31,32]
[23,21,31,23]
[13,32,33,93]

单词的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。以下是使用pythoch获取Embedding向量的代码脚本,复制可用。

import torch
import torch.nn as nn

# padding:当句子长度不一,有空白时用0补缺
embedding = nn.Embedding(单词数量, 向量维度,padding=0)
# 根据索引获取8个单词向量
input = torch.LongTensor([[1, 2, 3, 4], [11, 12, 13, 13]])
print(embedding(input))
print(embedding(input).shape)

三 Positional Embedding 位置编码模块的实现

位置编码模块负责将输入序列中的位置信息写入词向量,输入到transformer中的句子没有顺序信息,因此需要通过计算句子的长度,单词长度以及单词所在的位置通过编码来为输入系列添加位置信息。Tranformer原文作者使用的是正弦余弦编码

位置 Embedding 用 PE表示,PE 的维度与单词 Embedding 是一样的。PE 可以通过训练得到,也可以使用某种公式计算得到。在 Transformer 中采用了后者,计算公式如下:

那么单词向量是怎么得来的呢?
单词向量 = 原始单词编码 + 单词位置编码
举个例子:我吃屎 = i eat shit

在这里插入图片描述
位置编码计算公式

偶数索引: P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d ) 偶数索引:PE(pos,2i)=sin(pos/10000^2i/d) 偶数索引:PE(pos,2i)=sin(pos/100002i/d)
单数索引: P E ( p o s , 2 i ) = c o s ( p o s / 1000 0 2 i / d ) 单数索引:PE(pos,2i)=cos(pos/10000^2i/d) 单数索引:PE(pos,2i)=cos(pos/100002i/d)

import torch
import torch.nn as nn
import ludash as ld
import cv2
import seaborn    
import matplotlib.pyplot as plt

term = (10000**2/i)
pe[:, 0::2] = torch.sin(position * term )
pe[:, 1::2] = torch.cos(position * term )

四 获取预处理数据

获取到字符编码和位置编码后,就可以计算出参考了字符位置的权重矩阵

公式: [ q , k , v ] = ( I n p u t E m b e d d i n g + p o s i t i o n a l E m b e d d i n g ) ∗ [ W q , W k , W v ] 公式: [q, k, v] =(Input Embedding + positional Embedding)* [Wq, Wk, Wv] 公式:[q,k,v]=InputEmbedding+positionalEmbedding[Wq,Wk,Wv]
q = 查询向量, k = 键值向量, v = 值向量 q = 查询向量,k = 键值向量,v = 值向量 q=查询向量,k=键值向量,v=值向量






取得这个值后就可以进行下一步:传入Transfrom的组码器进行组码处理了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/247181.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序过滤器之计算当前时间差

微信小程序过滤器之计算当前时间差 前言一、wxs简介二、使用步骤1.定义2.使用 前言 最近遇到了一个需求,将小程序里面的具体时间2023-12-11 09:41:06转为当前时间差10小时前,这块可以使用js逻辑函数对数据进行处理,但这里我们采用微信小程序…

网络安全——SQL注入实验

一、实验目的要求: 二、实验设备与环境: 三、实验原理: 四、实验步骤: 五、实验现象、结果记录及整理: 六、分析讨论与思考题解答: 七、实验截图: 一、实验目的要求: 1、…

【javascript】npm ERR! cb() never called!

错误 环境 windows 10 nvm node 14.17.0 如何解决 尝试了 5 种方法 1,npm cache clean --force 2, npm cache verify 3, 删掉package-lock.json (然鹅我的这个项目没有这个文件) 4, npm set strict-ssl false 5, 删除node_modules 这五种…

智能高效|AIRIOT智慧货运管理解决方案

随着全球贸易的增加和消费需求的不断扩大,货运行业面临更大的压力,传统货运行业运输效率低下、信息不透明,往往存在如下的运维问题和管理痛点: 无法实时定位和追踪信息:无法提供实时的货物位置信息,以便随…

【Spark精讲】RDD特性之数据本地化

目录 首选运行位置 数据的本地化级别 谁来负责数据本地化 数据本地化执行流程 调优 代码中的设置方法 首选运行位置 上图红框为RDD的特性五:每个RDD的每个分区都有一组首选运行位置,用于标识RDD的这个分区数据最好能够在哪台主机上运行。通过RDD的…

嵌入式系统挑战赛---多线程并发打印奇偶数

一、题目要求 编写一个C语言程序,实现多线程并发打印奇偶数。要求使用两个线程,一个线程打印奇数,另一个线程打印偶数,打印范围为1到100。要求奇数线程先打印,偶数线程后打印,且要保证线程按次序交替进行。…

32、应急响应——linux

文章目录 一、linux进程排查二、linux文件排查三、linux用户排查四、linux持久化排查4.1 历史命令4.2 定时任务排查4.3 开机启动项排查 五、linux日志分析六、工具应用 一、linux进程排查 查看资源占用:top查看所有进程:ps -ef根据进程PID查看进程详细信…

LeetCode(60)K 个一组翻转链表【链表】【困难】

目录 1.题目2.答案3.提交结果截图 链接: K 个一组翻转链表 1.题目 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍&#xf…

【源码】车牌检测+QT界面+附带数据库

目录 1、基本介绍2、基本环境3、核心代码3.1、车牌识别3.2、车牌定位3.3、车牌坐标矫正 4、界面展示4.1、主界面4.2、车牌检测4.3、查询功能 5、演示6、链接 1、基本介绍 本项目采用tensorflow,opencv,pyside6和pymql编写,pyside6用来编写UI界…

Java架构师-数据机构与算法实战(第一篇)

数学知识回顾 指数 指数函数是重要的基本初等函数之一。一般地,ya^x函数(a为常数且以a>0,a≠1)叫做指数函数,函数的定义域是 R 。注意,在指数函数的定义表达式中,在a^x前的系数必须是数1,自变量x必须在…

ubantu22.04.3 安装4080驱动

新电脑安装驱动网卡EX211只适配22.04的内核,其他系统升级内核易出问题不推荐。 安装系统为系统盘安装制作Ubuntu22.04启动盘_ubuntu下制作pe启动盘-CSDN博客,参考此作者,选择系统为22.04.3 其他版本不推荐因前面用22.04安装显卡后出现兼容性…

Power BI - 5分钟学习增加索引列

每天5分钟,今天介绍Power BI增加索引列。 什么是增加索引列? 增加索引列就是向表中添加一个具有显式位置值的新列,一般从0或者从1开始。 举例: 首先,导入一张【Sales】样例表(Excel数据源导入请参考每天5分钟第一天)…

【Linux】tree命令使用

tree命令 tree命令用于以树状图列出目录的内容。 语法 tree [参数] [目录] tree 命令 -Linux手册页 bash: tree: 未找到命令... 安装tree yum -y install tree如果你系统中有安装tree 但是还是执行找不到该命令的话,那原因就是:环境变量错误&#x…

Google Shopping Action

Google Shopping Action是Google推出的一项在线购物服务,可以帮助零售商将产品推广和销售到Google平台上的消费者中。通过Google Shopping Action,用户可以在谷歌搜索页面上直接购买商品,而不需要离开搜索结果页面。 Google Shopping Action的…

神通数据库字段空与非空

神通数据库可以在建表时指定字段非空或可空, -- 指定column1字段非空 CREATE TABLE SYSDBA.tmp_test1(column1 varchar(100) NOT NULL)--尝试向column1字段插入空值 INSERT INTO SYSDBA.tmp_test1(column1) VALUES(NULL) 会收到插入失败的提示: 而如果…

基于JavaWeb实现的勤工俭学管理系统

一、系统架构 前端:jsp | js | css | jquery | layui 后端:spring | springmvc | mybatis 环境:jdk1.8 | mysql 二、代码及数据库 三、功能介绍 01. web端-首页 02. web端-论坛 03. web端-个人中心 04. web端-平台公告 05. web端-平…

音视频技术开发周刊 | 323

每周一期,纵览音视频技术领域的干货。 新闻投稿:contributelivevideostack.com。 Meta牵头组建开源「AI复仇者联盟」,AMD等盟友800亿美元力战OpenAI英伟达 超过50家科技大厂名校和机构,共同成立了全新的人工智能联盟。以开源为旗号…

CBTC上海新能源锂电池展览会奋战华东!2024携手共赢!

2024CBTC上海新能源锂电池技术展览会|上海锂离子电池生产设备展览会 时 间:2024年7月24~26日 地 点:国家会展中心(上海虹桥) 发展前景: 随着科技的不断进步,锂电池市场逐渐成为全球能源市场的…

@Transactional注解详细使用

Transactional注解详细使用 Transactional注解是Spring框架中用于管理事务的注解,它可以应用于类或方法上。使用该注解可以确保一个方法或类中的操作要么全部成功提交,要么全部回滚,从而保证数据的完整性和一致性。下面是Transactional注解的…

Gradio入门详细教程

常用的两款AI可视化交互应用比较: Gradio Gradio的优势在于易用性,代码结构相比Streamlit简单,只需简单定义输入和输出接口即可快速构建简单的交互页面,更轻松部署模型。适合场景相对简单,想要快速部署应用的开发者。便…