大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(1)-CSDN博客

大模型基础——从零实现一个Transformer(2)-CSDN博客

大模型基础——从零实现一个Transformer(3)-CSDN博客

大模型基础——从零实现一个Transformer(4)-CSDN博客


一、前言

上一篇文章已经把Encoder模块和Decoder模块都已经实现了,

接下来来实现完整的Transformer

二、Transformer

Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠

import torch
from torch import  nn,Tensor
from torch.nn import  Embedding

#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_mask

class Transformer(nn.Module):
    def __init__(self,
                 source_vocab_size:int,
                 target_vocab_size:int,
                 d_model: int = 512,
                 n_heads: int = 8,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 d_ff: int = 2048,
                 dropout: float = 0.1,
                 max_positions:int = 5000,
                 pad_idx: int = 0,
                 norm_first: bool=False) -> None:
        '''

        :param source_vocab_size: size of the source vocabulary.
        :param target_vocab_size: size of the target vocabulary.
        :param d_model: dimension of embeddings. Defaults to 512.
        :param n_heads: number of heads. Defaults to 8.
        :param num_encoder_layers: number of encoder blocks. Defaults to 6.
        :param num_decoder_layers: number of decoder blocks. Defaults to 6.
        :param d_ff: dimension of inner feed-forward network. Defaults to 2048.
        :param dropout: dropout ratio. Defaults to 0.1.
        :param max_positions: maximum sequence length for positional encoding. Defaults to 5000.
        :param pad_idx: pad index. Defaults to 0.
        :param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).
                Otherwise it's done after(Post-Norm). Default to False.
        '''
        super().__init__()
        # Token embedding
        self.src_embeddings = Embedding(source_vocab_size,d_model)
        self.target_embeddings = Embedding(target_vocab_size,d_model)

        # Position embedding
        self.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
        self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)


        # 编码层定义
        self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)
        # 解码层定义
        self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)

        self.pad_idx = pad_idx


    def encode(self,
               src:Tensor,
               src_mask: Tensor=None,
               keep_attentions: bool=False) -> Tensor:
        '''
        编码过程
        :param src: (batch_size, src_seq_length) the sequence to the encoder
        :param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence
        :param keep_attentions:  whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, seq_length, d_model) encoder output
        '''

        src_embedding_tensor = self.src_embeddings(src)
        src_embedded = self.encoder_pos(src_embedding_tensor)

        return self.encoder(src_embedded,src_mask,keep_attentions)

    def decode(self,
               target_tensor: Tensor,
               memory: Tensor,
               target_mask: Tensor = None,
               memory_mask: Tensor = None,
               keep_attentions: bool = False) ->Tensor:
        '''

        :param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.
        :param memory: (batch_size, src_seq_length, d_model) the  sequence from the last layer of the encoder.
        :param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.
        :param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.
        :param keep_attentions:  whether keep attention weigths or not. Defaults to False.
        :return: output (batch_size, tgt_seq_length, tgt_vocab_size)
        '''
        target_embedding_tensor = self.target_embeddings(target_tensor)
        target_embedded = self.decoder_pos(target_embedding_tensor)

        # logits (batch_size, target_seq_length, d_model)
        logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)
        return  logits


    def forward(self,
                src: Tensor,
                target: Tensor,
                src_mask: Tensor=None,
                target_mask: Tensor=None,
                keep_attention:bool=False)->Tensor:
        '''

        :param src: (batch_size, src_seq_length) the sequence to the encoder
        :param target:  (batch_size, tgt_seq_length) the sequence to the decoder
        :param src_mask:
        :param target_mask:
        :param keep_attention: whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, tgt_seq_length, tgt_vocab_size)
        '''
        memory = self.encode(src,src_mask,keep_attention)
        return  self.decode(target,memory,target_mask,src_mask,keep_attention)

三、测试

写个简单的main函数,测试一下整体网络是否正常

