【AIGC入门一】Transformers 模型结构详解及代码解析

Transformers 开启了NLP一个新时代,注意力模块目前各类大模型的重要结构。作为刚入门LLM的新手,怎么能不感受一下这个“变形金刚的魅力”呢?

目录

Transformers ——Attention is all You Need

背景介绍

模型结构

位置编码

代码实现:

Attention

Scaled Dot-product Attention

Multi-head Attention

Position-Wise Feed-Forward Networks

Encoder and Decoder

Add & Norm

mask 机制

参考链接


论文链接:Attention Is All You Need

Transformers ——Attention is all You Need

背景介绍

        在Transformer提出之前,NLP主要基于RNN、LSTM等算法解救相关问题。这些模型在处理长序列时面临梯度消失和梯度爆炸等问题,且这些模型是串行计算的,运行时间较长。

        Transformer 模型的提出是为了摆脱序列模型的顺序依赖性,引入了注意力机制,使得模型能够在不同位置上同时关注输入序列的各个部分,且支持并行计算。该模型的提出对深度学习和自然语言处理领域产生了深远的影响,成为了现代NLP模型的基础架构,并推动了attention 机制在各种任务中的应用。

模型结构

位置编码

        任何一门语言,单词在句子中的位置以及排列顺序是非常重要的。一个单词在句子的位置或者排列顺序不同,整个句子的意义就发生了偏差。举个例子:

小明小王500块

小王小明500块

顺序不同,债主关系就发生变化了😑

        当采用了Attention之后,句子中的词序信息就会丢失,模型就没法知道每个词在句子中的相对和绝对的位置信息。目前位置编码有多种方法:

(1)整型值标记位置,即第一个token标记为1, 第二个token标记为2。。。以此类推

         可能存在的问题:

  • 随着序列长度的增加,位置值会越来越大;
  • 推理的序列长度比训练时所用的序列长度更长,不利于模型的泛化

(2)用[0,1] 范围标记位置

        将位置值的范围限制在[0,1]之内,即在第一种的方法进行归一化操作(除以序列长度)。比如有4个token,那么位置信息就是[0, 0.33, 0.69, 1]。 但这样产生的问题是,当序列长度不同时,token间的相对距离是不一样的。

        因此,一个好的位置编码方法应该满足以下特性:

(1)可以表示一个token 在序列中的绝对位置;

(2)在序列长度不同的情况下,不同序列中token 的相对位置/ 距离要保持一直;

(3)可以扩展到更长的句子长度;

        Transformers 中选择的是sincos编码法,其公式如下所示:

        其中,pos 是token在sentence中的位置,i是维度。

代码实现:

        假设句子长度是 s, embedding的维度是d, 最终生成的PE的shape是(s, d)。公式的核心是计算pos /10000\tfrac{2i}{d_{model}}, 这里可以借助对数和指数的性质进行如下操作:

a = e^{^{loga}}

        所以可以转换成1/ 10000^{^{2i/d_{model}}} = e^{ - log(10000) * 2i /d_{model})}(可对照代码进行推导理解)

class Position_Encoding(nn.Module):
    
    def __init__(self, max_length, d_model):
        self.max_length = max_length
        self.dim = d_model
        
        pe = torch.zeros(self.max_length, self.dim)
        
        position = torch.arange(0, self.max_length).unsqueeze(1)
        
        div_term = torch.exp( torch.arange(0, self.dim ,2) * (-1) *math.log(10000) / self.dim)
        
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        

    def forward(self, x):
        
        #  input_embedding + position_encoding
        #....

Attention

        Attention 是将query 和key、value映射为输出值,其中query 和 key 计算一个相似度,然后以这个相似度为权重,计算value的加权和,最终得到输出。

Scaled Dot-product Attention

        论文中用的是放缩点乘注意力(scaled dot-product attention),其公式是:

Attention(Q, K, V) = softmax(\frac{QK^{}{T}}{sqrt(d_k))})

