从零开始实现大语言模型(四):简单自注意力机制

1. 前言

理解大语言模型结构的关键在于理解自注意力机制(self-attention)。自注意力机制可以判断输入文本序列中各个token与序列中所有token之间的相关性,并生成包含这种相关性信息的context向量。

本文介绍一种不包含训练参数的简化版自注意力机制——简单自注意力机制(simplified self-attention),后续三篇文章将分别介绍缩放点积注意力机制(scaled dot-product attention),因果注意力机制(causal attention),多头注意力机制(multi-head attention),并最终实现OpenAI的GPT系列大语言模型中MultiHeadAttention类。

2. 从循环神经网络到自注意力机制

解决机器翻译等多对多(many-to-many)自然语言处理任务最常用的模型是sequence-to-sequence模型。Sequence-to-sequence模型包含一个编码器(encoder)和一个解码器(decoder),编码器将输入序列信息编码成信息向量,解码器用于解码信息向量,生成输出序列。在Transformer模型出现之前,编码器和解码器一般都是一个循环神经网络(RNN, recurrent neural network)。

RNN是一种非常适合处理文本等序列数据的神经网络架构。Encoder RNN对输入序列进行处理,将输入序列信息压缩到一个向量中。状态向量 h 0 h_0 h0包含第一个token x 0 x_0 x0的信息, h 1 h_1 h1包含前两个tokens x 0 x_0 x0 x 1 x_1 x1的信息。以此类推, Encoder RNN最后一个状态 h m h_m hm是整个输入序列的概要,包含了整个输入序列的信息。Decoder RNN的初始状态等于Encoder RNN最后一个状态 h m h_m hm h m h_m hm包含了输入序列的信息,Decoder RNN可以通过 h m h_m hm知道输入序列的信息。Decoder RNN可以将 h m h_m hm中包含的信息解码,逐个元素地生成输出序列。

RNN的神经网络结构及计算方法使Encoder RNN必须用一个隐藏状态向量 h m h_m hm记住整个输入序列的全部信息。当输入序列很长时,隐藏状态向量 h m h_m hm对输入序列中前面部分的tokens的偏导数(如对 x 0 x_0 x0的偏导数 ∂ h m x 0 \frac{\partial h_m}{x_0} x0hm)会接近0。输入不同的 x 0 x_0 x0,隐藏状态向量 h m h_m hm几乎不会发生变化,即RNN会遗忘输入序列前面部分的信息。

本文不会详细介绍RNN的原理,大语言模型的神经网络中没有循环结构,RNN的原理及结构与大语言模型没有关系。对RNN的原理感兴趣读者可以参见本人的博客专栏:自然语言处理。

2014年,文章Neural Machine Translation by Jointly Learning to Align and Translate提出了一种改进sequence-to-sequence模型的方法,使Decoder每次更新状态时会查看Encoder所有状态,从而避免RNN遗忘的问题,而且可以让Decoder关注Encoder中最相关的信息,这也是attention名字的由来。

2017年,文章Attention Is All You Need指出可以剥离RNN,仅保留attention,且attention并不局限于sequence-to-sequence模型,可以直接用在输入序列数据上,构建self-attention,并提出了基于attention的sequence-to-sequence架构模型Transformer。

3. 简单自注意力机制

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。如下图所示,简单自注意力机制生成context向量的计算步骤如下:

  1. 计算注意力分数(attention score):简单注意力机制使用向量的点积(dot product)作为注意力分数,注意力分数可以衡量两个向量的相关性;
  2. 计算注意力权重(attention weight):将注意力分数归一化得到注意力权重,序列中每个token与序列中所有tokens之间的注意力权重之和等于1;
  3. 计算context向量:简单注意力机制将所有tokens对应Embedding向量的加权和作为context向量,每个token对应Embedding向量的权重等于其相应的注意力权重。

图一

3.1 计算注意力分数

对输入文本序列***“Your journey starts with one step.”* **做tokenization,将文本中每个单词分割成一个token,并转换成Embedding向量,得到 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6。自注意力机制分别计算 x i x_i xi x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力权重,进而计算 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6与其相应注意力权重的加权和,得到context向量 z i z_i zi

