【Transformer】transformer模型结构学习笔记

文章目录

      • 1. transformer架构
      • 2. transformer子层解析
      • 3. transformer注意力机制
      • 4. transformer部分释疑

 

图1 transformer模型架构
图2 transformer主要模块简介
图3 encoder-decoder示意图N=6
图4 encoder-decoder子层示意图

1. transformer架构

  • encoder-decoder框架是一种处理NLP或其他seq2seq转换任务中的常见框架, 机器翻译就是典型的seq2seq模型, 两个seq序列长度可以不相等

  • transformer也是encoder-decoder的总体架构, 如上图。transformer主要由4个部分组成:

    • 输入部分(输入输出嵌入与位置编码)
    • 多层编码器
    • 多层解码器
    • 以及输出部分(输出线性层与softmax)
  • 模块介绍

    • Input Embedding: 输入嵌入。将源文本中的词汇数字表示转换为向量表示,捕捉词汇间的关系
    • Positional Encoding: 位置编码。为输入序列的每个位置生成位置向量,以便模型能够理解序列中的位置信息
    • Output Embedding: 输出嵌入。将目标文本中的词汇数字表示转换为向量表示
    • Linear: 线性层。将decoder输出的向量转换为最终的输出维度
    • Softmax: softmax层。将线性层的输出转换为概率分布,以便进行最终的预测
    • encoder架构: encoder由6个相同的encoder层组成,每个层包括两个子层:一个多头自注意力层(multi-head self-attention)和一个逐位置的前馈神经网络(point-wise feed-forward network);每个子层后都会使用残差连接(residual connection)和层归一化(layer normalization)连接,即Add&Norm。如下图
    • decoder架构:decoder包含6个相同的decoder层,每层包含3个子层掩码自注意力层masked self-attention),encoder-decoder交叉注意力层逐位置的前馈神经网络。每个子层后都有残差连接层归一化操作,即Add&Norm。如下图

2. transformer子层解析

  • encoder和decoder的本质区别:self-attention的masked掩码机制
  • muitl-head进行masked的目的:在生成文本时,确保模型只依赖已知的信息,而不是未来的内容,对未来信息进行掩码处理,这样才能学会预测
  • multi-head的目的:让模型关注输入的不同部分或者不同信息,比如一个名词的修饰词,一个名词的分类,一个名词对象的情感、诗意等,从直观的到抽象的,捕获复杂的依赖关系
  • Add:残差连接。缓解梯度消失问题;网络输入x与网络输出F(x)相加,求导时相当于添加常数项1,缓解梯度消失问题
  • Norm:层归一化。在每个层上独立进行,使激活值具有相同的均值和方差,通常是0和1;在transformer中,Norm操作通常紧跟在Add之后,对残差连接结果进行归一化,以加速训练并稳定模型性能
  • 前馈网络:对输入进行非线性变换,提取更高级别的特征/信息
  • 逐位前馈神经网络:是一个简单的全连接神经网络,在模型中起到增加非线性和学习更复杂表示的作用。逐位的意思是逐个元素element或点进行独立且相同的操作,不是跨位置或跨元素来进行的。逐位前馈神经网络通常包括两个全连接层一个ReLU激活层,两个全连接层对应两个线性变换,第一个全连接层之后接ReLU激活函数引入非线性,使模型能够学习更复杂的表示。第一个全连接层通常对输入进行增维表示,第二个全连接层降维到模型输出所需的维度

3. transformer注意力机制

  • transformer的3种注意力层:在transformer架构中有3种不同的注意力层
    • self-attention layer自注意力层:编码器输入序列通过multi-head self-attention计算自注意力权重
    • casual attention layer因果自注意力层:解码器的单个序列通过masked multi-head self-attention计算自注意力权重
    • cross attention layer交叉注意力层:编码器-解码器两个序列通过multi-head cross attention进行注意力转移
  • 注意力机制的过程说明
  • 缩放点积注意力

上图是缩放点积注意力示意图,计算公式

其中,softmax内部是注意力分数,softmax整个是注意力权重,乘以缩放因子 1 d k \frac{1}{\sqrt{d_{k} } } dk 1是为了缓解可能的梯度消失问题(softmax值过大时), d k d_{k} dk是Q或者K的维度大小

  • 多头注意力机制

上图是多头注意力机制示意图,多个注意力头并行运行,每个头都会独立地计算注意力权重和输出,这里采用的是缩放点积注意力来计算;

