NLP实战8:图解 Transformer笔记

目录

1.Transformer宏观结构

2.Transformer结构细节

2.1输入

2.2编码部分

2.3解码部分

2.4多头注意力机制

2.5线性层和softmax

2.6 损失函数

3.参考代码


🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

Transformer整体结构图,与seq2seq模型类似,Transformer模型结构中的左半部分为编码器(encoder),右半部分为解码器(decoder),接下来拆解Transformer。

1.Transformer宏观结构

Transformer模型类似于seq2seq结构,包含编码部分和解码部分。不同之处在于它能够并行计算整个序列输入,无需按时间步进行逐步处理。

其宏观结构如下:

6层编码和6层解码器

其中,每层encoder由两部分组成:

  • Self-Attention Layer
  • Feed Forward Neural Network(前馈神经网络,FFNN)

decoder在encoder的Self-Attention和FFNN中间多加了一个Encoder-Decoder Attention层。该层的作用是帮助解码器集中注意力于输入序列中最相关的部分。

单层encoder和decoder

2.Transformer结构细节

2.1输入

Transformer的数据输入与seq2seq不同。除了词向量,Transformer还需要输入位置向量,用于确定每个单词的位置特征和句子中不同单词之间的距离特征。

2.2编码部分

编码部分的输入文本序列经过处理后得到向量序列,送入第一层编码器。每层编码器输出一个向量序列,作为下一层编码器的输入。第一层编码器的输入是融合位置向量的词向量,后续每层编码器的输入则是前一层编码器的输出。

2.3解码部分

最后一个编码器输出一组序列向量,作为解码器的K、V输入。

解码阶段的每个时间步输出一个翻译后的单词。当前时间步的解码器输出作为下一个时间步解码器的输入Q,与编码器的输出K、V共同组成下一步的输入。重复此过程直到输出一个结束符。

解码器中的 Self-Attention 层,和编码器中的 Self-Attention 层的区别:

  • 在解码器里,Self-Attention 层只允许关注到输出序列中早于当前位置之前的单词。具体做法是:在 Self-Attention 分数经过 Softmax 层之前,屏蔽当前位置之后的那些位置(将Attention Score设置成-inf)。
  • 解码器 Attention层是使用前一层的输出来构造Query 矩阵,而Key矩阵和Value矩阵来自于编码器最终的输出。

2.4多头注意力机制

Transformer论文引入了多头注意力机制(多个注意力头组成),以进一步完善Self-Attention。

  • 它扩展了模型关注不同位置的能力
  • 多头注意力机制赋予Attention层多个“子表示空间”。

残差链接&Normalize: 编码器和解码器的每个子层(Self-Attention 层和 FFNN)都有一个残差连接和层标准化(layer-normalization),细节如下图

2.5线性层和softmax

Decoder最终输出一个浮点数向量。通过线性层和Softmax,将该向量转换为一个包含模型输出词汇表中每个单词分数的logits向量(假设有10000个英语单词)。Softmax将这些分数转换为概率,使其总和为1。然后选择具有最高概率的数字对应的词作为该时间步的输出单词。

2.6 损失函数

在Transformer训练过程中,解码器的输出和标签一起输入损失函数,以计算损失(loss)。最终,模型通过方向传播(backpropagation)来优化损失。

3.参考代码

