程序员学长 | 快速学会一个算法,Transformer(下)

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学习一个算法,Transformer(二)

今天我们来继续分享 Transformer 模型的第二部分,解码器部分

建议大家先看完第一部分。程序员学长 | 快速学会一个算法,Transformer(上)-CSDN博客

解码器 Decoder

上篇文章我们已经介绍了编码器中的大部分概念,也基本知道了编码器的原理。现在让我们来看下, 编码器和解码器是如何协同工作的。

编码器一般有多层,第一个编码器的输入是一个序列文本,最后一个编码器输出是一组序列向量,这组序列向量会作为解码器的 K、V 输入,其中 K=V=解码器输出的序列向量表示。这些注意力向量将会输入到每个解码器的 Encoder-Decoder Attention 层,这有助于解码器把注意力集中到输入序列的合适位置,如下图所示。

图片

解码(decoding )阶段的每一个时间步都输出一个翻译后的单词(这里的例子是英语翻译),解码器当前时间步的输出又重新作为下一个时间步解码器的输入。然后重复这个过程,直到输出一个结束符。如下图所示。

图片

下面,我们来关注一下解码器中包含的核心组件。

Decoder 与 Encoder 一样会先跟 Positional Encoding 相加再进入 layer,不同的是 Decoder 有三个子层 Masked Multi-head Attention、Multi-head Attention、Feed Forward。此外,中间层 Multi-head Attention 的输入 q 来自于本身前一层的输出,而 k, v 则是来自于 Encoder 的输出。

下面我们来重点介绍一下 Masked Multi-head Attention 和 Multi-head Attention。

Masked Multi-head Attention

在标准的多头注意力机制中,每个位置的查询(Query)会与所有位置的键(Key)进行点积计算,得到注意力分数,然后与值(Value)加权求和,生成最终的输出

然而,在解码器中,生成序列时不能访问未来的信息。因此需要使用掩码(Mask)机制来屏蔽掉未来位置的信息。

例如在 “我看到 ______ 追着一只老鼠” 中,我们会直观地填充猫,因为这是最有可能的。因此,在编码单词时,它需要知道整个句子中的所有内容。这就是为什么在自注意力层中,查询是针对所有单词执行的。但是在解码时,当试图预测句子中的下一个单词时,从逻辑上讲,它不应该知道我们试图预测的单词之后有哪些单词。因此需要使用掩码(Mask)机制来屏蔽掉未来位置的信息。

掩码机制通过引入一个上三角矩阵来屏蔽未来的位置。具体来说,对于每个查询位置 t,只允许它与 t 及其之前的位置进行注意力计算,而屏蔽掉 t 之后的位置。这可以通过在计算注意力分数时将未来位置的分数设置为负无穷来实现,从而在 softmax 函数中得到接近于零的权重。

Multi-head Attention

在解码器中的 Multi-head Attention 也叫做 Encoder-Decoder Attention,它的 Query 来自解码器的 self-attention,而 Key、Value 则是编码器的输出。

下面,我们来看一下解码器的 python 代码实现

# code implementation of DECODER
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_self_attention = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        
        # Masked self-attention layer
        self_attention_output= self.masked_self_attention(x, x, x, tgt_mask)
        self_attention_output = self.dropout(self_attention_output)
        x = x + self_attention_output
        x = self.norm1(x)
        
        # Encoder-decoder attention layer
        enc_dec_attention_output= self.enc_dec_attention(x, encoder_output, 
        encoder_output, src_mask)
        enc_dec_attention_output = self.dropout(enc_dec_attention_output)
        x = x + enc_dec_attention_output
        x = self.norm2(x)
        
        # Feed-forward layer
        feed_forward_output = self.feed_forward(x)
        feed_forward_output = self.dropout(feed_forward_output)
        x = x + feed_forward_output
        x = self.norm3(x)
        
        return x# Define the DecoderLayer parameters
d_model = 512  # Dimensionality of the model
num_heads = 8  # Number of attention heads
d_ff = 2048    # Dimensionality of the feed-forward network
dropout = 0.1  # Dropout probability
batch_size = 1 # Batch Size
max_len = 100  # Max length of Sequence

# Define the DecoderLayer instance
decoder_layer = DecoderLayer(d_model, num_heads, d_ff, dropout)


