大模型基础——从零实现一个Transformer(2)

大模型基础——从零实现一个Transformer(1)

一、引言

上一章主要实现了一下Transformer里面的BPE算法和 Embedding模块定义
本章主要讲一下 Transformer里面的位置编码以及多头注意力

二、位置编码

2.1正弦位置编码(Sinusoidal Position Encoding)

其中:

pos:表示token在文本中的位置
: i代表词向量具体的某一维度,即位置编码的每个维度对应一个波长不同的正弦或余弦波
d : d表示位置编码的最大维度,和词嵌入的维度相同,假设是512

对于位置0的编码为:

对于位置1的编码为:

2.2 正弦位置编码特性

  • 相对位置关系:pos + k的位置编码可以被位置pos的位置编码线性表示
    三角函数公式如下:

对于pos + k的位置编码:

根据式( 3 )和( 4 )整理上式有:

  • 位置之间的相对距离

𝑃𝐸𝑝𝑜𝑠+𝑘∙𝑃𝐸𝑝𝑜𝑠 的内积:

位置之间内积的关系大小如下:

可以看到内积会随着相对位置的递增而减少,从而可以表示位置的相对距离。内积的结果是对称的,所以没有方向信息。

2.3 代码实现

import torch
from torch import nn,Tensor
import math


class PositionalEmbedding(nn.Module):
    def __init__(self,d_model:int=512,dropout:float=0.1,max_positions:int=1024) -> None:
        '''

        :param d_model: embedding向量的维度
        :param dropout:
        :param max_positions: 最大长度
        '''
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Position Embedding  (max_positions,d_model)
        pe = torch.zeros(max_positions,d_model)

        # 创建position index列表 ,形状为:(max_positions, 1)
        position = torch.arange(0,max_positions).unsqueeze(1)

        # d_model 维度 偶数位是sin ,奇数位是cos
        # 计算除数,这里的除数将用于计算正弦和余弦的频率
        div_term = torch.exp(
            torch.arange(0,d_model,2) * -(math.log(10000.0) /d_model)
        )

        # 对矩阵的偶数列(0,2,4...)进行正弦函数编码
        pe[:, 0::2] = torch.sin(position * div_term)

        # 对矩阵的奇数列(1,3,5...)进行余弦函数编码
        pe[:, 1::2] = torch.cos(position * div_term)

        # 扩展维度,增加batch_size: pe (1, max_positions, d_model)
        pe = pe.unsqueeze(0)

        # buffers will not be trained
        self.register_buffer("pe", pe)

    def forward(self,x:Tensor) ->Tensor:
        """

                Args:
                    x (Tensor): (batch_size, seq_len, d_model) embeddings

                Returns:
                    Tensor: (batch_size, seq_len, d_model)
        """

        # x.size(1)是指当前x的最大长度
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

if __name__ == '__main__':
    seq_len = 128
    d_model = 512

    pe = PositionalEmbedding(d_model)

    x = torch.rand((1,100,d_model))
    print(pe(x).shape)

三、多头注意力

3.1 自注意力

公式如下:

  • 假设一个矩阵X,分别乘上权重矩阵,,就得到了Q , K , V向量矩阵

  • 然后除以 𝑑𝑘 进行缩放,再经过Softmax,得到注意力权重矩阵,接着乘以value向量矩阵V,就一次得到了所有单词的输出矩阵Z

3.2 多头注意力

将原来n_head分割乘Nx n_sub_head.对于每个头i,都有它自己不同的key,query和value矩阵: 𝑊𝑖𝐾,𝑊𝑖𝑄,𝑊𝑖𝑉 。在多头注意力中,key和query的维度是 𝑑𝑘 ,value嵌入的维度是 𝑑𝑣 (其中key,query和value的维度可以不同,Transformer里面一般设置的是相同的),这样每个头i,权重 𝑊𝑖𝑄∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝐾∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝑉∈𝑅𝑑×𝑑𝑣 ,然后与压缩到X中的输入相乘,得到 𝑄∈𝑅𝑁×𝑑𝑘,𝐾∈𝑅𝑁×𝑑𝑘,𝑉∈𝑅𝑁×𝑑𝑣 .

3.3 代码实现

import math

