Transformer学习理解

1.前言        

        本文介绍当下人工智能领域的基石与核心结构模型——Transformer,为什么说它是基石,因为以ChatGPT为代表的聊 天机器人以及各种有望通向AGI(通用人工智能)的道路上均在采用的Transformer。 Transformer也是当下NLP任务的底座,包括后续的BERT和GPT,都是Transformer架构,BERT主要由Transformer的 encoder构成,GPT则主要是decoder构成。

        本文将会通读Transformer原论文《Attention is all you need》,总结Transformer结构以及文章要点,然后采用pytorch 代码进行机器翻译实验,通过代码进一步了解Transformer在应用过程中的步骤。

2.论文阅读笔记

        Transformer的论文是《Attention is all you need》(https://arxiv.org/abs/1706.03762),由Google团队在2017提出的 一种针对机器翻译的模型结构,后续在各类NLP任务上均获取SOTA。

 Motivation:针对RNN存在的计算复杂、无法串行的问题,提出仅通过attention机制的简洁结构——Transformer,在多个序列任务、机器翻译中获得SOTA,并有运用到其他数据模态的潜力。

 模型结构: 模型由encoder和decoder两部分构成,分别采用了6个block堆叠, encoder的block有两层,multi-head attention和FC层 decoder的block有三层,处理自回归的输入masked multi-head attention,处理encoder的attention,FC层。

注意力机制:scale dot-production attention,采用QK矩阵乘法缩放后softmax充当权重,再与value进行乘法。 多头注意力机制:实验发现多个头效果好,block的输出,把多个头向量进行concat,然后加一个FC层。因此每个头的向量长度是总长度/头数,例:512/8=64,每个头是64维向量。

Transformer的三种注意力层:

(1)encoder:输入来自上一层的全部输出

(2)decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

(3)decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

FFN层:attention之后接入两个FC层,第一个FC层采用2048,第二个是512,第一个FC层采用max(0, x)作为激活函数。

embedding层的缩放:在embedding层进行尺度缩放,乘以根号d_model。

位置编码:采用正余弦函数构建位置向量,然后采用加法,融入embedding中。

实验:两个任务,450万句子(英-德),3600万句子(英-法),8*16GB显卡,分别训练12小时,84小时。10万step 时,采用4kstep预热;采用标签平滑0.1。

重点句子摘录

1. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

        由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。 这里是官方对多头注意力机制引入的解释。

2. We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 根号 dk

        在Q和K做点积时,若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此 需要进行缩放。缩放因子为除以根号dk。

多头注意力机制的三种情况:

        i. encoder:输入来自上一层的全部输出。

        ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

        iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

首次阅读遗留问题

        1. 位置编码是否重新学习? 不重新学习,详见 PositionalEncoding的         self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

        2. qkv具体实现过程

                i. 通过3个w获得3个特征向量:

                ii. attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) # q.shape == (bs, n_head, len_seq, d_k/n_head) ,每个token用 一个向量表示,总向量长度是头数*每个头的向量长度。

                iii. attn = self.dropout(F.softmax(attn, dim=-1)) iv. output = torch.matmul(attn, v)

        3. decoder输入的处理细节

                 i. 训练阶段,无特殊处理,一个样本可以直接输入,根据下三角mask避免未卜先知。

                ii. 推理阶段,首先手动执行token的输入,然后for循环直至最大长度,期间输出的token拼接到输出列表中,并作为下一步decoder的输入。选择输出token时还采用了beam search来有效地平衡广度和深度。

3.数据集构建

        数据下载

        本案例数据集Tatoeba下载自这里。该项目是帮助不同语言的人学习英语,因此是英语与其它几十种语言的翻译文本。 其中就包括本案例使用的英中文本,共计29668条(Mandarin Chinese - English cmn-eng.zip (29668)) 数据以txt形式存储,一行是一对翻译文本,例如长这样:

        1.That mountain is easy to climb. 那座山很容易爬。

        2.It's too soon. 太早了。

        3.Japan is smaller than Canada. 日本比加拿大小。

        数据集划分

         对于29668条数据进行8:2划分为训练、验证,这里采用配套代码a_data_split.py进行划分,即可在统计目录下获得 train.txt和text.txt。

