简单实现Transformer的自注意力

简单实现Transformer的自注意力

关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

在这里插入图片描述
源码获取:VX关注并回复chatgpt-0获得

  • 实现的功能

假如有八个令牌,现在想让每一个令牌至于其前面的通信,如第5个令牌不与6,7,8位置的令牌通信(这是未来的令牌),只与4,3,2,1位置的令牌通信。因此只能通过以前的上下文信息猜测后面的;一种弱的通信方式是取前面的平局值。如5位置==5,4,3,2,1位置上的平局值。

  • 实现
    • 循环的版本
import torch
from torch.nn import functional as F
import torch.nn as nn
torch.manual_seed(1337)

B,T,C = 4,8,2 #batch,time,channels 
x = torch.randn(B,T,C)
xbow = torch.zeros((B,T,C))
print(f'x: {x[0]}')
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] #()t,C
        xbow[b,t] = torch.mean(xprev,0)
print(f'xbow: {xbow[0]}')


#结果
x: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
xbow: tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
# 每一行至于自己以及自己以前的数据进行通信
  • 通过数据矩阵高效实现
a = torch.tril(torch.ones(3,3)) #下三角函数
a = a/torch.sum(a,1,keepdim=True) #对a求平均数
b = torch.randint(0,10,(3,2)).float()
c = a @ b

print(f'a:{a}')
print(f'b:{b}')
print(f'c:{c}')

#结果

a:tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b:tensor([[0., 4.],
        [1., 2.],
        [5., 5.]])
c:tensor([[0.0000, 4.0000],
        [0.5000, 3.0000],
        [2.0000, 3.6667]])
  • 使用Softmax
tril = torch.tril(torch.ones(T,T))  #下三角函数
print(f'tril:{tril}')

wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大
print(f'wei: {wei}')
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数
print(f'wei: {wei}')
xbow3 = wei @ x

print(f'xbow3: {xbow3}')
print(torch.allclose(xbow,xbow3))
  • 单头自注意力

    • 上面的自注意力是通过相同的方式获取以往的信息。但是实际上并不希望是统一的方式,因为不同的token标记会发现其他不同的标记。
    • 例如:我是元音,那么也许我正在寻找过去的辅音,或与我想知道这些辅音是什么。希望这些信息流向我,所以我现在想以依赖数据的方式收集过去的信息。这就是自注意力解决的问题。
    • 方式如下:每个节点或每个位置的每个令牌都会发出两个向量,一个发出查询query,一个发出键key。查询向量粗略的说就是我要找的东西,键向量粗略的讲就是我包含什么。
    • 现在在序列中获取这些标记之间的亲和力的方式基本上只是在键和查询之间做一个点乘积。所以我的查询与所有的其他tokens令牌的所有键进行点乘积。并且点积方式变了。如果键和查询有点对齐,它们将交互到非常高的数量,然后我将了解有关特定标记的更多信息,而不是其他不再序列中的任何其他标记。
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)

k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)

tril = torch.tril(torch.ones(T,T))  #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
out = wei @ x
print(f'out:{out.shape}')
  • 但是在真是中并不聚合到x而是计算一个v.x看作为该令牌的私人信息,与不同头交流的信息存储在v中
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)

k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)

tril = torch.tril(torch.ones(T,T))  #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
value = nn.Linear(C,head_size,bias=False)
v = value(x)
out = wei @ v
print(f'out:{out.shape}')

简单实现自注意力

关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/424964.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Vue3】Props的使用详解

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

1. MAC 安装 goland 和 go

1. 安装goland 官网下载goland 安装破解goland 参考:安装pycharm下载压缩包->解压 -> 运行 sh jetbra/scripts/uninstall.sh -> 运行 sh jetbra/scripts/install.sh打开goland,help -> register -> 输入新的激活码 (从网盘获取到的) 有一…

YOLOv9独家改进|动态蛇形卷积Dynamic Snake Convolution与空间和通道重建卷积SCConv与RepNCSPELAN4融合

专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!! 一、改进点介绍 Dynamic Snake Convolution是一种针对细长微弱的局部结构特征与复杂多变的全局形态特征设计的卷积模块。 SCConv是一种即插即用的空间…

前端学习第七天-css常用样式设置

达标要求 掌握元素的显示与隐藏 熟练应用溢出的文字隐藏 熟练掌握版心和布局流程 1. 元素的显示与隐藏 在CSS中有三个显示和隐藏的单词比较常见,我们要区分开,他们分别是 display visibility 和 overflow。 他们的主要目的是让一个元素在页面中消失…

03、MongoDB -- MongoDB 权限的设计

目录 MongoDB 权限的设计演示前准备:启动 mongodb 服务器 和 客户端 :1、启动单机模式的 mongodb 服务器2、启动 mongodb 的客户端 MongoDB 权限的设计1、MongoDB 的每个数据库都可以保存用户,不止admin数据库可以保存用户。2、保存用户的数据…

人工智能指数报告2023

