TR2 - Transformer模型的复现


目录

  • 理论知识
  • 模型结构
    • 结构分解
      • 黑盒
      • 两大模块
      • 块级结构
      • 编码器的组成
      • 解码器的组成
  • 模型实现
    • 多头自注意力块
    • 前馈网络块
    • 位置编码
    • 编码器
    • 解码器
    • 组合模型
    • 最后附上引用部分
  • 模型效果
  • 总结与心得体会


理论知识

Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。

模型结构

模型整体结构

结构分解

黑盒

以机器翻译任务为例
黑盒

两大模块

在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
两大模块

块级结构

继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
块组成

编码器的组成

继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
编码器的组成

解码器的组成

而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
解码器的组成

模型实现

通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。

多头自注意力块

class MultiHeadAttention(nn.Module):
    """多头注意力模块"""
    def __init__(self, dims, n_heads):
        """
        dims: 每个词向量维度
        n_heads: 注意力头数
        """
        super().__init__()

        self.dims = dims
        self.n_heads = n_heads

        # 维度必需整除注意力头数
        assert dims % n_heads == 0
        # 定义Q矩阵
        self.w_Q = nn.Linear(dims, dims)
        # 定义K矩阵
        self.w_K = nn.Linear(dims, dims)
        # 定义V矩阵
        self.w_V = nn.Linear(dims, dims)

        self.fc = nn.Linear(dims, dims)
        # 缩放
        self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        # 例如: [32, 1024, 300] 计算10头注意力
        Q = self.w_Q(query)
        K = self.w_K(key)
        V = self.w_V(value)

        # [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组
        Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)

        # 1. 计算QK/根dk
        # [32, 1024, 10, 30] * [32, 1024, 30, 10] -> [32, 1024, 10, 10] 交换最后两维实现乘法
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        if mask is not None:
            # 将需要mask掉的部分设置为很小的值
            attention = attention.masked_fill(mask==0, -1e10)
        # 2. softmax
        attention = torch.softmax(attention, dim=-1)

        # 3. 与V相乘
        # [32, 1024, 10, 10] * [32, 1024, 10, 30] -> [32, 1024, 10, 30]
        x = torch.matmul(attention, V)

        # 恢复结构
        # 0 2 1 3 把 第2,3维交换回去
        x = x.permute(0, 2, 1, 3).contiguous()
        # [32, 1024, 10, 30] -> [32, 1024, 300]
        x = x.view(batch_size, -1, self.n_heads*(self.dims//self.n_heads))
        # 走一个全连接层
        x = self.fc(x)
        return x

前馈网络块

class FeedForward(nn.Module):
    """前馈传播"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

位置编码

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 用来存位置编码的向量
        pe = torch.zeros(max_len, d_model).to(device)
        # 准备位置信息
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)* -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        # 注册一个不参数梯度下降的模型参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        x  = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

编码器

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)

        return x

解码器

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.enc_attn = MultiHeadAttention(d_model, n_heads)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, mask, enc_mask):
        # 自注意力
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        # 编码器-解码器注意力
        attn_output = self.enc_attn(x, enc_output, enc_output, enc_mask)
        x = x + self.dropout(attn_output)
        x = self.norm2(x)

        # 前馈网络
        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm3(x)

        return x

组合模型

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, trg, src_mask, trg_mask):
        # 词嵌入
        src = self.embedding(src)
        src = self.positional_encoding(src)
        trg = self.embedding(trg)
        trg = self.positional_encoding(trg)

        # 编码器
        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        # 解码器
        for layer in self.decoder_layers:
            trg = layer(trg, src, trg_mask, src_mask)

        output = self.fc_out(trg)

        return output

最后附上引用部分

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

模型效果

编写代码测试模型的复现是否正确(没有跑任务)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size = 10000
d_model = 512
n_heads = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff = 2048
dropout = 0.1

transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout).to(device)

src = torch.randint(0, vocab_size, (32, 10)).to(device) # 源语言
trg = torch.randint(0, vocab_size, (32, 20)).to(device) # 目标语言

src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(device)
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2).to(device)

output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

打印结果

torch.Size([32, 20, 10000])

说明模型正常运行了

总结与心得体会

我是从CV模型学到Transfromer来的,通过对Transformer模型的复现我发现:

  • 类似于残差的连接在Transformer中也十分常见,还有先缩小再放大的Bottleneck结构。
  • 整个Transformer模型的核心处理对特征的维度没有变化,这一点和CV模型完全不同。
  • Transformer的核心是多头自注意机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/502739.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Echarts地图之——如何给地图添加外边框轮廓

有时候我们希望给地图外围加一圈边框来增加美感 但实际情况中,我们需要把国界的边框和各个省份属于国界的边框相吻合,否则就会造成两者看起来是错位的感觉 这就需要我们把echarts registerMap的全国省份json和国界边框json的坐标相一致。 这个json我们可…

WEPE系统安装纯净版window11教程(包含pe内系统安装方法)

目录 一.安装u盘启动盘 1.1制作安装系统引导盘 1.2下载保存windows镜像 1.3根据自己电脑品牌查询进入BIOS设置的方法 1.4我们成功进入了PE 二.重装系统 2.1遇到问题 2.2重新来到这个界面 三.PE中基本软件的作用 四.学习声明 今天不敲代码,今天来讲讲We P…

PetaLinux 去除自动获取 IP 地址

问题:系统启动的时候会自动检测 IP 地址,如不需要这个功能(该过程需耗时十几秒)。可以自定义 IP 地址,去掉这一步。 操作步骤如下: 所有命令均需在非管理员模式下执行 1. 初始化 PetaLinux 运行环境 运行…

LabVIEW车载轴承振动监测系统

LabVIEW车载轴承振动监测系统 随着汽车工业的快速发展,车用轴承的稳定性和可靠性对保障车辆安全运行越来越重要。目前,大多数车用轴承工作在恶劣的环境下,容易出现各种故障。开发了一种基于LabVIEW的车载轴承振动监测系统,提高车…

算法题:桃飘火焰焰,梨堕雪漠漠(Java贪心)

链接:桃飘火焰焰,梨堕雪漠漠 来源:牛客网 题目描述 在某游戏平台打折之际,EternityEternityEternity兴致勃勃地在该游戏平台上购买了nnn个不同的游戏,从1到nnn编号。 通过游览游戏论坛EternityEternityEternity确定…

# Apache SeaTunnel 究竟是什么?

作者 | Shawn Gordon 翻译 | Debra Chen 原文链接 | What the Heck is Apache SeaTunnel? 我在2023年初开始注意到Apache SeaTunnel的相关讨论,一直低调地关注着。该项目始于2017年,最初名为Waterdrop,在Apache DolphinScheduler的创建者…

LInux: fork()究竟是如何工作的?为何一个变量能够接受两个返回值?

LInux: fork函数究竟是如何工作的?为何一个变量能够接受两个返回值? 前言一、fork()用法二 、fork()应用实例展示三、fork()工作原理3.1 为什么要创建子进程?3.2 fork()究竟干了些什么?3.3 fork为什么会存在两个返回值&#xff1f…

文件上传漏洞-黑名单检测

黑名单检测 一般情况下,代码文件里会有一个数组或者列表,该数组或者列表里会包含一些非法的字符或者字符串,当数据包中含有符合该列表的字符串时,即认定该数据包是非法的。 如下图,定义了一个数组$deny_ext array(.a…

如何在Linux系统部署ONLYOFFICE协作办公利器并实现多人实时编辑文档

文章目录 1. 安装Docker2. 本地安装部署ONLYOFFICE3. 安装cpolar内网穿透4. 固定OnlyOffice公网地址 本篇文章讲解如何使用Docker在本地服务器上安装ONLYOFFICE,并结合cpolar内网穿透实现公网访问。 Community Edition允许您在本地服务器上安装ONLYOFFICE文档&…

DSVPN实验报告

一、分析要求 1. 配置R5为ISP,只能进行IP地址配置,所有地址均配为公有IP地址。 - 在R5上,将接口配置为公有IP地址,并确保只进行了IP地址配置。 2. R1和R5之间使用PPP的PAP认证,R5为主认证方;R2于R5之间…

Figma使用问题(更新自己遇到的问题)

文章目录 前言一、如何安装插件?方法1:Figma Community / Figma中文社区方法2:菜单栏 二、图片倾斜插件使用1.Angle Mockups前提:执行过程: 三.中文字体插件(宋体等)Chinese Font Picker前提&am…

【算法题】三道题理解算法思想——二分查找算法

二分查找算法 本篇文章中会带大家从零基础到学会利用二分查找的思想解决算法题,我从力扣上筛选了三道题,难度由浅到深,会附上题目链接以及算法原理和解题代码,希望大家能坚持看完,绝对能有收获,大家有更好…

阿里云2核4G服务器租用价格,支持多少人在线?

阿里云2核4G服务器多少钱一年?2核4G配置1个月多少钱?2核4G服务器30元3个月、轻量应用服务器2核4G4M带宽165元一年、企业用户2核4G5M带宽199元一年。可以在阿里云CLUB中心查看 aliyun.club 当前最新2核4G服务器精准报价、优惠券和活动信息。 阿里云官方2…

[项目实践]---RSTP生成树

[项目实践] 目录 [项目实践] 一、项目环境 二、项目规划 三、项目实施 四、项目测试 |验证 ---RSTP生成树 一、项目环境 Jan16 公司为提高网络的可靠性,使用了两台高性能交换机作为核心交换机,接入层交 换机与核心层交换机互联,形成冗…

数据恢复宝典:揭秘分区合并后的数据拯救之路

在计算机存储管理中,分区合并是一项常见的硬盘操作。它通过将两个或多个相邻的磁盘分区合并成一个更大的分区,来扩展存储空间或简化磁盘管理。然而,这个看似简单的操作背后,却隐藏着数据丢失的巨大风险。许多用户在尝试分区合并时…

【Linux系统】信号量实现同步和互斥

一.回顾 在这之前已经讲解了System V版本的信号量,主要内容为以下3点: 信号量本质是一把计数器申请信号量本质就是预订资源PV操作(申请和释放)是原子的 今天我们要学习的是POSIX版本的信号量,以上三点同样遵循 二.信号量VS互斥锁 1.联系&…

蓝桥杯23年第十四届省赛真题-三国游戏|贪心,sort函数排序

题目链接: 1.三国游戏 - 蓝桥云课 (lanqiao.cn) 蓝桥杯2023年第十四届省赛真题-三国游戏 - C语言网 (dotcpp.com) 虽然这道题不难,很容易想到,但是这个视频的思路理得很清楚: [蓝桥杯]真题讲解:三国游戏&#xff0…

OpenHarmony系统开发之应用接口文件转换工具介绍

简介: 应用接口文件转换工具是根据异构格式接口文件(.h 文件)转换生成 OpenHarmony 系统应用层需要的 TS(type-script)接口文件(*.d.ts)的工具。若某个服务实现方式为 c,且供应用层访问的接口已在.h 文件中定义,此时,NAPI 接口开…

23年蓝桥杯javaB组

23年蓝桥杯java-b组 前言: 23年蓝桥杯当时并没有参加,不过打算参加24年的蓝桥杯,于是打算复习下23年的题目,哦,不做不知道,做了几道题后评价一下,真的是老🐷上🏠&#…

13 完全分布式搭建-集群配置

1.集群部署规划 NameNode 和 SecondaryNameNode 不要安装在同一台服务器 ResourceManager 也很消耗内存,不要和 NameNode、SecondaryNameNode 配置在 同一台机器上。 在文章中与教材上有区别,在理论课上已讲解。 masterslave01slave02HDFS NameNode D…