class SelfAttention(nn.Module):
	def __init__(self, embed_size, heads):
		super(SelfAttention, self).__init__()
		self.embed_size = embed_size
		self.heads = heads
		self.head_dim = embed_size // heads

		assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

		self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

	def forward(self, values, keys, query, mask):
		N =query.shape[0]
		value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

		# split embedding into self.heads pieces
		values = values.reshape(N, value_len, self.heads, self.head_dim)
		keys = keys.reshape(N, key_len, self.heads, self.head_dim)
		queries = query.reshape(N, query_len, self.heads, self.head_dim)
		
		values = self.values(values)
		keys = self.keys(keys)
		queries = self.queries(queries)

		energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
		# queries shape: (N, query_len, heads, heads_dim)
		# keys shape : (N, key_len, heads, heads_dim)
		# energy shape: (N, heads, query_len, key_len)

		if mask is not None:
			energy = energy.masked_fill(mask == 0, float("-1e20"))

		attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

		out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
		# attention shape: (N, heads, query_len, key_len)
		# values shape: (N, value_len, heads, heads_dim)
		# (N, query_len, heads, head_dim)

		out = self.fc_out(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, embed_size, heads, dropout, forward_expansion):
		super(TransformerBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm1 = nn.LayerNorm(embed_size)
		self.norm2 = nn.LayerNorm(embed_size)

		self.feed_forward = nn.Sequential(
			nn.Linear(embed_size, forward_expansion*embed_size),
			nn.ReLU(),
			nn.Linear(forward_expansion*embed_size, embed_size)
		)
		self.dropout = nn.Dropout(dropout)

	def forward(self, value, key, query, mask):
		attention = self.attention(value, key, query, mask)

		x = self.dropout(self.norm1(attention + query))
		forward = self.feed_forward(x)
		out = self.dropout(self.norm2(forward + x))
		return out


class Encoder(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length,
		):
		super(Encoder, self).__init__()
		self.embed_size = embed_size
		self.device = device
		self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)

		self.layers = nn.ModuleList(
			[
				TransformerBlock(
					embed_size,
					heads,
					dropout=dropout,
					forward_expansion=forward_expansion,
					)
				for _ in range(num_layers)]
		)
		self.dropout = nn.Dropout(dropout)


	def forward(self, x, mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
		for layer in self.layers:
			out = layer(out, out, out, mask)

		return out


class DecoderBlock(nn.Module):
	def __init__(self, embed_size, heads, forward_expansion, dropout, device):
		super(DecoderBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm = nn.LayerNorm(embed_size)
		self.transformer_block = TransformerBlock(
			embed_size, heads, dropout, forward_expansion
		)

		self.dropout = nn.Dropout(dropout)

	def forward(self, x, value, key, src_mask, trg_mask):
		attention = self.attention(x, x, x, trg_mask)
		query = self.dropout(self.norm(attention + x))
		out = self.transformer_block(value, key, query, src_mask)
		return out


class Decoder(nn.Module):
	def __init__(
			self,
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length,
	):
		super(Decoder, self).__init__()
		self.device = device
		self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)
		self.layers = nn.ModuleList(
			[DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
			for _ in range(num_layers)]
			)
		self.fc_out = nn.Linear(embed_size, trg_vocab_size)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x ,enc_out , src_mask, trg_mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

		for layer in self.layers:
			x = layer(x, enc_out, enc_out, src_mask, trg_mask)

		out =self.fc_out(x)
		return out


class Transformer(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			trg_vocab_size,
			src_pad_idx,
			trg_pad_idx,
			embed_size = 256,
			num_layers = 6,
			forward_expansion = 4,
			heads = 8,
			dropout = 0,
			device="cuda",
			max_length=100
		):
		super(Transformer, self).__init__()
		self.encoder = Encoder(
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length
			)
		self.decoder = Decoder(
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length
			)


		self.src_pad_idx = src_pad_idx
		self.trg_pad_idx = trg_pad_idx
		self.device = device


	def make_src_mask(self, src):
		src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
		# (N, 1, 1, src_len)
		return src_mask.to(self.device)

	def make_trg_mask(self, trg):
		N, trg_len = trg.shape
		trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
			N, 1, trg_len, trg_len
		)
		return trg_mask.to(self.device)

	def forward(self, src, trg):
		src_mask = self.make_src_mask(src)
		trg_mask = self.make_trg_mask(trg)
		enc_src = self.encoder(src, src_mask)
		out = self.decoder(trg, enc_src, src_mask, trg_mask)
		return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/47428.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【云原生】Docker网络及Cgroup资源控制

一、Docker网络 1.docker网络实现原理 Docker使用Linux桥接,在宿主机虚拟一个Docker容器网桥(docker0),Docker启动一个容器时会根据Docker网桥的网段分配给容器一个IP地址,称为Container-IP,同时Docker网桥是每个容器的默认网关。…