4.词表构建

        文本任务首要任务是为文本构建词表,这里采用与上节一样的方法,首先对文本进行分词,然后统计语料库中所有的 词,最后根据最大上限、最小词频等约束,构建词表。本部分配套代码是b_gen_vocabulary.py

        词表的构建过程中,涉及两个知识点:中文分词和特殊token。

        1. 中文分词

        对于英文,分词可以直接采用空格。而对于中文,就需要用特定的分词方法,这里采用的是jieba分词工具,以下是英文 和中文的分词代码。

        source.append(parts[0].split(' '))

        target.append(list(jieba.cut(parts[1]))) # 分词

        2. 特殊token

         由于seq2seq任务的特殊性,在解码器部分,通常需要一个token告诉模型,现在是开始,同时还需要有个token让模型 输出,以此告诉人类,模型输出完毕,不要再继续生成了。

         因此相较于文本分类,还多了bos, eos,两个特殊token,有的时候,开始token也会用start表示。

        PAD_TAG = "" # 用PAD补全句子长度

         BOS_TAG = "" # 用BOS表示开始

        EOS_TAG = "" # 用EOS表示结束

        UNK_TAG = "" # 用EOS表示结束

        PAD = 0 # PAD字符对应的数字

        BOS = 1 # BOS字符对应的数字

         EOS = 2 # EOS字符对应的数字

         UNK = 3 # UNK字符对应的数字

        运行代码后,词表字典保存到了result目录下,并得到如下输出,表明英文中有2518个词,中文有3365,但经过最大长 度3000的截断后,只剩下2996,另外4个是特殊token。

        100%|██████████| 23635/23635 [00:00<00:00, 732978.24it/s]

        原始词表长度:2518,截断后长度:2518

        2522

        保存词频统计图:vocab_en.npy_word_freq.jpg

        100%|██████████| 23635/23635 [00:00<00:00, 587040.62it/s]

        保存统计图:vocab_en.npy_length_freq.jpg

        原始词表长度:3365,截断后长度:2996

        3000

5.Dataset编写

        NMTDataset的编写逻辑与上一小节的Dataset类似,首先在类初始化的时候加载原始数据,并进行分词;在getitem迭代 时,再进行token转index操作,这里会涉及增加结束符、填充符、未知符。 核心代码如下:

def __init__(self, path_txt, vocab_path_en, vocab_path_fra, max_len=32):
    self.path_txt = path_txt
    self.vocab_path_en = vocab_path_en
    self.vocab_path_fra = vocab_path_fra
    self.max_len = max_len
    self.word2index = WordToIndex()
    self._init_vocab()
    self._get_file_info()

def __getitem__(self, item):
# 获取切分好的句子list,一个元素是一个词
    sentence_src, sentence_trg = self.source_list[item], self.target_list[item]
# 进行填充, 增加结束符,索引转换
    token_idx_src = self.word2index.encode(sentence_src, self.vocab_en, self.max_len)
    token_idx_trg = self.word2index.encode(sentence_trg, self.vocab_fra, self.max_len)
    str_len, trg_len = len(sentence_src) + 1, len(sentence_trg) + 1 # 有效长度, +1是填充的结束符 <eos>.
    return np.array(token_idx_src, dtype=np.int64), str_len, np.array(token_idx_trg,         dtype=np.int64), trg_len

def _get_file_info(self):
    text_raw = read_data_nmt(self.path_txt)
    text_clean = text_preprocess(text_raw)
    self.source_list, self.target_list = text_split(text_clean)

6.模型构建

        Transformer代码梳理如下图所示,大体可分为三个层级

        1. Transformer的构建,包含encoder、decoder两个模块,以及两个mask构建函数。

        2. 两个coder内部实现,包括位置编码、堆叠的block。

        3. block实现,包含多头注意力、FFN,其中多头注意力将softmax(QK)*V拆分为ScaledDotProductAttention类。

        代码整体与论文中保持一致,总结几个不同之处:

        1. layernorm使用时机前置到attention层和FFN层之前

        2. position embedding 的序列长度默认采用了200,如果需要更长的序列,则要注意配置。 具体代码实现不再赘述,把论文中图2的结果熟悉,并掌握上面的代码结构,可以快速理解各模块、组件的运算和操作步 骤,如有疑问的点,再打开代码观察具体运算过程即可。