然后将所有头的输出拼接concat起来得到最终的输出;

多头其实是为了提取不同维度的信息,捕获复杂的依赖关系,增强模型的表示能力;最后多个头结果进行拼接,避免单个计算的误差,即避免只关注单方面维度信息的误差

计算公式:

在transformer原文中,head_num = 8,d_k=d_v=64

  • 交叉注意力机制

    • 自注意力机制,QKV都来自同一序列,如下
    • 交叉注意力机制,输入来自两个不同的序列,一个序列用作查询Q(来自decoder states的queries),另一个序列提供键K和值V(来自encoder states的keys和values),实现跨序列的交互和注意力转移,如下
  • 因果注意力机制

    • 为了确保模型在生成序列时只依赖于之前的输入信息,而不会受到未来信息的影响。casual self-attention通过掩码未来位置来实现这一点;使模型在预测某个位置的输出时,只看到该位置及之前的输入。如下图所示
    • 其中掩码未来位置的原因通过下图说明:
    • 掩码机制通过下图说明,加一个很大的负数,softmax之后就是0,如下

4. transformer部分释疑

  • 问题1:transformer相对RNN能处理长序列数据, 同时能进行并行计算, LSTM相对RNN进行改进的, 解决长时依赖问题, 那么transformer相对于LSTM有什么优势
    • (1)LSTM在解决长时依赖仍有局限。LSTM依赖cell state来传递长时信息,限制了其全局信息捕获能力;而transformer的自注意力机制可以考虑任意两个位置之间的依赖关系,能更好的捕捉全局的、长距离的依赖信息
    • (2)transformer的可解释性更强:transformer计算每个位置与所有位置的依赖关系,使得模型的预测结果更易于解释,LSTM的解释性相对较弱
    • (3)并行计算能力:transformer不用像LSTM等待上一时间步的输出作为下一时间步的输入,可以实现完全并行的计算,更容易进行分布式计算和加速
    • (4)扩展性和灵活性:transformer结构相对灵活,可以轻松扩展到更大的数据集和更复杂的任务中

 

  • 问题2:同问题1, transformer通过怎样的设计能够实现并行计算的?

    • 参考这个图,可以并行计算一个位置和其他所有位置的依赖关系
  • 问题3:层归一化Norm和batch normalization的区别

    • 都是归一化,但层归一化不是批量归一化;
    • LN是对每个样本的每个层进行的归一化,即对每个样本的所有特征做归一化;
    • 而BN是对每个batch数据进行归一化,即对batch_size内的每个特征做归一化;
    • LN保留了不同特征之间的大小关系,抹平了不同样本之间的大小关系,所以LN更适合NLP领域任务;
    • 而BN保留了不同样本之间的大小关系,抹平了不同特征之间的大小关系,所以BN更适合于依赖不同样本之间关系的任务,如CV领域
    • LN可以缓解梯度消失问题、改善系统对缩放摆幅变化的鲁棒性、更适用于小样本数据情况
    • 而BN旨在提高模型的训练速度和稳定性,使模型学习效率更高,降低测试错误率和泛化误差

 

  • 问题4:encoder和decoder的本质区别self-attention是否masked,如何理解
    • encoder中每个元素都能管住整个序列中的所有其他元素,生成新的输出表示。处理整个输入序列,不需要掩码未来的信息
    • decoder在生成序列时,只能依赖已经生成的部分,而不能依赖未来的信息。masked处理的是输出序列,将未来位置的注意力权重设置为0,从而限制模型的关注点在已生成的序列上,实现了类似条件语言模型的功能
    • decoder和encoder交叉注意力层,decoder允许关注encoder的输出,从而融合encoder中的信息到生成过程

 

  • 问题5:transformer训练的过程参数有哪些,除了W_Q/K/V这几个参数矩阵以外
    • (1)嵌入维度:输入和输出嵌入的维度,词嵌入和位置编码的维度。比如词嵌入矩阵大小为词汇表大小如50000 * d_词嵌入向量的维度
    • (2)multi-head attention的num_heads:注意力头数,决定模型并行关注输入序列不同部分的能力,每个头都会产生一个独立的注意力权重矩阵。论文中num_heads = 8
    • (3)隐藏层层数:每个encoder层和decoder层都保持一致
    • (4)前馈神经网络隐藏层大小:神经元个数,通常比层数大很多,以便能学习复杂的特征表示
    • (5)encoder和decoder的层数:定义了模型中encoder和decoder各自包含的层数,论文中n_layers = 6,即6个encoder层和6个decoder层
    • (6)位置编码的维度:输入输出序列进入encoder/decoder层时都要进行位置编码,通常与嵌入维度相同,以便和嵌入向量直接相加
    • (7)训练参数:像学习率,选用的优化器,batch_size,epoches等
    • (8)正则化参数:如dropout rate随机失活的神经元比例防止过拟合,L2正则化等
    • (9)权重初始化方法:如随机初始化,Xavier初始化,He初始化等,合理的初始化能加快训练的过程尽快找到最优解

 

  • 问题6:QKV计算的过程,W矩阵都是可以训练的

  • 问题7:self-attention和(cross)attention的区别

    • self-attention设置source=target,即query=key=value,然后计算内部依赖关系

 

  • 问题8:预训练模型BERT和transformer是什么关系
    • BERT(Bidirectional Encoder Representations from Transformers)使用transformer的encoder结构来构建的,输入与transformer类似,包括token/segment/position embedding等,这些embedding将输入文本序列转换为模型可以理解的向量表示;
    • 在BERT中可以选择encoder层的数量,轻量级模型通常使用12层,重量级模型通常使用24层;transformer的自注意力机制使BERT能够关注双向上下文的信息

 

  • 问题9:transformer模型训练的时候采用了什么损失函数
    • transformer训练过程主要采用了交叉熵损失函数(负对数似然损失函数)来衡量模型预测的概率分布真实分布之间的差异,也可以采用KL散度;
    • 并且可以计算向量空间距离MSE,即两组概率向量的空间距离

 


 
