Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

Transformer学习

  • 1 位置编码模块
    • 1.1 PE代码
    • 1.2 测试PE
    • 1.3 原文代码
  • 2 多头自注意力模块
    • 2.1 多头自注意力代码
    • 2.2 测试多头注意力
  • 3 未来序列掩码矩阵
    • 3.1 代码
    • 3.2 测试掩码

在这里插入图片描述

1 位置编码模块

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=\sin(pos/10000^{2i/d_{\mathrm{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i+1)=\cos(pos/10000^{2i/d_\mathrm{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

pos 是序列中每个对象的索引, p o s ∈ [ 0 , m a x s e q l e n ] pos\in [0,max_seq_len] pos[0,maxseqlen], i i i 向量维度序号, i ∈ [ 0 , e m b e d d i m / 2 ] i\in [0,embed_dim/2] i[0,embeddim/2], d m o d e l d_{model} dmodel是模型的embedding维度

1.1 PE代码

import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import seaborn as sns

def get_pos_ecoding(max_seq_len,embed_dim):
	# 初始化位置矩阵 [max_seq_len,embed_dim]
	pe = torch.zeros(max_seq_len,embed_dim])
	position = torch.arange(0,max_seq_len).unsqueeze(1) # [max_seq_len,1]
	print("位置:", position,position.shape)
	div_term = torch.exp(torch.arange(0,embed_dim,2)*-(math.log(10000.0)/embed_dim))   # 除项维度为embed_dim的一半,因为对矩阵分奇数和偶数位置进行填充。
	pe[:,0::2] = torch.sin(position/div_term)
	pe[:,1::2] = torch.cos(position/div_term)
	return pe

1.2 测试PE

pe = get_pos_ecoding(8,4)
plt.figure(figsize=(8,8))
sns.heatmap(pe)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")

输出:
位置: tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]]) torch.Size([8, 1])
除项: tensor([1.0000, 0.0100]) torch.Size([2])
在这里插入图片描述

plt.figure(figsize=(8, 5))
plt.plot(positional_encoding[1:, 1], label="dimension 1")
plt.plot(positional_encoding[1:, 2], label="dimension 2")
plt.plot(positional_encoding[1:, 3], label="dimension 3")
plt.legend()
plt.xlabel("Sequence length")
plt.ylabel("Period of Positional Encoding")

在这里插入图片描述

1.3 原文代码

class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
    	# max_len 序列最大长度,自定义的,不是真正的最大长度
    	# d_model 模型嵌入维度
        super(PositionalEncoding, self).__init__()
        # 实例化dropout层
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        # 初始化一个位置编码矩阵, shape: (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        #二维张量扩充为三维张量 shape: (1,max_len, d_model)
        pe = pe.unsqueeze(0)
        
        # 将位置编码注册为模型的buffer,注册为buffer之后,不会进行更新
        # 注册为buffer后可以再模型保存后重新加载时候将这个位置编码器和模型参数加载进来
        self.register_buffer('pe', pe)  # 注册的名字pe,变量也是pe
        
    def forward(self, x):
    	# x 序列的嵌入表示
    	# pe是按max_len进行注册的,太长了,将第二个维度(max_len对应的维度)缩小为真正的序列x的最大长度
    	# Variable(self.pe[:, :x.size(1)], requires_grad=False) 即位置编码
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

2 多头自注意力模块

2.1 多头自注意力代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

# 复制网络,即使用几层网络就改变N的数量
# 如 4层线性层  clones(nn.Linear(model_dim,model_dim),4)
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# 计算注意力
def attention(q, k, v, mask=None, dropout=None):
	# q,k,v [bs,-1,head,embed_dim//head], q.size(-1) = embed_dim//head
    d_k = q.size(-1)
    # (head,embed_dim//head)*(embed_dim//head,head)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
    	# 如果使用mask, 0的位置用-1e9填充
        scores = scores.masked_fill(mask == 0, -1e9)
    # 对scores的最后一个维度进行softmax
    p_attn = F.softmax(scores, dim = -1)
	# 判断是否需要进行dorpout处理
    if dropout is not None:
        p_attn = dropout(p_attn)
    # 返回添加注意力后的结果,注意力系数
    return torch.matmul(p_attn, v), p_attn

# 计算多头注意力
class Multi_Head_Self_Att(nn.Module):
    def __init__(self,head,model_dim,dropout=0.1):
        super(Multi_Head_Self_Att,self).__init__()+
        # 判断嵌入维度能否被head整除,不能整除抛出异常
        assert model_dim % head == 0
        self.d_k = model_dim//head
        self.head = head
        self.linears = clones(nn.Linear(model_dim,model_dim),4)
        self.att = None
        self.dropout = nn.Dropout(p=dropout)
    def forward(self,q,k,v,mask=None):
        if mask is not None:
        	# 掩码非空,扩充维度,代表多头中的第n个头
            mask = mask.unsqueeze(1)
        nbatches = q.size(0)
        
        # zip函数 将线性层与q,k,v分别对应(self.linears,q),(self.linears,k),(self.linears,v)
        # q,k,v [bs,-1,head,embed_dim/head]
        # 使用线性层l处理x,把处理后的x形状view成(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2)) 第二个维度自适应维度大小
        # transpose(1,2) 让代表序列长度的维度与词向量的维度相邻
        q,k,v = [l(x).view(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2) for l,x in zip(self.linears,(q,k,v))] 
        
        # 返回计算注意力之后的值作为x和注意力分数
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        # 通过多头注意力计算后,得到了每个头计算结果组成的4维张量,需要将其转换成与输入一样的维度
        # 将第2个维度和第三个维度换回来,维度组成(nbatches,-1,model_dim)
        # contiguous()使得转置之后的张量能够运用view方法
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, int(self.head * self.d_k))
        # 之前建立了四个线性层,前面q,k,v用了三个线性层,最后一个现成层对注意力结果进行一次线性变换
        return self.linears[-1](x),self.attn  

2.2 测试多头注意力

# 模型参数
head = 4
model_dim = 128
seq_len = 10
dropout = 0.1

# 生成示例输入
q = torch.randn(seq_len, model_dim)
k = torch.randn(seq_len, model_dim)
v = torch.randn(seq_len, model_dim)

# 创建多头自注意力模块
att = Multi_Head_Self_Att(head, model_dim, dropout=dropout)

# 运行模块
output,att = att(q, k, v)

# 输出形状
print("Output shape:", output.shape)
print(att.shape())
sns.heatmap(att.squeeze().detach().cpu())

输出
Output shape: torch.Size([10, 1, 128])
torch.Size([10, 4, 1, 1])
在这里插入图片描述

3 未来序列掩码矩阵

作用:

  • 解码器中的掩码是防止泄露未来要预测的部分,掩码矩阵是一个除对角线的上三角矩阵
  • 序列填充部分的掩码是判断哪些部位是填充的部位,填充的部位在计算注意力时保证期注意力分数为0【以添加一个负无穷的小数,使得其softmax值为0】

3.1 代码

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    print("掩码矩阵:",subsequent_mask)
    return torch.from_numpy(subsequent_mask) == 0

测试掩码

plt.figure(figsize=(5,5))
print(subsequent_mask(8),subsequent_mask(8).shape)
plt.imshow(subsequent_mask(8)[0])

掩码矩阵:
[[[0 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1]
[0 0 0 0 1 1 1 1]
[0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0]]]

tensor([[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True]]]) torch.Size([1, 8, 8])
在这里插入图片描述
紫色部分为添加掩码的部分