7.模型训练

        原论文进行两个数据集的机器翻译任务,采用的数据和超参数列举如下,供参考。

        英语-德语,450万句子对,英语-法语,3600万句子对。均进行base/big两种尺寸训练,分别进行10万step和30万step训 练,耗时12小时/84小时。10万step时,采用4千step进行warmup。正则化方面采用了dropout=0.1的残差连接,0.1的标 签平滑。

        本实验中有2.3万句子对训练,只能作为Transformer的学习,性能指标仅为参考,具体任务后续由BERT、GPT、T5来 实现更为具体的项目。 采用配套代码train_transformer.py,执行训练即可:

        采用配套代码train_transformer.py,执行训练即可:

python train_transformer.py -embs_share_weight -proj_share_weight -label_smoothing -b 256 -warmup 128000 -epoch 400

        训练完成后,在result文件夹下会得到日志与模型,接下来采用配套代码c_train_curve_plot.py

        对日志数据进行绘图可视化,Loss和Accuracy分别如下,可以看出模型拟合能力非常强,性能还在提高,但受限于数据量过少,模型在200个 epoch之后就已经出现了过拟合。

        这里面的评估指标用的acc,具体是直接复用github项目,也能作为模型性能的评估指标,就没有去修改为BLUE。

8.模型推理

        Transformer的推理过程与训练时不同,值得仔细学习。

        Transformer的推理输出是典型的自回归(Auto regressive),并且需要根据多个条件综合判断何时停止,因此推理部分的逻辑值得认真学习,具体步骤如下:

        第一步:输入序列经encoder,获得特征,每个token输出一个向量,这个特征会在decoder的每个step都用到,即 decoder中各block的第二个multi-head attention。需要enc_output去计算k和v,用decoder上一层输出特征向量去计算 q,以此进行decoder的第二个attention。

        enc_output, *_ = self.model.encoder(src_seq, src_mask)

        第二步:手动执行decoder第一步,输入是这个token,输出的是一个概率向量,由分类概率向量再决定第一个输出 token。

        self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))

        dec_output =  self._model_decode(self.init_seq, enc_output, src_mask)

        第三步:循环max_length次执行decoder剩余步。第i步时,将前序所有步输出的token,组装为一个序列,输入到 decoder。

         在代码中用一个gen_seq维护模型输出的token,输入给模型时,只需要gen_seq[:, :step]即可,很巧妙。 在每一步输出时,都会判断是否需要停止输出字符。

for step in range(2, max_seq_len):
    dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
略
    if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
        _, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
        ans_idx = ans_idx.item()
        break

        借助李宏毅老师2021年课件中一幅图,配合代码,可以很好的理解Transformer在推理时的流程。

         1. 输入序列经encoder,获得特征(绿色、蓝色、蓝色、橙色)

          2. decoder输入第一个序列(序列长度为1,token是),输出第一个概率向量,并通过max得到“机”。

          3. decoder输入第二个序列(序列长度为2,token是[BOS, 机]),输出得到“器”

          4. decoder输入第三个序列(序列长度为3,token是[BOS, 机,器]),输出得到“学”

          5. decoder输入第四个序列(序列长度为4,token是[BOS, 机,器,学]),输出得到“习”

          6. ...以此类推 

        在推理过程中,通常还会使用Beam Search 来最终确定当前步应该输出哪个token,此处不做展开。

        运行配套代码inference_transformer.py可以看到10条训练集的推理结果。

         从结果上看,基本像回事儿了。

        src: tom's door is open .

        trg: 湯姆的門開著 。

        pred: 汤姆的 <unk>了 。

        src:<unk> is a <unk> country .

        trg: <unk>是一個<unk>的國家 。

         pred: <unk>是一個<unk>的城市 。

         src: i can come at three .

         trg: <unk>可以 來 。

        pred: 我 可以 在 那裡