如下图所示,将context向量 z i z_i zi对应的向量 x i x_i xi称为query向量,计算query向量 x 2 x_2 x2对应的context向量 z 2 z_2 z2的第一步是计算注意力分数。将query向量 x 2 x_2 x2分别点乘向量 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6,得到实数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26,其中 ω 2 i \omega_{2i} ω2i是query向量 x 2 x_2 x2与向量 x i x_i xi的注意力分数,可以衡量 x 2 x_2 x2对应token与 x i x_i xi对应token之间的相关性。

图二

两个向量的点积等于这两个向量相同位置元素的乘积之和。假如向量 x 1 = ( x 11 , x 12 , x 13 ) x_1=(x_{11}, x_{12}, x_{13}) x1=(x11,x12,x13),向量 x 2 = ( x 21 , x 22 , x 23 ) x_2=(x_{21}, x_{22}, x_{23}) x2=(x21,x22,x23),则向量 x 1 x_1 x1 x 2 x_2 x2的点积等于 x 11 × x 21 + x 12 × x 22 + x 13 × x 23 x_{11}\times x_{21} + x_{12}\times x_{22} + x_{13}\times x_{23} x11×x21+x12×x22+x13×x23

可以使用如下代码计算query向量 x 2 x_2 x2 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力分数:

import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

执行上面代码,打印结果如下:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

3.2 计算注意力权重

如下图所示,将注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化可得到注意力权重 α 21 , α 22 , ⋯   , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。每个注意力权重 α 2 i \alpha_{2i} α2i的值均介于0到1之间,所有注意力权重的和 ∑ i α 2 i = 1 \sum_i\alpha_{2i}=1 iα2i=1。可以用注意力权重 α 2 i \alpha_{2i} α2i表示 x i x_i xi对当前context向量 z 2 z_2 z2的重要性占比,注意力权重 α 2 i \alpha_{2i} α2i越大,表示 x i x_i xi x 2 x_2 x2的相关性越强,context向量 z 2 z_2 z2 x i x_i xi的信息量比例应该越高。使用注意力权重对 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6加权求和计算context向量,可以使context向量的数值分布范围始终与 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6一致。这种数值分布范围的一致性可以使大语言模型训练过程更稳定,模型更容易收敛。

图三

可以使用softmax函数将注意力分数归一化得到注意力权重:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

执行上面代码,打印结果如下:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

3.3 计算context向量

简单注意力机制使用所有tokens对应Embedding向量的加权和作为context向量,context向量 z 2 = ∑ i α 2 i x i z_2=\sum_i\alpha_{2i}x_i z2=iα2ixi

图四

可以使用如下代码计算context向量 z 2 z_2 z2

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

执行上面代码,打印结果如下:

tensor([0.4419, 0.6515, 0.5683])

3.4 计算所有tokens对应的context向量

将向量 x 2 x_2 x2作为query向量,按照3.1所述方法,可以计算出注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26。使用softmax函数将注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化,可以得到注意力权重 α 21 , α 22 , ⋯   , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。Context向量 z 2 z_2 z2是使用注意力权重对 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的加权和。

计算所有tokens对应的context向量,可以使用矩阵乘法运算,分别将各个 x i x_i xi作为query向量,一次性批量计算注意力分数及注意力权重,并最终得到context向量 z i z_i zi

如下面代码所示,可以使用矩阵乘法,一次性计算出所有注意力分数:

attn_scores = inputs @ inputs.T
print(attn_scores)

执行上面代码,打印结果如下:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

@操作符是PyTorch中的矩阵乘法运算符号,与函数torch.matmul运算逻辑相同。

将一个 n n n m m m列的矩阵 A A A与另一个 m m m n n n B B B的矩阵相乘,结果 C C C是一个 n n n n n n列的矩阵。其中矩阵 C C C i i i j j j列元素等于矩阵 A A A的第 i i i行与矩阵 B B B的第 j j j列两个向量的内积。

如下面代码所示,使用softmax函数注意力分数归一化,可以一次批量计算出所有注意力权重:

attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

执行上面代码,打印结果如下:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

可以同样使用矩阵乘法运算,一次性批量计算出所有context向量:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

执行上面代码,打印结果如下:

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

4. 结束语

自注意力机制是大语言模型神经网络结构中最复杂的部分。为降低自注意力机制原理的理解门槛,本文介绍了一种不带任何训练参数的简化版自注意力机制。

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。简单自注意力机制生成context向量共3个步骤,首先计算注意力分数,然后使用softmax函数将注意力分数归一化得到注意力权重,最后使用注意力权重对所有tokens对应的Embedding向量加权求和得到context向量。

接下来,该去看看大语言模型中真正使用到的注意力机制了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/779839.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32-PWR和WDG看门狗

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. PWR1.1 PWR简介1.2 电源框图1.3 上电复位和掉电复位1.4 可编程电压监测器1.5 低功耗模式1.6 模式选择1.7 睡眠模式1.8 停止模式1.9 待机模式1.10 库函数 2. WDG看门狗2.1 WDG简介2.2 IWDG框图2.3 IWDG键寄存器2.4 …

ACM ICPS独立出版 | 2024年第三届计算与人工智能国际会议(ISCAI 2024)

会议简介 Brief Introduction 2024年第三届计算与人工智能国际会议(ISCAI 2024) 会议时间:2024年11月22 -24日 召开地点:中国大理 大会官网:www.iscai.org 2024年第三届计算与人工智能国际会议(ISCAI 2024)将围绕“计算与人工智能”的最新研究…

排序 -- 冒泡排序和快速排序

一、 交换排序 1、基本思想 所谓交换,就是根据序列中两个记录键值的比较结果来对换这两个记录在序列中的位置,交换排序的特点是:将键值较大的记录向序列的尾部移动,键值较小的记录向序列的前部移动。 2、常见的交换排序 1、冒泡…

Java Selenium入门程序

需求:使用chrome浏览器打开百度首页 1.配置浏览器驱动 (1)下载浏览器驱动,浏览器版本需与驱动版本一致; (2)编辑系统环境变量-->编辑Path-->填入浏览器驱动路径: 2.maven工…

【反悔贪心 反悔堆】1642. 可以到达的最远建筑

本文涉及知识点 反悔贪心 反悔堆 LeetCode1642. 可以到达的最远建筑 给你一个整数数组 heights ,表示建筑物的高度。另有一些砖块 bricks 和梯子 ladders 。 你从建筑物 0 开始旅程,不断向后面的建筑物移动,期间可能会用到砖块或梯子。 当…

刷题之删除有序数组中的重复项(leetcode)

删除有序数组中的重复项 这题简单题&#xff0c;双指针&#xff0c;一个指针记录未重复的数的个数&#xff0c;另一个记录遍历的位置。 以下是简单模拟&#xff0c;可以优化&#xff1a; class Solution { public:int removeDuplicates(vector<int>& nums) {int l0…

STL--求交集,并集,差集(set_intersection,set_union,set_difference)