3.2 测试掩码

import torch
from torch.autograd import Variable


# 函数接受两个参数 tgt 和 pad,其中 tgt 是目标序列的张量,pad 是表示填充的值
def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    # 首先,创建一个掩码 tgt_mask,其形状与 tgt 的形状相同,用于指示哪些位置不是填充位置。
    # 这是通过将 tgt 张量中不等于 pad 的位置设置为 True(1),其余位置设置为 False(0)来实现的。
    # unsqueeze(-2) 的作用是在倒数第二个维度上添加一个维度,以便后续的逻辑运算。
    tgt_mask = (tgt != pad).unsqueeze(-2)
    # 调用 subsequent_mask 函数,生成一个用于遮挡未来词的掩码。这个掩码是一个上三角矩阵,
    # 其对角线及其以下的元素为 True(1),其余元素为 False(0)。
    
    # tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)):
    # 将 tgt_mask 和生成的未来词掩码进行逻辑与操作,将未来词位置的掩码设置为 False,即遮挡掉未来词。
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(*attn_shape), diagonal=1)
    return subsequent_mask == 0

# 示例数据
tgt = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0], [6, 7, 8, 9, 10]])  # 目标序列,假设填充值为 0
pad = 0  # 填充值

# 创建掩码
tgt_mask = make_std_mask(tgt, pad)