9.总结

        本文通过Transformer论文的学习,了解Transformer基础架构,并通过机器翻译案例,从代码实现的角度深入剖析 Transformer训练和推理的过程。由于Transformer是目前人工智能的核心与基石,因此需要认真、仔细的掌握其中细节。

        本文内容值得注意的知识点:

         1. 多头注意力机制的引入:官方解释为,由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。

         2. 注意力计算时的缩放因子:QK乘法后需要缩放,是因为若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此需要进行缩放。缩放因子为除以根号dk。

        3. 多头注意力机制的三种情况:

          i. encoder:输入来自上一层的全部输出

         ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

         iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

        4. 输入序列的msk:代码实现时,由于输入数据是通过添加来组batch的,并且为了在运算时做并行运算,因此需要对 src中是pad的token做mask,这一点在论文是不会提及的。

        5. decoder的mask:根据下三角mask避免未卜先知。

        6. 推理时会采用beam search进行搜索,确定输出的token。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/727413.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于PCL实现多边形框选点云并进行裁剪(附C++源码)

文章目录 一.算法效果二.算法原理PNPoly算法直线相交性判断三.代码实现一.算法效果 通过在PCL可视化界面上绘制2D封闭多边形来提取位于该封闭多边形内部或者外部的 的点,算法效果如下: 图1多边形裁剪点云效果图 二.算法原理 PNPoly算法 2D多边形框选裁剪点云,实际上可以简…

SSRF学习,刷题

[HNCTF 2022 WEEK2]ez_ssrf 给了一个Apache2的界面&#xff0c;翻译一下 就是一个默认的界面,目录扫描 可以看到flag.php,肯定是不能直接访问得到的&#xff0c;还有index.php&#xff0c;访问这个 可以看到三个参数data,host,port 还有fsockopen() 函数是 PHP 中用于打开一个…

代码随想录训练营Day 64|卡码网98. 所有可达路径(深搜)

1.所有可达路径 98. 所有可达路径 | 代码随想录 代码&#xff1a; &#xff08;深搜&#xff09;邻接矩阵表示 #include <iostream> #include <vector> using namespace std; vector<int> path; vector<vector<int>> result; void dfs(const ve…

LeetCode 算法:两两交换链表中的节点 c++

原题链接&#x1f517;&#xff1a;两两交换链表中的节点 难度&#xff1a;中等⭐️⭐️ 题目 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交…

QT中利用动画弄一个侧边栏窗口,以及贴条效果