set_intersection(重要) 求两个有序的序列的交集. 函数声明如下: template<class InputIterator1, class InputIterator2, class OutputIterator>OutputIterator set_intersection(InputIterator1 _First1, //容器1开头InputIterator1 _Last1, //容器2结尾(不包含)Inp…

ChatGPT4深度解析:探索智能对话新境界

大模型chatgpt4分析功能初探 目录 1、探测目的 2、目标变量分析 3、特征缺失率处理 4、特征描述性分析 5、异常值分析 6、相关性分析 7、高阶特征挖掘 1、探测目的 1、分析chat4的数据分析能力&#xff0c;提高部门人效 2、给数据挖掘提供思路 3、原始数据&#xf…

Navicat终于免费了, 但是这个结果很奇葩

个人用下载地址: 点呀 好家伙, 每个机构最多5个用户, 对于正在审计的公司…

DAY1: 实习前期准备

文章目录 VS Code安装的插件C/CCMakeGitHub CopilotRemote-SSH收获 VS Code 下载链接&#xff1a;https://code.visualstudio.com 安装的插件 C/C 是什么&#xff1a;C/C IntelliSense, debugging, and code browsing. 为什么&#xff1a;初步了解如何在VS Code里使用C输出…

Vulnhub-Os-hackNos-1(包含靶机获取不了IP地址)

https://download.vulnhub.com/hacknos/Os-hackNos-1.ova #靶机下载地址 题目&#xff1a;要找到两个flag user.txt root.txt 文件打开 改为NAT vuln-hub-OS-HACKNOS-1靶机检测不到IP地址 重启靶机 按住shift 按下键盘字母"E"键 将图中ro修改成…

筛选Github上的一些优质项目

每个项目旁都有标签说明其特点&#xff0c;如今日热捧、多模态、收入生成、机器人、大型语言模型等。 项目涵盖了不同的编程语言和领域&#xff0c;包括人工智能、语言模型、网页数据采集、聊天机器人、语音合成、AI 代理工具集、语音转录、大型语言模型、DevOps、本地文件共享…

7-6 每日升学消息汇总

复旦附中清北比例大涨&#xff0c;从统计数据来看&#xff0c;今年复附的清北人数将创历史新高&#xff0c;达到前所未有年进43人。离上海7月9号中考出分&#xff0c;还有3天。小道消息说&#xff0c;画狮的数游天下又回来了&#xff0c;目前还未官方消息。2024第二届国际数学夏…

安卓虚拟位置修改1.25beta支持路线模拟、直接定位修改

导语:更新支持安卓14/15&#xff0c;支持路线模拟、直接定位修改&#xff0c;仅支持单一版本 无root需根据教程搭配下方链接所提供的虚拟机便可进行使用 有root且具备XP环境可直接真机运行 如你有特殊需求 重启问题设置打开XP兼容 针对具有虚拟机检测的软件 建议如下 度娘搜索…

多表查询sql

概述&#xff1a;项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互关联,所以各个表结构之间也存在着各种联系&#xff0c;分为三种&#xff1a; 一对多多对多一对一 一、多表关系 一对多 案例&#xff1a;部门与…

在CMD中创建虚拟环境并在VSCode中使用和管理

1. 使用Conda创建虚拟环境 在CMD或Anaconda Prompt中执行以下代码以创建一个新的虚拟环境&#xff1a; conda create -n my_env python 3.8 这样会创建一个名为 my_env 的环境&#xff0c;并在Anaconda环境目录下生成一个相应的文件夹&#xff0c;包含该虚拟环境所需的所有…

STM32-ADC+DMA

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. ADC模拟-数字转换器1.1 ADC模拟-数字转换器1.2 逐次逼近型ADC1.3 ADC框图1.4 ADC基本结构1.5 输入通道1.6 规则组的转换模式1.6.1 单次转换&#xff0c;非扫描模式1.6.2 连续转换&#xff0c;非扫描模式1.6.3 单次…

时间、查找、打包、行过滤与指令的运行——linux指令学习(二)

前言&#xff1a;本节内容标题虽然为指令&#xff0c;但是并不只是讲指令&#xff0c; 更多的是和指令相关的一些原理性的东西。 如果友友只想要查一查某个指令的用法&#xff0c; 很抱歉&#xff0c; 本节不是那种带有字典性质的文章。但是如果友友是想要来学习的&#xff0c;…

如何确保 PostgreSQL 在高并发写操作场景下的数据完整性?

文章目录 一、理解数据完整性二、高并发写操作带来的挑战三、解决方案&#xff08;一&#xff09;使用合适的事务隔离级别&#xff08;二&#xff09;使用合适的锁机制&#xff08;三&#xff09;处理死锁&#xff08;四&#xff09;使用索引和约束&#xff08;五&#xff09;批…

系统学习ElastricSearch(一)

不知道大家在项目中是否使用过ElastricSearch&#xff1f;大家对它的了解又有多少呢&#xff1f;官网的定义&#xff1a;Elasticsearch是一个分布式、可扩展、近实时的搜索与数据分析引擎。今天我们就来揭开一下它的神秘面纱&#xff08;以下简称ES&#xff09;。 ES 是使用 J…