【Transformer从零开始代码实现 pytoch版】(五)总架构类的实现

Transformer总架构

在这里插入图片描述
在实现完输入部分、编码器、解码器和输出部分之后,就可以封装各个部件为一个完整的实体类了。

【Transformer从零开始代码实现 pytoch版】(一)输入部件:embedding+positionalEncoding

【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

【Transformer从零开始代码实现 pytoch版】(四)输出部件:Linear+softmax

编码器-解码器总结构代码实现

class EncoderDecoder(nn.Module):
    """ 编码器解码器架构实现、定义了初始化、forward、encode和decode部件
    """
    def __init__(self, encoder, decoder, source_embed, target_embed, generator):
        """ 传入五大部件参数

        :param encoder: 编码器
        :param decoder: 解码器
        :param source_embed: 源数据embedding函数
        :param target_embed: 目标数据embedding函数
        :param generator: 输出部分类被生成器对象
        """
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = source_embed
        self.tgt_embed = target_embed
        self.generator = generator					# 生成器后面会专门用到

    def forward(self, source, target, source_mask, target_mask):
        """ 构建数据流入流出

        :param source: 源数据
        :param target: 目标数据
        :param source_mask: 源数据掩码张量
        :param target_mask: 目标数据掩码张量
        :return:
        """
        # 注意这里先用的encode和decode函数,又才在其函数里面,再用了encoder和decoder
        return self.decode(self.encode(source, source_mask), source_mask, target, target_mask)

    def encode(self, source, source_mask):
        """ 编码函数,编码部件

        :param source: 源数据张量
        :param source_mask: 源数据的掩码张量
        :return: 经过解码器的输出
        """
        return self.encoder(self.src_embed(source), source_mask)

    def decode(self, memory, source_mask, target, target_mask):
        """ 解码函数,解码部件

        :param memory:编码器的输出QV
        :param source_mask:源数据的掩码张量
        :param target:目标数据
        :param target_mask:目标数据的掩码张量
        :return:
        """
        return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

示例

# 输入参数
vocab_size = 1000
size = d_model = 512

# 编码器部分
dropout = 0.2
d_ff = 64				# 隐藏层参数
head = 8				# 注意力头数
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
encoder_layer = EncoderLayer(size, c(attn), c(ff), dropout)
encoder_N = 8
encoder = Encoder(encoder_layer, encoder_N)

# 解码器部分
dropout = 0.2
d_ff = 64
head = 8
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
decoder_N = 8
decoder = Decoder(decoder_layer, decoder_N)

# 用了nn的embedding作为输入示意
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = Generator(d_model, vocab_size)

# 输入张量和掩码张量
source = target = torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]])
source_mask = target_mask = torch.zeros(2, 4, 4)

# 实例化编码器-解码器,再带入参数实现
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_res = ed(source, target, source_mask, target_mask)
print(f"ed_res: {ed_res}\n shape:{ed_res.shape}")


ed_res: tensor([[[-0.1861,  0.0849, -0.3015,  ...,  1.1753, -1.4933,  0.2484],
         [-0.3626,  1.3383,  0.1739,  ...,  1.1304,  2.0266, -0.5929],
         [ 0.0785,  1.4932,  0.3184,  ..., -0.2021, -0.2330,  0.1539],
         [-0.9703,  1.1944,  0.1763,  ...,  0.1586, -0.6066, -0.6147]],
        [[-0.9216, -0.0309, -0.6490,  ...,  1.0177,  0.5574,  0.4873],
         [-1.4097,  0.6678, -0.6708,  ...,  1.1176,  0.1959, -1.2494],
         [-0.3204,  1.2794, -0.4022,  ...,  0.6319, -0.4709,  1.0520],
         [-1.3238,  1.1470, -0.9943,  ...,  0.4026,  1.0911,  0.1327]]],
       grad_fn=<AddBackward0>)
 shape:torch.Size([2, 4, 512])

编码器-解码器模型构建函数

