Transformer详解:从放弃到入门(完结)

  前几篇文章中,我们已经拆开并讲解了Transformer中的各个组件。现在我们尝试使用这些方法实现Transformer的编码器。
在这里插入图片描述  如图所示,编码器(Encoder)由N个编码器块(Encoder Block)堆叠而成,我们依次实现。

class EncoderBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float,
        norm_first: bool = False,
    ) -> None:
        super().__init__()

        self.norm_first = norm_first

        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = LayerNorm(d_model)

        self.ff = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    # self attention sub layer
    def _sa_sub_layer(
        self, x: Tensor, attn_mask: Tensor, keep_attentions: bool
    ) -> Tensor:
        x = self.attention(x, x, x, attn_mask, keep_attentions)
        return self.dropout1(x)

    def _ff_sub_layer(self, x: Tensor) -> Tensor:
        x = self.ff(x)
        return self.dropout2(x)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tuple[Tensor, Tensor]:
        # pass througth multi-head attention
        # src (batch_size, seq_length, d_model)
        # attn_score (batch_size, n_heads, seq_length, k_length)
        x = src
        if self.norm_first:
            x = x + self._sa_sub_layer(self.norm1(x), src_mask, keep_attentions)
            x = x + self._ff_sub_layer(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_sub_layer(x, src_mask, keep_attentions))
            x = self.norm2(x + self._ff_sub_layer(x))

        return x

  需要注意的是,层归一化的位置通过参数norm_first控制,默认norm_first=False,这种实现方式称为Post-LN,是Transformer的默认做法。但这种方式很难从零开始训练,把层归一化放到残差块之间,接近输出层的参数的梯度往往较大。然后在那些梯度上使用较大的学习率会使得训练不稳定。通常需要用到学习率预热(warm-up)技巧,在训练开始时学习率需要设成一个极小的值,但是一旦训练好之后的效果要优于Pre-LN的方式。而如果采用norm_first=True的方式,被称为Pre-LN,它的区别在于对于子层(*_sub_layer)的输入先进行层归一化,再输入到子层中。最后进行残差连接。
在这里插入图片描述  即实际上由上图左变成了图右,注意最后在每个Encoder或Decoder的输出上再接了一个层归一化。
  有了编码器块,我们再来实现编码器。

class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layers: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        norm_first: bool = False,
    ) -> None:
        super().__init__()
        # stack n_layers encoder blocks
        self.layers = nn.ModuleList(
            [
                EncoderBlock(d_model, n_heads, d_ff, dropout, norm_first)
                for _ in range(n_layers)
            ]
        )

        self.norm = LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tensor:
        x = src
        # pass through each layer
        for layer in self.layers:
            x = layer(x, src_mask, keep_attentions)

        return self.norm(x)

  这里要注意的是,最后对编码器和输出进行一次层归一化。至此,我们的编码器完成了,在其forward()中src是词嵌入加上位置编码,那么src_mask是什么?它是用来指示非填充标记的。我们知道,对于文本序列批数据,一个批次内序列长短不一,因此需要以一个指定的最长序列进行填充,而我们的注意力不需要在这些填充标记上进行。
  创建src_mask很简单,假设输入是填充后的批数据:

def make_src_mask(src: Tensor, pad_idx: int = 0) -> Tensor:
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask

  输出维度变成(batch_size, 1, 1, seq_length)为了与缩放点积注意力分数适配维度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/602923.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

粤港澳青少年信息学创新大赛 Python 编程竞赛(初中部分知识点整理)