其中, 计算时需要用到的矩阵Q(查询),K(键值),V(值)是输入单词的embedding 变换或者 上一个Encoder block的输出 。注意Q、K、V的shape 会存在一定的联系(因为需要做矩阵乘法运算)。

        公式中会除以dk的平方根,从而避免内积过大。还有解释是说 softmax 在 绝对值较大的区域梯度较小,梯度下降的速度比较慢。因此希望softmax的点乘数值尽可能小。

        论文中解释了为什么点积会变大。假设q 和 k中的元素满足独立分布,且均值是0,方差为1。点积 q*k = \sum_{i=1}^{dk} q_i * k_i   的均值是0, 方差是dk 。

Multi-head Attention

        作者发现,相比直接在dmodel 维度上的 q、k、v进行attention计算 ,使用不同的、可学习的linear function 分别地对q、k、v 进行多次映射(映射的维度是dk, dk, dv)  , 然后对每一组映射的q、k、v进行attention 并行计算,并concat得到最终输出。后一种方法更有效。就像卷积层可以用多个卷积核生成多个通道的特征,在Transformers中可以用多组self attention 生成多组注意力的结果,从而增加特征表示。其计算公式和流程图如下:

        注意: head的数量 * 每一组head中q的维度 = dmodel (输入Q的维度)

Position-Wise Feed-Forward Networks

        前馈网络比较简单,是一个两层的全连接层,第一层的激活函数是ReLU,第二层不使用激活函数,对应的公式如下所示:

Encoder and Decoder

       Transformer 从结构上可以分为Encoder 和Decoder 两个部分,这两者结构上比较类似,但也存在一些差异。

        上图红色区域对应的是Encoder部分,可以看出是由 Input Embedding 、Position Encoding 和6层的EncoderLayer组成。 EncoderLayer 主要包括Multi-head Attention, Add&Norm, Feed Forward ,Add&Norm。

        上图绿色区域对应的是Decoder部分,相比Encoder,需要注意Decoder中的Multi-head Attention 有所不同。首先是Masked Multi-head Attention, 是为了实现串行推理;第二个Multi-head Attention输入的Q、K、V来自不同的地方,其中Q是Masked Multi-head Attention 的输出, K和V是Encoder 的输出。

Add & Norm

        这部分主要由Add 和 Norm 组成,其计算公式如下所示:

        Add 是一种残差结构,和ResNet中的是一样的,可以帮助网络收敛。Norm 是指Layer Norm。

mask 机制

        Transformers中比较重要的一个知识点就是mask设置。mask主要来源有两个:第一个是填充操作的空白字符(为了保证batch内句子的长度一样会进行padding操作);第二个是因为模拟串行推理需要用到mask(Decoder部分)。

        一般情况下, query 和 key都是一样的,但是在Decoder的第二个多头注意力层中,query 来自目标语言,key来自源语言。为了生成mask, 首先要知道query 和 key中<pad> 字符的分布情况,它们的形状为[n, seq_len]。如果某处是True, 表明这个地方的字符是<pad>。

src_pad_mask = x == pad_idx
dst_pad_mask = y == pad_idx

        为了实现串行推理,即某字符只能知道该字符以及该字符之前的内容,即一个下三角全1矩阵。mask矩阵需要取反,实现方式如下所示:

mask = 1 - torch.tril(torch.ones(mask_shape))

        最后根据<pad>字符分布情况分别将mask对应的行或者列置1。

参考链接

  1. GitHub - P3n9W31/transformer-pytorch: Transformer model for Chinese-English translation.
  2. PyTorch Transformer 英中翻译超详细教程 - 知乎
  3. Transformer模型详解(图解最完整版) - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/324939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设计模式之开闭原则:如何优雅地扩展软件系统

在现代软件开发中&#xff0c;设计模式是解决常见问题的最佳实践。其中&#xff0c;开闭原则作为面向对象设计的六大基本原则之一&#xff0c;为软件系统的可维护性和扩展性提供了强大的支持。本文将深入探讨开闭原则的核心理念&#xff0c;以及如何在实际项目中运用这一原则&a…