QT--day3(定时器事件、对话框)

头文件代码&#xff1a; #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTimerEvent> //定时器事件处理时间头文件 #include <QTime> //时间类 #include <QtTextToSpeech> #include <QPushButton> #include <QLabel&g…

分布式文件存储与数据缓存 FastDFS

一、FastDFS概述 1.1 什么是分布式文件系统 单机时代 初创时期由于时间紧迫&#xff0c;在各种资源有限的情况下&#xff0c;通常就直接在项目目录下建立静态文件夹&#xff0c;用于用户存放项目中的文件资源。如果按不同类型再细分&#xff0c;可以在项目目录下再建立不同的…

已实现商业化却仍陷亏损泥潭,瑕瑜错陈的觅瑞集团求上市

撰稿|行星 来源|贝多财经 7月25日&#xff0c;Mirxes Holding Company Limited-B&#xff08;以下简称“觅瑞集团”&#xff09;向港交所递交上市申请材料&#xff0c;计划在港交所主板上市&#xff0c;中金公司和建银国际为其联席保荐人。 据招股书介绍&#xff0c;成立于2…

Spring中如何用注解方式存取JavaBean?有几种注入方式?

博主简介&#xff1a;想进大厂的打工人博主主页&#xff1a;xyk:所属专栏: JavaEE进阶 本篇文章将讲解如何在spring中使用注解的方式来存取Bean对象&#xff0c;spring提供了多种注入对象的方式&#xff0c;常见的注入方式包括 构造函数注入&#xff0c;Setter 方法注入和属性…

字节跳动 EB 级 Iceberg 数据湖的机器学习应用与优化

深度学习的模型规模越来越庞大&#xff0c;其训练数据量级也成倍增长&#xff0c;这对海量训练数据的存储方案也提出了更高的要求&#xff1a;怎样更高性能地读取训练样本、不使数据读取成为模型训练的瓶颈&#xff0c;怎样更高效地支持特征工程、更便捷地增删和回填特征。本文…

[OnWork.Tools]系列 02-安装

下载地址 百度网盘 历史版本连接各种版本都有,请下载版本号最高的版本 链接&#xff1a;https://pan.baidu.com/s/1aOT0oUhiRO_L8sBCGomXdQ?pwdn159提取码&#xff1a;n159 个人链接 http://on8.top:5000/share.cgi?ssiddb2012fa6b224cd1b7f87ff5f5214910 软件安装 双…

Rust之泛型、特性和生命期(四):验证有生存期的引用

开发环境 Windows 10Rust 1.71.0 VS Code 1.80.1 项目工程 这里继续沿用上次工程rust-demo 验证具有生存期的引用 生存期是我们已经在使用的另一种泛型。生存期不是确保一个类型具有我们想要的行为&#xff0c;而是确保引用在我们需要时有效。 我们在第4章“引用和借用”一…

SpringCloud学习路线(13)——分布式搜索ElasticSeach集群

前言 单机ES做数据存储&#xff0c;必然面临两个问题&#xff1a;海量数据的存储&#xff0c;单点故障。 如何解决这两个问题&#xff1f; 海量数据的存储问题&#xff1a; 将索引库从逻辑上拆分为N个分片&#xff08;shard&#xff09;&#xff0c;存储到多个节点。单点故障…

新增WebDB和ChatGPT组件,支持对ChatGPT资产进行纳管,JumpServer堡垒机v3.5.0发布

2023年7月24日&#xff0c;JumpServer开源堡垒机正式发布v3.5.0版本。在这一版本中&#xff0c;新生代数据库连接组件——问题终结者Chen强势来袭&#xff0c;替代原有的OmniDB组件&#xff0c;在兼容旧版本的同时&#xff0c;解决了旧组件性能不足的问题&#xff0c;为用户提供…

微信小程序开发之配置菜单跳转到自定义页面