# 打印结果
print("目标序列:")
print(tgt)
print("\n生成的掩码:")
print(tgt_mask)

输出:

目标序列:
tensor([[ 1, 2, 3, 0, 0],
[ 4, 5, 0, 0, 0],
[ 6, 7, 8, 9, 10]])

生成的掩码:
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/517079.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Java设计模式】序:设计模式总体概述

目录 什么是设计模式设计模式的分类1 创建型模式1.1. 单例(Singleton)1.2 原型(Prototype)1.3 工厂方法(FactoryMethod)1.4 抽象工厂(AbstractFactory)1.5 建造者(Builde…

ajax教程

文章目录 一、原生ajax1、AJAX 简介2、特点1)优点2)缺点 二、http协议1、概念2、Cookie和Session机制1)Cookie2)Session3)报文 二、请求头1、概念2、常见请求头:3、Content-Type 三、AJAX使用1、详细操作2、…

竞赛 Yolov安全帽佩戴检测 危险区域进入检测 - 深度学习 opencv

1 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 Yolov安全帽佩戴检测 危险区域进入检测 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:3分创新点:4分 该项目较为新颖&am…

代码随想录阅读笔记-二叉树【二叉搜索树中的搜索】

题目 给定二叉搜索树(BST)的根节点和一个值。 你需要在BST中找到节点值等于给定值的节点。 返回以该节点为根的子树。 如果节点不存在,则返回 NULL。 例如, 在上述示例中,如果要找的值是 5,但因为没有节点…

Sybase ASE中的char(N)的坑以及与PostgreSQL的对比

1背景 昨天,一朋友向我咨询Sybase ASE中定长字符串类型的行为,说他们的客户反映,同样的char类型的数据,通过jdbc来查,Sybase库不会带空格,而PostgreSQL会带。是不是这样的?他是PostgreSQL的专业大拿,但因为他手头没有现成的Sybase ASE环境,刚好我手上有,便于一试。 …

PostgreSQL 表膨胀原因和解决方案

在 PostgreSQL 中,表膨胀是一个常见的问题,它会导致数据库性能下降,甚至在极端情况下会耗尽磁盘空间。了解表膨胀的原因及其解决方案,对于维护数据库性能和稳定性至关重要。 表膨胀的原因 MVCC (多版本并发控制) PostgreSQL 使…

Kotlin:for循环的几种示例