import torch
from torch import nn,Tensor
from typing import *

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model: int = 512,n_heads: int=8,dropout: float = 0.1):
        '''

        :param d_model: embedding大小
        :param n_heads: 多头个数
        :param dropout:
        '''
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_key = d_model // n_heads

        self.q = nn.Linear(d_model,d_model)
        self.k = nn.Linear(d_model,d_model)
        self.k = nn.Linear(d_model,d_model)

        self.concat = nn.Linear(d_model,d_model)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self,x:Tensor,is_key : bool = False) -> Tensor:
        '''
        分割向量为N个头,如果是key的话,softmax时候,key需要转置一下
        :param x:
        :param is_key:
        :return:
        '''
        batch_size = x.size(0)

        # x (batch_size,seq_len,n_heads,d_key)
        x = x.view(batch_size,-1,self.n_heads,self.d_key)
        if is_key:
            # (batch_size,n_heads,d_key,seq_len)
            return x.permute(0,2,3,1)

        # (batch_size,n_heads,seq_len,d_key
        return x.transpose(1,2)

    def merge_heads(self,x: Tensor) -> Tensor:
        x = x.transpose(1,2).contigouse().view(x.size(0),-1,self.d_model)
        return x

    def attention(self,
                  query:Tensor,
                  key:Tensor,
                  value:Tensor,
                  mask:Tensor = None,
                  keep_attentions:bool = False):

        scores = torch.matmul(query,key) / math.sqrt(self.d_key)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # weights (batch_size,n_heads,q_length,k_length)
        weights = self.dropout(torch.softmax(scores,dim=-1))

        # (batch_size,n_heads,q_length,k_length) x (batch_size,n_heads,v_length,d_key)
        # -> (batch_size,n_heads,q_length,d_key)
        # assert k_length == v_length

        # attn_output (batch_size, n_heads, q_length, d_key)
        atten_output = torch.matmul(weights,value)

        if keep_attentions:
            self.weights = weights
        else:
            del weights

        return atten_output

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                mask: Tensor = None,
                keep_attentions: bool = False)-> Tuple[Tensor,Tensor]:
        '''

        :param query:(batch_size, q_length, d_model)
        :param key:(batch_size, k_length, d_model)
        :param value:(batch_size, v_length, d_model)
        :param mask: mask for padding or decoder. Defaults to None.
        :param keep_attentions: whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, q_length, d_model) attention output
        '''
        query = self.q(query)
        key = self.k(key)
        value = self.v(value)

        query,key,value = (
            self.split_heads(query),
            self.split_heads(key,is_key=True),
            self.split_heads(value)
        )

        atten_output = self.attention(query,key,value,mask,keep_attentions)

        del query
        del key
        del value

        # concat
        concat_output = self.merge_heads(atten_output)

        # the final liear
        # output (batch_size, q_length, d_model)
        output = self.concat(concat_output)
        
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/696871.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【JVM】从编译后的指令集来再次理解++i和i++的执行顺序

JVM为什么要选用基于栈的指令集架构 与基于寄存器的指令集架构相比,基于栈的指令集架构不依赖于硬件,因此可移植性更好,跨平台性更好因为栈结构的特性,永远都是先处理栈顶的第一条指令,因此大部分指令都是零地址指令&…

SpringMVC[从零开始]

SpringMVC SpringMVC简介 1.1什么是MVC MVC是一种软件架构的思想,将软件按照模型、视图、控制器来划分 M:Model,模型层,指工程中的JavaBean,作用是处理数据 JavaBean分为两类: 一类称为实体类Bean:专…

对猫毛过敏?怎么有效的缓解过敏症状,宠物空气净化器有用吗?

猫过敏是一种常见的过敏反应,由猫的皮屑、唾液或尿液中的蛋白质引起。这些蛋白质被称为过敏原,它们可以通过空气传播,被人体吸入后,会触发免疫系统的过度反应。猫过敏是宠物过敏中最常见的类型之一,对许多人来说&#…

【Java】static 修饰变量

static 一种java内置关键字,静态关键字,可以修饰成员变量、成员方法。 static 成员变量 1.static 成员变量2.类变量图解3.类变量的访问4.类变量的内存原理5.类变量的应用 1.static 成员变量 成员变量按照有无static修饰,可以分为 类变量…

Python学习打卡:day02

day2 笔记来源于:黑马程序员python教程,8天python从入门到精通,学python看这套就够了 8、字符串的三种定义方式 字符串在Python中有多种定义形式 单引号定义法: name 黑马程序员双引号定义法: name "黑马程序…

如何为色盲适配图形用户界面

首发日期 2024-05-25, 以下为原文内容: 答案很简单: 把彩色去掉, 测试. 色盲, 正式名称 色觉异常. 众所周知, 色盲分不清颜色. 如果用户界面设计的不合理, 比如不同项目只使用颜色区分, 而没有形状区分, 那么色盲使用起来就会非常难受, 甚至无法使用. 色盲中最严重的情况称为…

2024PTA算法竞赛考试编程题代码