Rust-借用和生命周期

生命周期 一个变量的生命周期就是它从创建到销毁的整个过程。其实我们在前面已经注意到了这样的现象&#xff1a; 然而&#xff0c;如果一个变量永远只能有唯一一个入口可以访问的话&#xff0c;那就太难使用了。因此&#xff0c;所有权还可以借用。 借用 变量对其管理的内存…

C#编程-自定义属性

命名自定义属性 让我们继续漏洞修复示例,在这个示例中新的自定义属性被命名为BugFixingAttribute。通常的约定是在属性名称后添加单词Attribute。编译器通过允许您调用具有短版名称的属性来支持附加。 因此,可以如以下代码段所示编写该属性: [ BugFixing ( 122,"Sara…

C#用double.TryParse(String, Double)方法将字符串类型数字转换为数值类型

目录 一、定义 二、实例 命名空间: System 程序集: System.Runtime.dll 一、定义 将数字的字符串表示形式转换为它的等效双精度浮点数。 一个指示转换是否成功的返回值。 public static bool TryParse (string? s, out double result…

蓝桥杯备赛 | 洛谷做题打卡day3

蓝桥杯备赛 | 洛谷做题打卡day3 sort函数真的很厉害&#xff01; 文章目录 蓝桥杯备赛 | 洛谷做题打卡day3sort函数真的很厉害&#xff01;【深基9.例1】选举学生会题目描述输入格式输出格式样例 #1样例输入 #1 样例输出 #1 我的一些话 【深基9.例1】选举学生会 题目描述 学校…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例4-2 常用表单控件

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>常用表单控件</title> <style> form {width: 260px;margin: 0 auto;border: 1px solid #ccc;padding: 20px; } .right {float: right; } </style&g…

【学习笔记】2、逻辑代数与硬件描述语言基础

2.1 逻辑代数 &#xff08;1&#xff09;逻辑代数的基本定律和恒等式 基本定律或 “”与 “”非 “—”0-1律A0AA11AAAA A ‾ \overline{A} A1(互补律)A00A1AAAAA A ‾ \overline{A} A0 A ‾ ‾ \overline{\overline{A}} AA结合律(AB)C A(BC)(AB)CA(BC)ABC交换律AB BAABBA分…

广州市生物医药及高端医疗器械产业链大会暨联盟会员大会召开,天空卫士数据安全备受关注

12月20日&#xff0c;广州市生物医药及高端医疗器械产业链大会暨联盟会员大会在广州举办。在本次会议上&#xff0c;作为大会唯一受邀参加主题分享的技术供应商&#xff0c;天空卫士南区技术总监黄军发表《生物制药企业如何保护数据安全》的主题演讲。 做好承上启下“连心桥”…

【Spring实战】29 @Value 注解

文章目录 1. 定义2. 好处3. 示例1&#xff09;注入基本类型2&#xff09;注入集合类型3&#xff09;使用默认值4&#xff09;注入整数和其他类型 总结 在实际的应用中&#xff0c;我们经常需要从外部配置文件或其他配置源中获取参数值。Spring 框架提供了 Value 注解&#xff0…

感染了后缀为.mallox勒索病毒如何应对?数据能够恢复吗?

尊敬的读者&#xff1a; 在数字时代&#xff0c;勒索病毒如.mallox已经成为网络威胁中的重要一环。这篇文章将深入介绍.mallox勒索病毒的特征、应对策略以及如何预防这一威胁。面对复杂的勒索病毒&#xff0c;您需要数据恢复专家作为坚强后盾。我们的专业团队&#xff08;技术…

adb wifi 远程调试 安卓手机 命令

使用adb wifi 模式调试需要满足以下前提条件&#xff1a; 手机 和 PC 需要在同一局域网下。手机需要开启开发者模式&#xff0c;然后打开 USB 调试模式。 具体操作步骤如下&#xff1a; 将安卓手机通过 USB 线连接到 PC。&#xff08;连接的时候&#xff0c;会弹出请求&#x…

收银系统源码-智慧新零售系统框架

智慧新零售系统是一套线下线上打通的收银系统&#xff0c;主要给门店提供含线下收银、线上小程序商城、ERP进销存、精细化会员管理、丰富营销插件等为一体的智慧行业解决方案。智慧新零售系统有合伙人、代理商、商户、门店、收银员/导购员等角色&#xff0c;每个角色有相应的权…

Fine-tuning:个性化AI的妙术

在本篇文章中&#xff0c;我们将深入探讨Fine-tuning的概念、原理以及如何在实际项目中运用它&#xff0c;以此为初学者提供一份入门级的指南。 一、什么是大模型 ChatGPT大模型今年可谓是大火&#xff0c;在正式介绍大模型微调技术之前&#xff0c;为了方便大家理解&#xf…

C++系统笔记教程----vscode远程连接ssh

C系统笔记教程 文章目录 C系统笔记教程前言开发环境配置总结 前言 开发环境配置 Ubuntu20.24VScode 如果没有linux系统&#xff0c;但是想用其编译&#xff0c;可以使用ssh远程连接。 首先进入vscode,打开远程连接窗口&#xff08;蓝色的小箭头这&#xff09; 选择连接到主机…

Redis主从架构、哨兵集群原理实战

1.主从架构简介 背景 单机部署简单&#xff0c;但是可靠性低&#xff0c;且不能很好利用CPU多核处理能力生产环境必须要保证高可用&#xff0c;一般不可能单机部署读写分离是可用性要求不高、性能要求较高、数据规模小的情况 目标 读写分离&#xff0c;扩展主节点的读能力&…

案例:应用内字体大小调节

文章目录 介绍相关概念完整实例 代码结构解读保存默认大小获取字体大小修改字体大小 介绍 本篇Codelab将介绍如何使用基础组件Slider&#xff0c;通过拖动滑块调节应用内字体大小。要求完成以下功能&#xff1a; 实现两个页面的UX&#xff1a;主页面和字体大小调节页面。拖动…

【电子通识】开漏输出和推挽输出有什么差别?

在看一些MCU芯片手册的时候&#xff0c;能发现GPIO的功能有开漏输出和推挽式输出。那么这两种输出到底有什么差别&#xff1f; 如下是STM32F10xxx参考手册中对于GPIO的功能描述&#xff1a; 如下为GPIO内部框图&#xff1a; 在一些其他的芯片规格书中也同样看到不同的GPIO工作…

java基于Spring Boot的灾害应急救援评估调度平台

灾害应急救援平台的目的是让使用者可以更方便的将人、设备和场景更立体的连接在一起。能让用户以更科幻的方式使用产品&#xff0c;体验高科技时代带给人们的方便&#xff0c;同时也能让用户体会到与以往常规产品不同的体验风格。&#xff08;1&#xff09;鉴于该系统是一款面向…

详细讲解Python连接Mysql的基本操作

目录 前言1. mysql.connector2. pymysql 前言 连接Mysql一般有几种方法&#xff0c;主要讲解mysql.connector以及pymysql的连接 后续如果用到其他库还会持续总结&#xff01; 对于数据库中的表格,本人设计如下:(为了配合下面的操作) 1. mysql.connector mysql.connector 是一…

高通平台开发系列讲解(PCIE篇)MHI (Modem Host Interface)驱动详解

文章目录 一、MHI驱动代码二、MHI读数据流程三、MHI写数据流程沉淀、分享、成长,让自己和他人都能有所收获!😄 📢MHI (Modem Host Interface)我们通过名字顾名思义知道,它是Modem与Host的桥梁。 MHI 可以很容易地适应任何外围总线,但它主要用于基于 PCIe 的设备。 MHI(…