需求: 用户点击公众号菜单跳转到自定义带引流码的链接 公众号相关文档: 网页授权 | 微信开放文档 大致流程: 1.在公众号菜单配置链接: https://open.weixin.qq.com/connect/oauth2/authorize?appidXXXXXXXXXXXX&redirect_urihttps%3A%2F%2F测试域名%2Fws_dabai%2Fwe…

NoSQL-Redis持久化

NoSQL-Redis持久化 一、Redis 高可用&#xff1a;1.概述&#xff1a; 二、Redis持久化&#xff1a;1.持久化的功能&#xff1a;2.Redis 提供两种方式进行持久化&#xff1a; 三、RDB 持久化&#xff1a;1.定义&#xff1a;2.触发条件&#xff1a;3.执行流程&#xff1a;4.启动时…

HDFS的设计目标和重要特性

HDFS的设计目标和重要特性 设计目标HDFS重要特性主从架构分块存储机制副本机制namespace元数据管理数据块存储 设计目标 硬件故障(Hardware Failure)是常态&#xff0c;HDFS可能有成百上千的服务器组成&#xff0c;每一个组件都有可能出现故障。因此古见检测和自动快速恢复的H…

选择合适的图表,高效展现数据魅力

随着大数据时代的来临&#xff0c;数据的重要性愈发凸显&#xff0c;数据分析和可视化成为了决策和传递信息的重要手段。在数据可视化中&#xff0c;选择合适的图表是至关重要的一环&#xff0c;它能让数据更加生动、直观地呈现&#xff0c;为观众提供更有说服力的信息。本文将…

JavaScript 练手小技巧:音乐播放器的歌词显示

暑假了&#xff0c;还是不能让自己闲着&#xff0c;学点自己感兴趣的知识&#xff0c;写点自己喜欢的代码。 今天写了一个播放器的雏形&#xff0c;带歌词显示。 没去自定义播放器&#xff0c;主要是写歌词显示效果。效果图如下&#xff1a; 首先当然是要准备一个 mp3 文件。 …

RocketMQ基本概念与入门

文章目录 MQ基本结构依赖案例:productConsumer 核心概念1.nameserver2.broker3.主题队列4.queue队列5. 生产者6.消费者分组和生产者分组7.消费点位 MQ基本结构 message: 消息数据对象product: 程序代码,生成消息,发送消息到队列consumer: 程序代码,监听(绑定)队列,获取消息,执行…

分布式锁:Redis、Zookeeper

1.基于Redis实现分布式锁&#xfeff; Redis分布式锁原理如上图所示&#xff0c;当有多个Set命令发送到Redis时&#xff0c;Redis会串行处理&#xff0c;最终只有一个Set命令执行成功&#xff0c;从而只有一个线程加锁成功 2.SetNx命令加锁 利用Redis的setNx命令在Redis数据库…

数据结构【绪论】

数据结构入门级 第一章绪论 什么是数据结构&#xff1f;什么是数据类型&#xff1f; 程序数据结构算法 一、基本概念&#xff1a; 数据&#xff1a;指所有能被计算机处理的&#xff0c;无论图、文字、符号等。数据元素&#xff1a;数据的基本单位&#xff0c;通常作为整体考…

Unity TMP (TextMeshPro) 创建字体材质

1 TMP 简介 完整名称&#xff1a;Text Mesh Pro &#xff0c;unity新一代主流字体插件 1.1 组件变化 内置的Text组件以及与内置Text组件绑定的Button、DropDown、InputField均被替换为使用TextMeshPro的版本 内置的Text组件以及与内置Text组件绑定的Button、DropDown、Input…

tinymce插件tinymce-powerpaste-plugin——将word中内容(文字图片等)直接粘贴至tinymce编辑器中

TinyMCE是一款易用、且功能强大的所见即所得的富文本编辑器。同类程序有&#xff1a;UEditor、Kindeditor、Simditor、CKEditor、wangEditor、Suneditor、froala等等。 TinyMCE的优势&#xff1a; 开源可商用&#xff0c;基于LGPL2.1 插件丰富&#xff0c;自带插件基本涵盖日常…