1、效果 2、关键代码 void Widget::on_sliderBtn_clicked() {m_sliderWidget->show();QPropertyAnimation* animation = new QPropertyAnimation(m

政策更新记录:敏感信息访问权限与API使用变更

我们将更新“健康数据共享”政策,简化“健康数据共享”申请流程,并与“健康类应用”政策保持一致。此外,我们将于今年晚些时候在 Play 管理中心推出一项新的声明,取代当前使用表单进行申请的方式。 公布日期:2024-04-03 Health Connect 政策要求及常见问题解答 初步认识对…

[AIGC] 使用Google的Guava库中的Lists工具类:常见用法详解

在Java程序设计中&#xff0c;集合是我们最常用的数据结构之一。为了方便我们操作集合&#xff0c;Google的Guava库提供了一个名为Lists的工具类&#xff0c;它封装了许多用于操作List对象的实用方法。在本文中&#xff0c;我们将详细介绍其常见的用法&#xff0c;以帮助您更好…

volatile关键字(juc编程)

volatile关键字 3.1 看程序说结果 分析如下程序&#xff0c;说出在控制台的输出结果。 Thread的子类 public class VolatileThread extends Thread {// 定义成员变量private boolean flag false ;public boolean isFlag() { return flag;}Overridepublic void run() {// 线…

钡铼BL101网关助力智慧城市路灯远程智能管控

在迈向智慧城市的征途中&#xff0c;基础设施的智能化改造是关键一环&#xff0c;而路灯作为城市脉络的照明灯塔&#xff0c;其智能化升级对于节能减排、提升城市管理效率具有重要意义。钡铼BL101网关&#xff0c;作为Modbus转MQTT的专业桥梁&#xff0c;正以其卓越的性能和广泛…

数据仓库与数据库的区别

在数据管理和分析的过程中&#xff0c;我们常常会听到“数据库”和“数据仓库”这两个术语。 虽然它们看起来相似&#xff0c;但实际上它们在设计目的、结构和使用场景上都有显著的区别。 数据库是什么&#xff1f; 数据库&#xff08;Database&#xff09;是一个用于存储和管…

[创业之路-120] :全程图解:软件研发人员如何从企业的顶层看软件产品研发?

目录 一、企业全局 二、供应链 三、团队管理 四、研发流程IPD 五、软件开发流程 六、项目管理 七、研发管理者的自我修炼 一、企业全局 二、供应链 三、团队管理 四、研发流程IPD 五、软件开发流程 六、项目管理 七、研发管理者的自我修炼

时空预测 | 基于深度学习的碳排放时空预测模型

时空预测 模型描述 数据收集和准备&#xff1a;收集与碳排放相关的数据&#xff0c;包括历史碳排放数据、气象数据、人口密度数据等。确保数据的质量和完整性&#xff0c;并进行必要的数据清洗和预处理。 特征工程&#xff1a;根据问题的需求和领域知识&#xff0c;对数据进行…

【C++】基础知识--inline(内联)关键字以及与宏的区别

c语言中的小小白-CSDN博客c语言中的小小白关注算法,c,c语言,贪心算法,链表,mysql,动态规划,后端,线性回归,数据结构,排序算法领域.https://blog.csdn.net/bhbcdxb123?spm1001.2014.3001.5343 给大家分享一句我很喜欢我话&#xff1a; 知不足而奋进&#xff0c;望远山而前行&am…

Nginx Rewrite技术

一&#xff1a;理解地址重写 与 地址转发的含义。二&#xff1a;理解 Rewrite指令 使用三&#xff1a;理解if指令四&#xff1a;理解防盗链及nginx配置 简介&#xff1a;Rewrite是Nginx服务器提供的一个重要的功能&#xff0c;它可以实现URL重定向功能。 一&#xff1a;理解地…

医学图像预处理之z分数归一化

在医学图像处理中&#xff0c;Z分数标准化&#xff08;Z-score normalization&#xff09;是一种常用的数据标准化方法&#xff0c;其目的是将数据集中的每个图像像素值转换为具有均值为0和标准差为1的标准化值。这种标准化方法有助于改善图像的质量&#xff0c;便于后续图像处…

RS485中继器的作用你还不知道?

RS485是一种串行通信协议&#xff0c;支持设备间长距离通信。RS485中继器则像“传声筒”&#xff0c;能放大衰减信号&#xff0c;延长通信距离&#xff0c;隔离噪声&#xff0c;扩展分支。在实际场景中&#xff0c;如工厂内&#xff0c;通过中继器可确保控制室与远距离机器间通…

嵌入式Linux 中常见外设屏接口分析

今天将梳理下嵌入式外设屏幕接口相关的介绍,对于一个嵌入式驱动开发工程师,对屏幕都可能接触到一些相关的的调试,这里首先把基础相关的知识梳理。 1. 引言 在嵌入式开发过程中,使用到的液晶屏有非常多的种类,根据不同技术和特性分类,会接触到TN液晶屏,TN液晶屏 VA液晶屏…

Java基础16(集合框架 List ArrayList容器类 ArrayList底层源码解析及扩容机制)

目录 一、什么是集合&#xff1f; 二、集合接口 三、List集合 四、ArrayList容器类 1. 常用方法 1.1 增加 1.2 查找 int size() E get(int index) int indexOf(Object c) boolean contains(Object c) boolean isEmpty() List SubList(int fromindex,int …

H3C防火墙抓包(图形化)

一.报文捕获 &#xff0c;然后通过wireshark查看报文 二.报文示踪 &#xff0c; 输入源目等信息&#xff0c; 查看报文的详情

鸿蒙Harmony实战—通过登录Demo了解ArkTS

ArkTS是HarmonyOS优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&#xff0c;继承了TS的所有特性&#xff0c;是TS的超集。 ArkTS在TS的基础上主要扩展了如下能力&#xff1a; 基本语法&#xff1a;ArkTS定义…