人工智能指数报告2023 主要要点第 1 章 研究与开发第 2 章 技术性能第 3 章 人工智能技术伦理第 4 章 经济第 5 章 教育第 6 章 政策与治理第 7 章 多样性第 8 章 舆论 人工智能指数是斯坦福大学以人为本的人工智能研究所(HAI)的一项独立倡议&#xff0c…

CNN-LSTM-Attention混合神经网络归时序预测的MATLAB实现(源代码)

CNN-LSTM-Attention介绍: CNN-LSTM-Attention混合神经网络是一种结合了卷积神经网络(CNN)、长短期记忆神经网络(LSTM)和注意力机制(Attention)的模型。这种混合神经网络结合了CNN对空间特征的提…

【MySQL】查询语句:条件、排序和分页

基本查询 MySQL 数据库使用SELECT语句来查询数据。 查询字段 以下为在MySQL数据库中查询数据通用的 SELECT 语法: SELECT 字段名,字段名... FROM 表名;选择全部列 SELECT * FROM emp; -- 查询所有字段一般情况下,除非需要使用表中所有的字段数据&…

[HackMyVM] 靶场 Wave

kali:192.168.56.104 主机发现 arp-scan -l # arp-scan -l Interface: eth0, type: EN10MB, MAC: 00:0c:29:d2:e0:49, IPv4: 192.168.56.104 Starting arp-scan 1.10.0 with 256 hosts (https://github.com/royhills/arp-scan) 192.168.56.1 0a:00:27:00:00:05 (Un…

【LeetCode:2368. 受限条件下可到达节点的数目 + BFS】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

MSCKF3讲:后端理论推导(上)

MSCKF3讲:后端理论推导(上) 文章目录 MSCKF3讲:后端理论推导(上)1 MSCKF中的状态变量① IMU状态:② cam0状态:③ IMU和cam0间状态关系 2 微分方程递推(数值解)3 IMU状态预…

leetcode - 2095. Delete the Middle Node of a Linked List

Description You are given the head of a linked list. Delete the middle node, and return the head of the modified linked list. The middle node of a linked list of size n is the ⌊n / 2⌋th node from the start using 0-based indexing, where ⌊x⌋ denotes th…

ABAP - SALV教程05 添加页眉和页脚

先看看效果叭CL_SALV_TABLE提供了SET_TOP_OF_LIST方法设置页眉显示和SET_TOP_OF_LIST_PRINT方法设置页眉打印来实现添加页眉的目的。CL_SALV_TABLE提供了SET_END_OF_LIST方法设置页脚显示和SET_END_OF_LIST_PRINT方法设置页脚打印来实现添加页脚的目的。这个四个方法的传入参数…

计算机二级Python刷题笔记------基本操作题11、14、17、21、30(考察列表)

文章目录 第十一题(列表遍历)第十四题(len)第十七题(len、insert)第二十一题(append)第三十题(二维列表) 第十一题(列表遍历) 题目&a…

你敢信,copilot Pro这个带着Pro的产品是阉割版?

你敢信,copilot Pro这个带着Pro的产品是阉割版? 没错。 很多人以为copilot Pro带着Pro就是专业版,高大上。 但不知道的是,微软对于office copilot同时发布了两款产品: 针对个人家庭版office用户的copilot Pro&…

【C语言】linux内核dev_hard_start_xmit

一、中文注释 struct sk_buff *dev_hard_start_xmit(struct sk_buff *first, struct net_device *dev,struct netdev_queue *txq, int *ret) {struct sk_buff *skb first; // 初始化skb指针,指向第一个待发送的数据包int rc NETDEV_TX_OK; // 初始返回码为NETD…

C++ set和map使用

set和map 1.关联式容器2. 键值对3. set3.1 介绍3.2 简单使用 4.multiset5.map5.1 介绍5.2 简单使用 6. multimap 1.关联式容器 关联式容器是一种STL容器,用于存储键-值对。它们提供了一种通过键来快速查找值的机制。STL总共实现了两种不同结构的管理式容器&#xff…

编写dockerfile挂载卷、数据容器卷

编写dockerfile挂载卷 编写dockerfile文件 [rootwq docker-test-volume]# vim dockerfile1 [rootwq docker-test-volume]# cat dockerfile1 FROM centosVOLUME ["volume01","volume02"]CMD echo "------end------" CMD /bin/bash [rootwq dock…

ecmascript 6+(2)

引用数据类型: Object, Array, RegExp, Date等 包装类型:(底层数据类型会将简单数据类型包装为对象) String, Number, Boolean等(都是基本数据类型的构造函数) Object Object.keys(对象) 返回数组&…

4款塞纸条盲盒交友源码,可以对接公众号

一元盲盒交友源码/脱单盲盒源码/交友盲盒/恋爱盲盒公众号版 可以对接自己支付,全部自定义 没有任何bug版本,已经测试完全可以 免费源码,不包搭建指导 源码下载地址专业知识分享社区-专业知识笔记免费分享 (chaobiji.cn)