一、 打印 0 到 2 1.1 方式一:0 until 3 /*** 打印 0 到 2*/ fun print0To2M1(){for (inex in 0 until 3){// 不包含3print("$inex ")} }运行结果 1.2 方式二:inex in 0 …2 /*** 打印 0 到 2*/ fun print0To2M2(){for (inex in 0 ..2){//…

A First Course in the Finite Element Method【Daryl L.】|PDF电子书

专栏导读 作者简介:工学博士,高级工程师,专注于工业软件算法研究本文已收录于专栏:《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现,并提供所有案例完整源码;2.单元…

MySQL-逻辑架构:逻辑架构分析、SQL执行流程、数据库缓冲池

逻辑架构 1. 逻辑架构剖析 1.1 第1层:连接层 系统(客户端)访问MySQL服务器前,做的第一件事就是建立TCP连接。 经过三次握手建立连接成功后,MySQL服务器对TCP传输过来的账号密码做身份认证、权限获取。 用户名或密码…

【go】模板展示不同k8s命名空间的deployment

gin模板展示k8s命名空间的资源 这里学习如何在前端单页面,调用后端接口展示k8s的资源 技术栈 后端 -> go -> gin -> gin模板前端 -> gin模板 -> html jsk8s -> k8s-go-client ,基本资源(deployment等) 环境 go 1.19k8s 1.23go m…

Windows程序设计课程作业-1

文章目录 1. 作业内容2. 设计思路分析与难点3. 代码实现3.1 接口定义3.2 工厂类实现3.3 委托和事件3.4 主函数3.5 代码运行结果 4. 代码地址5. 总结&改进思路6. 阅读参考 1. 作业内容 使用 C# 编码(涉及类、接口、委托等关键知识点),实现…

nssm 工具把asp.net core mvc变成 windows服务,使用nginx反向代理访问

nssm工具的作用&#xff1a;把项目部署成Windows服务&#xff0c;可以在系统后台运行 1.创建一个asp.net core mvc的项目weblication1 asp.net core mvc项目要成为windows服务需要安装下面的nuget包 <ItemGroup><PackageReference Include"Microsoft.Extension…

Python卷积网络车牌识别系统(V2.0)

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

【随笔】Git 高级篇 -- 相对引用1(十二)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

工业组态 物联网组态 组态编辑器 web组态 组态插件 编辑器

体验地址&#xff1a;by组态[web组态插件] BY组态是一款非常优秀的纯前端的【web组态插件工具】&#xff0c;可无缝嵌入到vue项目&#xff0c;react项目等&#xff0c;由于是原生js开发&#xff0c;对于前端的集成没有框架的限制。同时由于BY组态只是一个插件&#xff0c;不能独…

【机器学习】《机器学习算法竞赛实战》第7章用户画像

文章目录 第7章 用户画像7.1 什么是用户画像7.2 标签系统7.2.1 标签分类方式7.2.2 多渠道获取标签7.2.3 标签体系框架 7.3 用户画像数据特征7.3.1 常见的数据形式7.3.2 文本挖掘算法7.3.3 神奇的嵌入表示7.3.4 相似度计算方法 7.4 用户画像的应用7.4.1 用户分析7.4.2 精准营销7…

B/S架构SaaS模式 医院云HIS系统源码,自主研发,支持电子病历4级

B/S架构SaaS模式 医院云HIS系统源码&#xff0c;自主研发&#xff0c;支持电子病历4级 系统概述&#xff1a; 一款满足基层医院各类业务需要的云HIS系统。该系统能帮助基层医院完成日常各类业务&#xff0c;提供病患挂号支持、病患问诊、电子病历、开药发药、会员管理、统计查…

基于Java微信小程序的医院挂号小程序,附源码

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

Java学习之面向对象三大特征

目录 继承 作用 实现 示例 instanceof 运算符 示例 要点 方法的重写(Override) 三个要点 示例 final关键字 作用 继承和组合 重载和重写的区别 Object类详解 基本特性 补充&#xff1a;IDEA部分快捷键 " "和equals()方法 示例 Super关键字 示例 …

Golang实现一个聊天工具

简介 聊天工具作为实时通讯的必要工具&#xff0c;在现代互联网世界中扮演着重要的角色。本博客将指导如何使用 Golang 构建一个简单但功能完善的聊天工具&#xff0c;利用 WebSocket 技术实现即时通讯的功能。 项目源码 点击下载 为什么选择 Golang Golang 是一种高效、简…