if __name__ == '__main__':
    source_vocab_size = 300
    target_vocab_size = 300
    # padding对应的index,一般都是0
    pad_idx = 0

    batch_size = 1
    max_positions = 20

    model = Transformer(source_vocab_size=source_vocab_size,
                        target_vocab_size=target_vocab_size)

    src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
    target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))

    ## 最后5位置是padding
    src_tensor[:,-5:] = 0

    ## 最后10位置是padding
    target_tensor[:, -10:] = 0

    src_mask = (src_tensor != pad_idx).unsqueeze(1)
    targe_mask =  make_target_mask(target_tensor)

    logits = model(src_tensor,target_tensor,src_mask,targe_mask)
    print(logits.shape) 
    #torch.Size([1, 20, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/718964.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue标签组

先看样式 再看代码 <div v-else class"relative"><n-tabs ref"tabsInstRef" v-model:value"selectValue" class"min-w-3xl myTabs"><n-tab-panev-for"(tab) in songsTags" :key"tab.name" displ…

【数据结构初阶】--- 堆

文章目录 一、什么是堆&#xff1f;树二叉树完全二叉树堆的分类堆的实现方法 二、堆的操作堆的定义初始化插入数据&#xff08;包含向上调整详细讲解&#xff09;向上调整删除堆顶元素&#xff08;包含向下调整详细讲解&#xff09;向下调整返回堆顶元素判断堆是否为空销毁 三、…

Docker安装(内网无网环境),亲测简单易懂

文章目录 前言一、安装环境二、安装步骤三、启动四、查看状态总结 前言 Docker安装&#xff08;内网无网环境&#xff09;&#xff0c;亲测简单易懂 一、安装环境 CentOS Linux release 7.x Docker版本&#xff1a;18.09.8 二、安装步骤 &#xff01;&#xff01;&#xf…

网络学习(三)TCP三次握手、四次挥手,及Wireshark抓包验证

目录 一、什么是 TCP 三次握手&#xff1f;二、什么是 TCP 四次挥手&#xff1f;三、Wireshark抓包验证3.1 如何捕获三次握手、四次挥手3.2 TCP 三次握手的记录3.3 数据传输3.4 TCP 四次挥手的记录 一、什么是 TCP 三次握手&#xff1f; TCP&#xff08;Transmission Control …

计算机组成原理之存储器(二)

文章目录 随机读写存储器RAM静态MOS存储单元与存储芯片动态MOS存储单元与存储芯片 半导体存储器逻辑设计存储器的读写以及刷新存储器的读写动态存储芯片的刷新 随机读写存储器RAM 静态MOS存储单元与存储芯片 静态RAM用半导体管的导通和截止来记忆&#xff0c;只要不掉电&#x…

2.线性神经网络

目录 1.线性回归一个简化模型线性模型&#xff1a;可以看做是单层神经网络衡量预估质量训练数据参数学习显示解总结 2.基础优化方法小批量随机梯度下降总结 3.Softmax回归&#xff1a;其实是一个分类问题回归VS分类从回归到多类分类---均方损失Softmax和交叉熵损失 4.损失函数L…

使用插件永久解决IDEA使用Shift+F10失效问题(不需要换老版本输入法)

在日常编程中&#xff0c;使用快捷键可以大大提高开发效率。然而&#xff0c;有时候我们会遇到IDEA 中&#xff0c;ShiftF10 快捷键失效。这个蛋疼的问题现在终于可以得到解决&#xff0c;上个月在逛V2EX的时候看见一位大佬做的插件。 大佬链接&#xff1a;https://www.v2ex.c…

编程精粹—— Microsoft 编写优质无错 C 程序秘诀 02:设计并使用断言

这是一本老书&#xff0c;作者 Steve Maguire 在微软工作期间写了这本书&#xff0c;英文版于 1993 年发布。2013 年推出了 20 周年纪念第二版。我们看到的标题是中译版名字&#xff0c;英文版的名字是《Writing Clean Code ─── Microsoft’s Techniques for Developing》&a…

51单片机宏定义的例子

代码 demo.c #include "hardware.h"void delay() {volatile unsigned int n;for(n 0; n < 50000; n); }int main(void) {IO_init();while(1){PINSET(LED);delay();PINCLR(LED);delay();}return 0; }cfg.h #ifndef _CFG_H_ #define _CFG_H_// #define F_CPU …

nacos注册中心配置中心集群搭建

文章目录 学习连接1.Nacos安装与简单使用1.1. Nacos安装指南Windows安装下载安装包解压端口配置启动访问 Linux安装安装JDK上传安装包解压端口配置启动 1.2.服务注册到nacos使用步骤引入依赖配置nacos地址重启 示例父工程pom.xmluser-servicepom.xmlapplication.ymlUserApplica…