def make_model(source_vocab, target_vocab, N=6, d_model=512, d_ff=2048, head=8, dropout=0.1):
    """ 用于构建模型

    :param source_vocab: 源数据词汇总数
    :param target_vocab: 目标词汇总数
    :param N: 解码器/解码器堆叠层数
    :param d_model: 词嵌入维度
    :param d_ff: 前馈全连接层隐藏层维度
    :param dropout: 置0比率
    :return: 返回构建编码器-解码器模型
    """
    # 拷贝函数,来保证拷贝的函数彼此之间相互独立,不受干扰
    c = copy.deepcopy

    # 实例化多头注意力
    attn = MultiHeadedAttention(head, d_model)

    # 实例化全连接层
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)

    # 实例化位置编码类,得到对象position
    position = PositionalEncoding(d_model, dropout)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, source_vocab), c(position)),
        nn.Sequential(Embedding(d_model, source_vocab), c(position)),
        Generator(d_model, target_vocab)
    )

    # 模型结构构建完成后,初始化模型中的参数
    for p in model.parameters():
        # 这里判定当参数维度大于1的时候,则会将其初始化成一个服从均匀分布的矩阵
        if p.dim() > 1:
            nn.init.xavier_normal(p)        # 生成服从正态分布的数,默认为U(-1, 1),更改第二个参数可以改值

    return model

示例

source_vocab = target_vocab = 11
N = 6
res = make_model(source_vocab, target_vocab, N)
print(res)


EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w1): Linear(in_features=512, out_features=2048, bias=True)
          (w2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-1): 2 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (src_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w1): Linear(in_features=512, out_features=2048, bias=True)
          (w2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-2): 3 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (src_embed): Sequential(
    (0): Embedding(
      (lut): Embedding(512, 11)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tgt_embed): Sequential(
    (0): Embedding(
      (lut): Embedding(512, 11)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (generator): Generator(
    (project): Linear(in_features=512, out_features=11, bias=True)
  )
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/139588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IP-guard WebServer 命令执行漏洞复现

简介 IP-guard是一款终端安全管理软件&#xff0c;旨在帮助企业保护终端设备安全、数据安全、管理网络使用和简化IT系统管理。在旧版本申请审批的文件预览功能用到了一个开源的插件 flexpaper&#xff0c;使用的这个插件版本存在远程命令执行漏洞&#xff0c;攻击者可利用该漏…

动作捕捉系统通过VRPN与ROS系统通信

NOKOV度量动作捕捉系统支持通过VRPN与机器人操作系统ROS通信&#xff0c;进行动作捕捉数据的传输。 一、加载数据 打开形影动捕软件&#xff0c;加载一段后处理数据。 这里选择一段小车飞机的同步数据。在这段数据里面&#xff0c;场景下包含两个刚体&#xff0c;分别是小车和…

1688商品详情API接口的使用方法、注意事项以及示例代码

1688商品详情API接口使用方法 1688商品详情API接口是1688平台提供的用于获取商品详细信息的接口。通过该接口&#xff0c;您可以获取到商品的ID、名称、价格、销量、评价等信息&#xff0c;从而进行进一步的数据分析和应用开发。本文将介绍1688商品详情API接口的使用方法、注意…

python爬虫top250电影数据

之前看到的&#xff0c;我改了一下&#xff0c;多了很多东西 import requests from bs4 import BeautifulSoup from openpyxl import Workbook from openpyxl.styles import Font import redef extract_movie_info(info):# 使用正则表达式提取信息pattern re.compile(r导演:…

react函数式组件props形式子向父传参

父组件中定义 子组件中触发回调传值 import { useState } from "react"; function Son(params) {const [count, setCount] useState(0);function handleClick() {console.log(params, paramsparamsparamsparamsparamsparams);params.onClick(111)setCount(count 1…

猫罐头怎么选择?精选的5款口碑好的猫罐头推荐!

猫罐头因其成分约80%为水分&#xff0c;对于不喜欢喝水的猫咪来说&#xff0c;正是可以用来补充水分的替代方案。 而近年来市面上也有越来越多讲究食用安全性的猫罐头&#xff0c;像是强调无添加多余加工品、或是不含谷物成分等的商品。但也因为种类过多&#xff0c;让铲屎官容…

k8s系列文章二:集群配置

一、关闭交换分区 # 临时关闭分区 swapoff -a # 永久\关闭自动挂载swap分区 sudo sed -i / swap / s/^\(.*\)$/#\1/g /etc/fstab 二、修改cgroup管理器 ubuntu 系统&#xff0c;debian 系统&#xff0c;centos7 系统&#xff0c;都是使用 systemd 初始化系统的。systemd 这边…

11-13 代理模式

调用者 代理对象 目标对象 代理对象除了可以完成核心任务&#xff0c;还可以增强其他任务,无感的增强 代理模式目的: 不改变目标对象的目标方法的前提,去增强目标方法 分为:静态代理,动态代理 静态代理 有对象->前提需要有一个类&#xff0c;那么我们可以事先写好一个类&a…

擎创动态 | 再获上海区政府肯定,擎创科技被评为年度优秀高新技术企业

11月6日&#xff0c;上海市静安区副区长张慧和市北高新集团总裁陈军一行来到擎创科技调研指导&#xff0c;由擎创科技高管张健和陈莹陪同交流。 陈莹女士首先向副区长一行详细介绍了擎创科技的发展现状、落地实践效益以及未来的规划布局。在公司的成长过程中&#xff0c;得到静…

【Unity】 场景优化策略

Unity 场景优化策略 GPU instancing 使用GPU Instancing可以将多个网格相同、材质相同、材质属性可以不同的物体合并为一个批次&#xff0c;从而减少Draw Calls的次数。这可以提高性能和渲染效率。 GPU instancing可用于绘制在场景中多次出现的几何体&#xff0c;例如树木或…

腾讯云优惠服务器有哪些?腾讯云服务器优惠券领取入口汇总

腾讯云此次推出云服务器中最实惠的2核2G服务器以每年仅需88元的超低价格为用户提供稳定可靠的计算资源。这样的价格对于个人网站、小型企业以及学生开发者来说绝对是一笔难以忽视的优惠。 腾讯云双十一领9999代金券 https://1111.mian100.cn 腾讯云新用户领2860代金券 https:…

快速拉取聚水潭单据的ETL工具

聚水潭介绍 聚水潭平台则是国内较为出名的电商ERP平台&#xff0c;为企业提供了便捷的销售和管理服务&#xff0c;专注于提高交易效率&#xff0c;但是如何将数据快速同步到其他系统一直是很多企业的痛点。 ETLCloud数据集成平台提供了丰富的数据分析工具和算法模型&#xff…

Nat. Med. | 成年人的城市生活环境对心理健康的影响

今天为大家介绍的是来自Jiayuan Xu和Gunter Schumann团队的一篇论文。城市居民暴露于许多可能相互结合和相互作用的环境因素&#xff0c;这些因素可能影响心理健康。目前尚未有工作尝试建模城市生活的复杂实际暴露与大脑和心理健康之间的关系&#xff0c;以及这如何受遗传因素调…

js设置图片放大缩小拖动

效果: 思路: 在外层box进行相对定位relative,img设置绝对定位absolute;通过监听滚轮事件(wheel),设置样式缩放中心点(transformOrigin)和缩放转换(transform);获取到图片大小和位置,设置对应图片宽度高度和top、left偏移;鼠标按下事件(mousedown)和鼠标移动事…

AI机器人软件定制流程

一、项目概述 AI机器人软件定制流程是根据客户的需求&#xff0c;定制开发一款具有人工智能功能的机器人软件。本方案将详细介绍AI机器人软件定制的整个流程&#xff0c;包括需求分析、设计、开发、测试和交付等环节。 二、需求分析 在定制AI机器人软件之前&#xff0c;需要…

pmp考试是智商税吗,是一场持久的割韭菜战吗?

考试只是考试&#xff0c;是不是智商税要看人&#xff0c;跟风考&#xff0c;考了不用&#xff0c;那真的就是智商税&#xff0c;被割韭菜。 那么&#xff0c;什么人适合考PMP&#xff1f; 1、有项目管理实践经验的人&#xff1a;PMP是基于项目管理实践经验的认证考试&#x…

Python数据结构:元组(Tuple)详解

1.介绍和基础操作 Python中的元组&#xff08;Tuple&#xff09;是不可变有序序列&#xff0c;可以容纳任意数据类型&#xff08;包括数字、字符串、布尔型、列表、字典等&#xff09;的元素&#xff0c;通常用圆括号() 包裹。与列表&#xff08;List&#xff09;类似&#xff…

Python实现猎人猎物优化算法(HPO)优化XGBoost回归模型(XGBRegressor算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 猎人猎物优化搜索算法(Hunter–prey optimizer, HPO)是由Naruei& Keynia于2022年提出的一种最新的…

论文导读 | 图流的分割和摘要

前 言 本次论文导读介绍有关图流的分割和摘要问题的3篇文章。第1篇是partition的&#xff0c;第2篇是summarization的。 首先介绍第一篇文章。 文章一&#xff1a;图分割在分布式系统中有广泛的应用 文章的问题定义是用划分边的方式来分割图。如图所示&#xff0c;把图&#…