创作不易,如有帮助,请 点赞 收藏 支持
 


 

[参考文章]
[1]. transformer注意力机制解析
[2]. Seq2Seq的注意力机制
[3]. attention机制图示
[4]. LN与BN的区别
[5]. Seq2Seq的注意力机制
[6]. transformer的decoder结构
[7]. decoder-only和编解码器区别
[8]. Attention is All You Need翻译
[9]. transformer结构详解,推荐

created by shuaixio, 2024.06.23

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/781528.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI编程探索- iOS 实现类似苹果地图 App 中的半屏拉起效果

想要的效果 功能分析 想要实现这种效果,感觉有点复杂,于是就想搜一下相关资料看看,可问题是,我不知道如何描述这种效果😂。 当我们遇到这种效果看着很熟悉,但是不知道如何描述它具体是什么的时候&#…

有一个日期(Date)类的对象和一个时间(Time)类的对象,均已指定了内容,要求一次输出其中的日期和时间

可以使用友元成员函数。在本例中除了介绍有关友元成员函数的简单应用外,还将用到类的提前引用声明,请读者注意。编写程序: 运行结果: 程序分析: 在一般情况下,两个不同的类是互不相干的。display函…

华为云OBS 通过S3客户端访问

华为云好像没有对S3协议的支持说明其实底层是支持S3协议的。 使用S3的时候我们会需要endpoint,桶名字,region,AWS_ACCESS_KEY,AWS_SECRET_KEY 其中endpoint 就是图片中的,桶名字也很容易找到,region 就是你的endpoint…

Nestjs基础

一、创建项目 1、创建 安装 Nest CLI(只需要安装一次) npm i -g nestjs/cli 进入要创建项目的目录,使用 Nest CLI 创建项目 nest new 项目名 运行项目 npm run start 开发环境下运行,自动刷新服务 npm run start:dev 2、…

Maven一键配置阿里云远程仓库,让你的项目依赖飞起来!

文章目录 引言一、为什么选择阿里云Maven仓库?二、如何设置Maven阿里云远程仓库?三、使用阿里云Maven仓库的注意事项总结 引言 在软件开发的世界里,Maven无疑是一个强大的项目管理工具,它能够帮助我们自动化构建、依赖管理和项目…

C++初学者指南-5.标准库(第一部分)--迭代器

C初学者指南-5.标准库(第一部分)–迭代器 Iterators 文章目录 C初学者指南-5.标准库(第一部分)--迭代器 Iterators1.默认正向迭代器2.反向迭代器3.基于迭代器的循环4.示例:交换相邻的一对元素5.迭代器范围6.迭代器范围中的元素数量7. 总结:迭代器 指向某…

动态规划|剑指 Offer II 093. 最长斐波那契数列