Jupyter Notebook 中 %run 魔法命令

目录 基本用法运行 Python 脚本运行 Jupyter Notebook 的其他单元格传递命令行参数 示例运行 Python 脚本示例运行其他 Jupyter Notebook 示例传递命令行参数示例 注意事项与 import 命令的区别%runimport 结论 %run 是 Jupyter Notebook 中的一个强大工具&#xff0c;它允许你…

【机器学习】第4章 决策树算法(重点)

一、概念 1.原理看图&#xff0c;非常简单&#xff1a; &#xff08;1&#xff09;蓝的是节点&#xff0c;白的是分支&#xff08;条件&#xff0c;或者说是特征&#xff0c;属性&#xff0c;也可以直接写线上&#xff0c;看题目有没有要求&#xff09;&#xff0c; &#xff0…

MySQL----InooDB行级锁、间隙锁

行级锁 行锁&#xff0c;也称为记录锁&#xff0c;顾名思义就是在记录上加的锁。 注意&#xff1a; InnoDB行锁是通过给索引上的索引项加锁来实现的&#xff0c;而不是给表的行记录加锁实现的&#xff0c;这就意味着只有通过索引条件检索数据&#xff0c;InnoDB才使用行级锁…

【开发工具】git服务器端安装部署+客户端配置

自己安装一个轻量级的git服务端&#xff0c;仅仅作为代码维护&#xff0c;尤其适合个人代码管理。毕竟代码的版本管理是很有必要的。 这里把git服务端部署在centos系统里&#xff0c;部署完成后可以通过命令行推拉代码&#xff0c;进行版本和用户管理。 一、服务端安装配置 …

【Kubernetes】k8s--安全机制

机制说明 Kubernetes 作为一个分布式集群的管理工具&#xff0c;保证集群的安全性是其一个重要的任务。API Server 是集群内部各个组件通信的中介&#xff0c; 也是外部控制的入口。所以 Kubernetes 的安全机制基本就是围绕保护 API Server 来设计的。 比如 kubectl 如果想向 …

新版FMEA培训内容中关于团队协作的部分可以怎么展开?

团队协作&#xff0c;作为新版FMEA的核心要素之一&#xff0c;其重要性不言而喻。在FMEA的分析过程中&#xff0c;团队成员的密切合作与沟通是确保分析全面性和准确性的关键。通过团队协作&#xff0c;不同领域的专家能够共同参与到潜在故障模式的识别、评估与预防中来&#xf…

解决ubuntu22.04共享文件夹问题

刚开机发现ubuntu里面的共享文件夹访问不了了 ubuntuwxy:/mnt/hgfs$ ls找了几篇博客&#xff0c;设置如下指令即可&#xff0c;记得退出当前目录重新进入刷新一下 sudo vmhgfs-fuse .host:/ /mnt/hgfs/ -o allow_other -o uid1000 仅供参考

针对indexedDB的简易封装

连接数据库 我们首先创建一个DBManager类&#xff0c;通过这个类new出来的对象管理一个数据库 具体关于indexedDB的相关内容可以看我的这篇博客 indexedDB class DBManager{}我们首先需要打开数据库&#xff0c;打开数据库需要数据库名和该数据库的版本 constructor(dbName,…

[WTL/Win32]_[中级]_[MVP架构在实际项目中应用的地方]

场景 在开发Windows和macOS的界面软件时&#xff0c;Windows用的是WTL/Win32技术&#xff0c;而macOS用的是Cocoa技术。而两种技术的本地语言一个主打是C,另一个却是Object-c。界面软件的源码随着项目功能增多而增多&#xff0c;这就会给同步Windows和macOS的功能造成很大负担…

Aigtek高压放大器在柔性爬行机器人驱动性能研究中的应用

实验名称&#xff1a;柔性爬行机器人的材料测试 研究方向&#xff1a;介电弹性体的最小能量结构是一种利用DE材料的电致变形与柔性框架形变相结合设计的新型柔性驱动器&#xff0c;所谓最小能量是指驱动器在平衡状态时整个系统的能量最小&#xff0c;当系统在外界的电压刺激下就…