src_mask = torch.rand(batch_size, max_len, max_len) > 0.5
tgt_mask = torch.tril(torch.ones(max_len, max_len)).unsqueeze(0) == 0

# Pass the input tensors through the DecoderLayer
output = decoder_layer(input_sequence, encoder_output, src_mask, tgt_mask)

# Output shape
print("Output shape:", output.shape)
线性层和 softmax

Decoder 最终的输出是一个向量,其中每个元素是浮点数。我们怎么把这个向量转换为单词呢?这是线性层和 softmax 完成的。

线性层就是一个普通的全连接神经网络,可以把解码器输出的向量,映射到一个更大的向量,这个向量称为 logits 向量。假设我们的模型有 10000 个英语单词(模型的输出词汇表),此 logits 向量便会有 10000 个数字,每个数表示一个单词的分数。

然后,Softmax 层会把这些分数转换为概率(把所有的分数转换为正数,并且加起来等于 1)。然后选择最高概率的那个数字对应的词,就是这个时间步的输出单词。

到目前位置,Transformer 模型的所有模块都已经介绍完毕。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/740334.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

热虹吸管的传热计算

热对称管和热管通过使用中空管内的两相流体,在特定的距离上传输大量的热量。 更广泛使用的热管使用吸芯结构将液体输送回热端,而热虹吸管是一个简单的空心管,使用重力。 由于缺乏吸芯结构,使得热虹吸管比传统的热管便宜得多。 然…

自学指南:必备书籍清单--近100本R语言及生物信息相关书籍

R语言是一种功能丰富的编程语言,数据处理、统计分析是大家所熟知的基本功能。开源免费、活跃的全球社区、灵活可扩展等优点促使R语言飞速发展。目前,CRAN 软件包存储库包含 20446 个可用软件包,涵盖了从生物信息到金融分析等广泛的应用领域。…

vue3.0(十五)内置组件Teleport和Suspense

文章目录 一、Teleport1.基本用法2.禁用Teleport3.多个 Teleport 共享目标4.搭配组件 二、 Suspense1.什么是Suspense2.异步依赖3.加载中状态4.事件5.错误处理6.和其他组件结合注意 一、Teleport <Teleport> 是一个内置组件&#xff0c;它可以将一个组件内部的一部分模板…

第二十八篇——复盘:世界不完美,我们该怎么办?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 对于信息传递过程中的相关知识的总结&#xff0c;让我又仿佛回到了每一个…

Python列表比较:判断两个列表是否相等的多种方法

&#x1f4d6; 正文 1 通过排序的方式实现判断 list_a [a, b, c, d] list_b [c, d, a, b]if sorted(list_a) sorted(list_b):print(list_a与list_b的元素相等) else:print(list_a与list_b的元素不相等)通过排序&#xff0c;让两组列表中元素一直后进行判断&#xff0c;得到…

Linux常用基本命令

linux目录 1.查看linux本机ip ip addr 2.新建文件夹 mkdir 文件夹名 3.新建文件 touch 文件名.后缀 4.删除文件 rm 文件名.后缀 5.删除文件 rm -r 文件名 6.不询问直接删除 rm -rf 文件名/文件名/ 7.显示目录下文件&#xff0c;文件夹 作用&#xff1a;显示指定目…

人工ai智能写作,分享推荐三款好用软件!

在数字化时代&#xff0c;人工智能&#xff08;AI&#xff09;已经渗透到我们生活的方方面面&#xff0c;而在内容创作领域&#xff0c;AI智能写作软件更是如雨后春笋般涌现。今天&#xff0c;就为大家分享三款备受好评的AI智能写作软件&#xff0c;让你轻松掌握高效写作的秘密…

基于matlab的SVR回归预测

1 原理 SVR&#xff08;Support Vector Regression&#xff09;回归预测原理&#xff0c;基于支持向量机&#xff08;SVM&#xff09;的回归分支&#xff0c;其核心思想是通过寻找一个最优的超平面来进行回归预测&#xff0c;并处理非线性回归问题。以下是SVR回归预测原理的系统…

FFmpeg中位操作相关的源码:GetBitContext结构体,init_get_bits函数、get_bits1函数和get_bits函数分析

一、引言 由《音视频入门基础&#xff1a;H.264专题&#xff08;3&#xff09;——EBSP, RBSP和SODB》可以知道&#xff0c;H.264 码流中的操作单位是位(bit)&#xff0c;而不是字节。因为视频的传输和存贮是十分在乎体积的&#xff0c;对于每一个比特&#xff08;bit&#xf…