如果数组 arr 中存在三个下标 i、j、k 满足 arr[i]>arr[j]>arr[k] 且 arr[k]arr[j]arr[i],则 arr[k]、arr[j] 和 arr[i] 三个元素组成一个斐波那契式子序列。由于数组 arr 严格递增,因此 arr[i]>arr[j]>arr[k] 等价于 i>j>k。 把这道题…

记录第一次使用air热更新golang项目

下载 go install github.com/cosmtrek/airlatest 下载时提示: module declares its path as: github.com/air-verse/air but was required as: github.com/cosmtrek/air 此时,需要在go.mod中加上这么一句: replace github.com/cosmtrek/air &…

jmeter-beanshell学习4-beanshell截取字符串

再写个简单点的东西,截取字符串,参数化文件统一用csv,然后还要用excel打开,如果是数字很容易格式就乱了。有同事是用双引号把数字引起来,报文里就不用加引号了,但是这样beanshell处理起来,好像容…

插入排序——C语言

假设我们现在有一个数组,对它进行排序,插入排序的算法如同它的名字一样,就是将元素一个一个插入到合适的位置,那么,该如何做呢? 如果我们要从小到大进行排序的话,步骤如下: 1.对于…

LabVIEW机器视觉系统中的图像畸变、校准和矫正

在机器视觉应用中,图像畸变、校准和矫正是确保图像准确性的关键步骤。LabVIEW作为一种强大的图像处理和分析工具,提供了一系列功能来处理这些问题。以下是对图像畸变、校准和矫正的详细介绍。 图像畸变 图像畸变 是指由于摄像镜头的光学特性或拍摄角度问…

二分法查找有序表的通用算法(可查链表,数组,字符串...等等)

find_binary函数 注意事项: (1)你设计的迭代器模板中必须有using value_type T,且有加减运算功能,其本上能与C标准库std中一样。 (2)集合必须是有序的。 下面是函数代码: /// &…

土豆炒肉做法

菜单:土豆、葱、铁辣子、纯瘦肉、淀粉、生抽、酱油、刀、案板、十三香、盐巴、擦板 流程: 洗土豆,削皮,擦成条,用凉水过滤两遍淀粉,顺便放个燥里洗肉,切成条,按照生抽、酱油、淀粉、…

react dangerouslySetInnerHTML将html字符串以变量方式插入页面,点击后出现编辑状态

1.插入变量 出现以下编辑状态 2.解决 给展示富文本的标签添加css样式 pointerEvents: none

JAVA之(方法的重载与重写、this关键字、super关键字)

方法的重载与重写 一、方法的重载与重写1、回顾方法的定义2、重载的概念3、重写 二、this关键字1、何为this方法2、使用方法(1)在构造方法中指构造器所创建的新对象(2) 方法中指调用该方法的对象(3) 在类本…

【植物大战僵尸杂交版】获取+存档插件

文章目录 一、还记得《植物大战僵尸》吗?二、在哪下载,怎么安装?三、杂交版如何进行存档功能概述 一、还记得《植物大战僵尸》吗? 最近,一款曾经在15年前风靡一时的经典游戏《植物大战僵尸》似乎迎来了它的"文艺复…

自用款 复制粘贴工具 Paste macOS电脑适配

Paste是一款专为Mac和iOS用户设计的剪贴板管理工具,它提供了强大的剪贴板增强功能。Paste能够实时记录用户复制和剪切的内容,包括文本、图片、链接等多种数据类型,并形成一个可视化的剪贴板历史记录,方便用户随时访问和检索。此外…

嵌入式鸿蒙系统openharmony编译方法详解

大家好,时光如梭,今天主要给大家分享一下,鸿蒙系统的使用方法,以及源码该如何编译,其中要注意的细节有哪些? 第一:OpenHarmony系统简介 OpenHarmony 是由开放原子开源基金会(OpenAtom Foundation)孵化及运营的开源项目, 目标是面向全场景、全连接、全智能时代,基于…

vite简介

vite是新一代前端构建工具,vite具有优势如下: 轻量快速的热重载(HMR),能实现快速的服务启动。对TypeScript、JSX、CSS等支持开箱即用。真正的按需编译,不再等待整个应用编译完成。webpack构建与vite构建对…

html+css+JavaScript 实现两个输入框的反转动画

开发时遇到了一个输入框交换的动画 做完之后觉得页面上加些许过渡或动画,其变化虽小,却能极大的提升页面质感,给人一种顺畅、丝滑的视觉体验。它的实现过程主要是通过css中的transition和animation来实现的。平时在开发的时候增加一些动画效…