一、考试大纲梳理 知识内容 知识目标 计算机基础与编程环境,历史,存储与网络变量定义和使用基本数据类型(整型,浮点型,字符型,布尔型),数据类型的转换控制语句结构(顺序…

车规级低功耗汽车用晶振SG-9101CGA

车规级晶振SG-9101CGA属于爱普生9101系列,是一款可编程晶振。SG-9101CGA车规级晶振采用2.5x2.0mm封装,利用PLL技术生产,此款振荡器的频率范围从0.67M~170MHZ任一频点可选,步进1ppm,采用标准CMOS输出,最大输…

极验4 一键解混淆

提示!本文章仅供学习交流,严禁用于任何商业和非法用途,未经许可禁止转载,禁止任何修改后二次传播,如有侵权,可联系本文作者删除! AST简介 AST(Abstract Syntax Tree)&a…

【Redis7】10大数据类型之Zset类型

文章目录 1.Zset类型2.常用命令3.示例3.1 ZADD,ZRANGE和ZREVRANGE3.2 ZSCORE,ZCARD和ZREM3.3 ZRANGEBYSCORE和ZCOUNT3.4 ZRANK和ZREVRANK3.5 Redis7新命令ZMPOP 1.Zset类型 Redis的Zset(Sorted Set,有序集合)是一种特殊的数据结构&#xff0…

【教学类-18-03】20240508《蒙德里安“红黄蓝黑格子画”-A4横版》(大小格子)

作品展示 背景需求: 2年前制作了蒙德里安随机线条 【教学类-18-02】20221124《蒙德里安“红黄蓝黑格子画”-A4竖版》(大班)_蒙德里安模版-CSDN博客文章浏览阅读1k次。【教学类-18-02】20221124《蒙德里安“红黄蓝黑格子画”-A4竖版》(大班)…

ChIP-seq or CUTTag,谁能hold住蛋白质与DNA互作主战场?

DNA与蛋白质的相互作用作为表观遗传学中的一个重要领域,对理解基因表达调控、DNA复制与修复、表观遗传修饰(组蛋白修饰)及染色质结构等基本生命过程至关重要。 自1983年James Broach首次公布染色质免疫共沉淀(ChIP)技…

Docker-Compose 容器集群的快速编排

Docker-compose 简介 Docker-Compose项目是Docker官方的开源项目,负责实现对Docker容器集群的快速编排。 Docker-Compose将所管理的容器分为三层,分别是 工程(project),服务(service)以及容器&…

LLM 安全 | 大语言模型应用安全入门

一、背景 2023年以来,LLM 变成了相当炙手可热的话题,以 ChatGPT 为代表的 LLM 的出现,让人们看到了无限的可能性。ChatGPT能写作,能翻译,能创作诗歌和故事,甚至能一定程度上做一些高度专业化的工作&#x…

滤除纹波的方法:

空载时看起来纹波比较小但是一加负载的时候纹波一下子上来是因为空载时候电流比较小MOS大部分时间处于关断状态,而加上负载后电流变大MOS打开关闭更为频繁,因而纹波更大。 下图中的尖峰可能是因为电路的寄生电感导致的,理论上只能削弱不能避…

LINUX 入门 8

LINUX 入门 8 day10 20240507 耗时:90min 有点到倦怠期了 课程链接地址 第8章 TCP服务器 1 TCP服务器的介绍 开始讲服务器端,之前是客户端DNShttps请求 基础:网络编程并发服务器:多客户端 一请求,一线程 veryold…

推荐网站(6)33台词,通过台词找电影、电视剧、纪录片等素材

今天推荐一个网站33台词,你可以根据电影、电视剧、纪录片等某一段台词,来找到来源,帮你精确到多少分多少秒出现的,非常的好用,尤其是对那种只记得一些经典台词,不知道是哪个电影的人来说,帮助巨…

揭秘全网热门话题:抖音快速涨粉方法,巨量千川投流助你日增10000粉

在当今社交媒体的时代( 千川投流:hzzxar)抖音成为了年轻人分享自己才华和生活的平台。然而,要在抖音上快速获得关注和粉丝,却不是一件容易的事情。今天,我们将揭秘全网都在搜索的抖音快速涨1000粉的秘籍,带…

网络文件共享

存储类型分三类 直连式存储:DAS存储区域网络:SAN网络附加存储:NAS 三种存储架构的应用场景 DAS虽然比较古老了,但是还是很适用于那些数据量不大,对磁盘访问速度要求较高的中小企业SAN多适用于文件服务器&#xff0c…

pyfilesystem2,一个超级实用的 Python 库!

更多Python学习内容:ipengtao.com 大家好,今天为大家分享一个超级实用的 Python 库 - pyfilesystem2。 Github地址:https://github.com/pyfilesystem/pyfilesystem2 PyFilesystem2是一个强大的Python库,用于抽象化文件系统操作。它…

静态分析-RIPS-源码解析记录-02

这部分主要分析scanner.php的逻辑,在token流重构完成后,此时ini_get是否包含auto_prepend_file或者auto_append_file 取出的文件路径将和tokens数组结合,每一个文件都为一个包含require文件名的token数组 接着回到main.php中,此时…

【GUI软件】调用YouTube的API接口,采集关键词搜索结果,并封装成界面工具!

文章目录 一、背景介绍1.1 爬取目标1.2 演示视频1.3 软件说明 二、代码讲解2.1 调用API-搜索接口2.2 调用API-详情接口2.3 API_KEY说明2.4 软件界面模块2.5 日志模块 三、获取源码及软件 一、背景介绍 1.1 爬取目标 您好!我是马哥python说,一名10年程序…

动态规划day.2

62.不同路径 链接:. - 力扣(LeetCode) 题目描述: 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#…

能否直接上手 Qt ?——看完 C++ 课本后怎么做?

在开始前我有一些资料,是我根据网友给的问题精心整理了一份「Qt的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!如果你已经阅读了 C 课本,但仍然感到…

TMS320F28335学习笔记-时钟系统

第一次使用38225使用了普中的clocksystem例程进行编译,总是编译失败。 问题一:提示找不到文件 因为工程的头文件路径没有包含,下图的路径需要添加自己电脑的路径。 问题二 找不到库文件 例程种的header文件夹和common文件夹不知道从何而来…

7.删除有序数组中的重复项(快慢指针)

文章目录 题目简介题目解答解法一:暴力解法复杂度分析: 解法二:双指针(快慢指针)代码:复杂度分析: 题目链接 大家好,我是晓星航。今天为大家带来的是 相关的讲解!😀 题目简介 题目解…