策略模式 + 抽象工厂实现多方式登录验证

文章目录 1、需求背景2、常规想法3、工厂模式 配置文件解耦 策略模式4、具体实现5、其他场景6、一点思考 1、需求背景 以gitee为例&#xff0c;登录验证的方式有多种&#xff1a; 用户名密码登录短信验证码登录微信登录 先写一个登录接口&#xff0c;适配所有方式&#xff…

udp协议 服务器

1 TCP和UDP基本概念 TCP:(Transmission Control Protocol)是一种面向连接、可靠的基于字节流的传输层通信协议。并且提供了全双工通信&#xff0c;允许两个应用之间建立可靠的链接以进行数据交换 udp:(User Datagram Protocol):是一种无链接、不可靠、基于数据报文传输层协议&…

websocket服务执行playwright测试

上一篇博客从源码层面分析了playwright vscode插件实现原理&#xff0c;在上一篇博客中提到&#xff0c;backend服务是一个websocket服务。这遍博客将介绍如何封装一个websocket服务&#xff0c;通过发送消息来执行playwright测试。 初始化项目 第一步是初始化项目和安装必要的…

​【VMware】VMware Workstation的安装

目录 &#x1f31e;1. VMware Workstation是什么 &#x1f31e;2. VMware Workstation的安装详情 &#x1f33c;2.1 VMware Workstation的安装 &#x1f33c;2.2 VMware Workstation的无限使用 &#x1f31e;1. VMware Workstation是什么 VMware Workstation是一款由VMwar…

【K8s】专题六:Kubernetes 资源限制及服务质量等级

以下内容均来自个人笔记并重新梳理&#xff0c;如有错误欢迎指正&#xff01;如果对您有帮助&#xff0c;烦请点赞、关注、转发&#xff01;欢迎扫码关注个人公众号&#xff01; 目录 一、资源限制 1、基本介绍 2、工作原理 3、限制方法 二、服务质量等级 一、资源限制 1…

【软件测试入门】测试用例经典设计方法 — 因果图法

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、因果图设计测试用例的步骤 1、分析需求 阅读需求文档&#xff0c;如果User Case很复杂&am…

DIY灯光特效:霓虹灯动画制作教程

下面我们根据这张霓虹灯案例,教大家如何用智能动物霓虹灯闪烁的效果,大家可以根据思路,实现自己想要的动效效果,一起动手来做吧。 即时设计-可实时协作的专业 UI 设计工具 设置背景 新建画板尺寸为:800PX^600PX,设置背景色#120527。 绘制主题 输入自己喜欢文案,轮廓化,具体…

PHP-CGI的漏洞(CVE-2024-4577)

通过前两篇文章的铺垫&#xff0c;现在我们可以了解 CVE-2024-4577这个漏洞的原理 漏洞原理 CVE-2024-4577是CVE-2012-1823这个老漏洞的绕过&#xff0c;php cgi的老漏洞至今已经12年&#xff0c;具体可以参考我的另一个文档 简单来说&#xff0c;就是使用cgi模式运行的PHP&…

充电桩--充电桩智能化发展趋势

聚焦光伏产业、深耕储能市场、探究充电技术 小Q下午茶 相互交流学习储能和BMS相关内容 43篇原创内容 公众号 一、背景介绍 国家提出“新基建”以来&#xff0c;充电基础设施产业跃入人们的视线成为热门话题。充电基础设施作为充电网、车联网、能源网和物联网的连接器&…

JS对象、数组、字符串超详细方法

JavaScript 对象方法 对象创建的方式 对象字面量 var dog1 {name: "大黄",age: 2,speak: function () {console.log("汪汪");}, };使用Object构造函数 var dog2 new Object(); dog2.name "大黄"; dog2.age 2; dog2.speak function () …

卷积的通俗解释

以时间和空间两个维度分别理解卷积&#xff0c;先用文字来描述&#xff1a; 时间上&#xff0c;任何当前信号状态都是迄至当前所有信号状态的叠加&#xff1b;时间上&#xff0c;任何当前记忆状态都是迄至当前所有记忆状态的叠加&#xff1b;空间上&#xff0c;任何位置状态都…