目录 前言 题目和代码 L1-006 连续因子 L1-009 N个数求和 L2-004 这是二叉搜索树吗? L2-006 树的遍历 L2-007 家庭房产 L4-118 均是素数 L4-203 三足鼎立 L2-002 链表去重 L2-003 月饼 L2-026 小字辈 L4-201 出栈序列的合法性 L4-205 浪漫侧影 前言 所…

【数据结构】AVL树(平衡二叉树)

目录 一、AVL树的概念二、AVL树的节点三、AVL树的插入四、AVL树的旋转1.插入在较高左子树的左侧,使用右单旋2.插入在较高右子树的右侧,使用左单旋3.插入较高左子树的右侧,先左单旋再右单旋4.插入较高右子树的左侧,先右单旋再左单旋…

unity基础(五)地形详解

目录 一 创建地形 二 调整地形大小 三 创建相邻地形 四 创建山峰 五 创建树木 七 添加风 八 添加水 简介: Unity 中的基础地形是构建虚拟场景的重要元素之一。 它提供了一种直观且灵活的方式来创建各种地形地貌,如山脉、平原、山谷等。 通过 Unity 的地形…

C51学习归纳9 --- I2C通讯学习(重点)

首先,我自己学习过以后的直观感觉,通信协议是单片机的灵魂之一,只有规定好了通信协议我们才能够正确的接收到信息,才能实现更加深入的研究。所以这一部分是需要好好学习的。 本节借助一个可存储的芯片AT24C02,进行在I2…

开源低代码平台技术为数字化转型赋能!

实现数字化转型升级是很多企业未来的发展趋势,也是企业获得更多发展商机的途径。如何进行数字化转型?如何实现流程化办公?这些都是摆在客户面前的实际问题,借助于开源低代码平台技术的优势特点,可以轻松助力企业降低开…

【设计模式】创建型设计模式之 建造者模式

文章目录 一、介绍定义UML 类图 二、用法1 简化复杂对象具体构建过程省略抽象的 Builder 类省略 Director 类 三、用法2 控制对象构造方法、限制参数关系Guava 中使用建造者模式构建 cache 来进行参数校验 一、介绍 定义 建造者模式,将一个复杂的对象的构建过程与…

互联网应用主流框架整合之SpringMVC初始化及各组件工作原理

Spring MVC的初始化和流程 MVC理念的发展 SpringMVC是Spring提供给Web应用领域的框架设计,MVC分别是Model-View-Controller的缩写,它是一个设计理念,不仅仅存在于Java中,各类语言及开发均可用,其运转流程和各组件的应…

探索OrangePi AIpro:单板计算机的深度体验之旅

准备阶段:环境与资料 在开始我们的探索之旅前,确保您已准备好以下装备: OrangePi AIpro:我们的主角,一台功能强大的单板计算机。Windows 10笔记本电脑:作为我们的辅助工具,用于管理和测试。路…

FastAPI:在大模型中使用fastapi对外提供接口

通过本文你可以了解到: 如何安装fastapi,快速接入如何让大模型对外提供API接口 往期文章回顾: 1.大模型学习资料整理:大模型学习资料整理:如何从0到1学习大模型,搭建个人或企业RAG系统,如何评估…

python ---使用python操作mysql ---> pymysql

本章内容: 1:能够完成从MySQL中读取出数据; [重点] 查询: execute()、fetchall() 2:能够将数据写入MySQL数据库。 [重点] 插入数据: execute() sql insert into xxx [掌握]pymysql模块的安装 目标:了解如何安装pymysql模块? 当要使用Python和M…

操作系统复习-存储管理之虚拟内存

虚拟内存概述 有些进程实际需要的内存很大,超过物理内存的容量。多道程序设计,使得每个进程可用物理内存更加稀缺。不可能无限增加物理内存,物理内存总有不够的时候。虚拟内存是操作系统内存管理的关键技术。使得多道程序运行和大程序运行称…

永久免费的iPhone,iPad,Mac,iWatch锁屏,桌面壁纸样机生成器NO.105

使用这个壁纸样机生成器,生成iPhone,iPad,Mac,iWatch锁屏,桌面壁纸,展示你的壁纸作品,一眼就看出壁纸好不好看,适不适合 资源来源于网络,免费分享仅供学习和测试使用&am…

【C语言初阶】分支语句

🌟博主主页:我是一只海绵派大星 📚专栏分类:C语言 ❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、什么是语句 二、if语句 悬空else 三、switch语句 default 四、switch语句与if-else语句性能对比如何&#xff1f…

【Python核心数据结构探秘】:元组与字典的完美协奏曲

文章目录 🚀一、元组⭐1. 元组查询的相关方法❤️2. 坑点🎬3. 修改元组 🌈二、集合⭐1. 集合踩坑❤️2. 集合特点💥无序性💥唯一性 ☔3. 集合(交,并,补)🎬4. …