transfomer中attention为什么要除以根号d_k

简介

得到矩阵 Q, K, V之后就可以计算出 Self-Attention 的输出了,计算的公式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=Softmax(dk QKT)V

好处

除以维度的开方,可以将数据向0方向集中,使得经过softmax后的梯度更大.
从数学上分析,可以使得QK的分布和Q/K保持一致,

推导

对于两个独立的正态分布而言,两者的加法的期望和方差就是两个独立分布的期望和方差。
qk_T的计算过程为[len_q,dim][dim,len_k]=[len_q,len_k],qk的元素等于dim个乘积的和。对于0-1分布表乘积不会影响期望和方差,但是求和操作会使得方差乘以dim,因此对qk元素除以sqrt(dim)把标准差压回1.

这里展示一个不严谨的采样可视化过程
假设在query在(0,1)分布,key在(0,1)分布,随机采样lengthdim个点,然后统计querykey_T的散点的分布

import math
import numpy as np
import matplotlib.pyplot as plt


def plot_curve(mu=0, sigma =1):
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import norm
    # 设置正态分布的参数
    # mu, sigma = 0, 1  # 均值和标准差
    # 创建一个x值的范围,覆盖正态分布的整个区间
    x = np.linspace(mu - 4 * sigma, mu + 4 * sigma, 1000)
    # 计算对应的正态分布的概率密度值
    y = norm.pdf(x, mu, sigma)
    # 我们可以选择y值较高的点来绘制散点图,以模拟概率密度的分布
    # 这里我们可以设置一个阈值,只绘制y值大于某个值的点
    threshold = 0.01  # 可以根据需要调整这个阈值
    selected_points = y > threshold
    plt.plot(x, y, 'r-', lw=2, label='Normal dist. (mu={}, sigma={})'.format(mu, sigma))
    plt.title('Normal Distribution Scatter Approximation')
    plt.xlabel('Value')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_poins(x):
    # 因为这是一个一维的正态分布,我们通常只绘制x轴上的点
    # 但为了模拟二维散点图,我们可以简单地将y轴设置为与x轴相同或固定值(例如0)
    y = np.zeros_like(x)
    # 绘制散点图
    plt.figure(figsize=(8, 6))
    plt.scatter(x, y, alpha=0.5)  # alpha控制点的透明度
    plt.title('Normal (0, 1) Distribution Scatter Plot')
    plt.xlabel('Value')
    plt.ylabel('Value (or Frequency if binned)')
    plt.grid(True)
    plt.show()



if __name__ == '__main__':
    # 设置随机种子以便结果可复现
    np.random.seed(0)
    len = 10000
    dim = 100
    query = np.random.normal(0, 1, len*dim).reshape(len,dim)
    key = np.random.normal(0, 1, len*dim).reshape(dim,len)
    qk = np.matmul(query,key) / math.sqrt(dim)

    mean_query = query.mean()
    std_query = np.std(query,ddof=1)

    mean_key = key.mean()
    std_key = np.std(key,ddof=1)

    mean_qk = qk.mean()
    std_qk = np.std(qk,ddof=1)

    plot_poins(query)
    plot_curve(mean_query,std_query)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/668587.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

算法每日一题(python,2024.05.31)

题目来源(力扣. - 力扣(LeetCode),简单) 解题思路: 二次遍历,第一次遍历用哈希表记录每个字母的出现次数,出现一次则将它的value值赋为True,将它的下标赋为key值&#x…

leetcode74搜索二维矩阵

题目 给你一个满足下述两条属性的 m x n 整数矩阵: 每行中的整数从左到右按非严格递增顺序排列。每行的第一个整数大于前一行的最后一个整数。 给你一个整数 target ,如果 target 在矩阵中,返回 true ;否则,返回 fa…

LeetCode-47 全排列Ⅱ

LeetCode-47 全排列Ⅱ 题目描述解题思路代码说明 题目描述 给定一个可包含重复数字的序列 nums ,按任意顺序 返回所有不重复的全排列。 示例 : 输入:nums [1,1,2]输出: [[1,1,2], [1,2,1], [2,1,1]] b站题目解读讲的不好&…

充电宝哪个牌子好?怎么选充电宝?压箱底充电宝购买指南大全!

充电宝作为我们日常生活中不可或缺的便携式电源之一,市场上品牌众多、种类繁多。对于消费者来说,如何选择适合自己的充电宝成为一个值得重视的问题。有的充电宝厂家为节省成本“偷工减料”,使用劣质电池,以次充好、参数造假等现象…

Win10安装TensorRT

目录 什么是TensorRT 下载TensorRT 安装TensorRT 拷贝文件 安装whl文件 验证是否安装成功 什么是TensorRT TensorRT是由Nvidia推出的C语言开发的高性能神经网络推理库,是一个用于生成部署的优化器和运行时引擎。和cudnn类似,但它不支持训练&#xff…

Mysql(一)查询Sql是如何执行的

Hello,大家好我是极客涛😎,我最近在整理Mysql相关的知识点,所以准备开启一个Mysql的主线任务,大概耗时3周左右,整个节奏还是由浅入深,主要包括Mysql的架构、事务实现、索引组织形式、SQL优化、日…

kettle 使用动态变量名定义变量

name是变量,value 值也是变量 我需要把name作为变量名,value作为变量值; 在kettle中,使用javascript脚本 key与lastVsxzl都是变量 //Script here setVariable(key,lastVsxzl,r);var rgetVariable(key,r); Demo 1、从记事本里面…

sensitive-word 敏感词 v0.16.1 新特性支持字典内存资源释放

敏感词系列 sensitive-word-admin 敏感词控台 v1.2.0 版本开源 sensitive-word-admin v1.3.0 发布 如何支持分布式部署? 01-开源敏感词工具入门使用 02-如何实现一个敏感词工具?违禁词实现思路梳理 03-敏感词之 StopWord 停止词优化与特殊符号 04-…

【第十三节】C++控制台版本坦克大战小游戏

目录 一、游戏简介 1.1 游戏概述 1.2 知识点应用 1.3 实现功能 1.4 开发环境 二、项目设计 2.1 类的设计 2.2 各类功能 三、程序运行截图 3.1 游戏主菜单 3.2 游戏进行中 3.3 双人作战 3.4 编辑地图 一、游戏简介 1.1 游戏概述 本项目是一款基于C语言开发的控制台…

linux--------线程的同步和互斥

前言 提示:以下是本篇文章正文内容,下面案例可供参考 一、线程互斥 (1)互斥: 任何时刻,互斥保证有且只有一个执行流进入临界区,访问临界资源,通常对临界资源起保护作用 要了解互…

fastjson 泛型转换问题(详解)

系列文章目录 附属文章一:fastjson TypeReference 泛型类型(详解) 文章目录 系列文章目录前言一、代码演示1. 不存在泛型转换2. 存在泛型转换3. 存在泛型集合转换 二、原因分析三、解决方案1. 方案1:重新执行泛型的 json 转换2. …

使用Python突破网站验证码限制

之前有小伙伴说,在web自动化的过程中,经常会被登录的验证码给卡住,不知道如何去通过验证码的验证,今天专门给大家来聊聊验证码的问题。 常见的验证码一般分为两类,一类是图文验证码,一类是滑块验证码&#…

c#基础()

学习目标 了解:嵌套类,匿名类,对象初始化器 重点:类的定义以及对象,构造方法,this和static关键字 掌握:面向对象的概念,访问修饰符,垃圾回收 面向对象 面向对象的概…

面试题:SpringBoot启动流程

具体步骤 新建一个Spring应用程序 (new springApplication()): 确认web应用的类型加载ApplicationContextInitializer加载ApplicationListener记录主启动类 运行应用程序(.run): 准备环境对象Environment,用于加载…

Java学习【String类详解】

Java学习【String类详解】 String的介绍及定义方式String类型的比较String类型的查找charAt()访问字符indexOf()查找下标 转化和替换数值和字符串转化大小写的转换字符串转数组格式化替换 字符串的拆分和截取split()拆分substring()截取trim()去除两边空格 StringBuilder和Stri…

09Linux GDB学习笔记

Linux GDB使用 目录 文章目录 Linux GDB使用先编译文件1.检查安装1.1 安装GDB 2.启动GDB3.退出GDB4.设置断点4.1 在指定行号处设置断点4.2 在指定函数名处设置断点4.3 在指定源文件和行号处设置断点 4.4查看断点信息4.5删除断点5.运行5.1 <font color#ff0000>逐过程&am…

java web爬虫

目录 读取本地文件 从网站读取文件 java爬虫 总结 读取本地文件 import java.io.File; import java.io.PrintWriter; import java.util.Scanner;public class ReplaceText {public static void main() throws Exception{File file new File("basic\\test.txt"…

Sui与Atoma合作为开发者提供AI支持

AI初创公司Atoma宣布其即将推出的推理网络将与Sui集成&#xff0c;该网络将使开发者能够在他们的应用程序中使用AI工具。Atoma选择Sui作为其第一个区块链集成对象是由于Sui的可扩展性和性能。 尽管生成式AI在过去几年中引起了轰动&#xff0c;但它尚未进入许多消费者应用程序。…

openfiler安装部署-1

openfiler安装部署 简介1 下载openfiler2 openfiler 安装2.1 vmware 典型配置2.2 稍后安装操作系统2.3 新建虚拟机向导2.4 命名虚拟机2.5 指定磁盘容量2.6 添加系统镜像&#xff0c;准备安装系统2.7 启动安装系统2.8 初始化磁盘&#xff0c;选择"Yes"2.9 创建分区&am…

软链接和硬链接

1.软链接 > 也称为符号链接 1.1软链接的创建 注&#xff1a;不管是源文件还是链接文件&#xff0c;最好都用上绝对路径 ln -s 链接源 链接名 //创建链接文件 ln -sf 链接源 链接名 //修改链接的源 s 如果目标链接名称已经存在&#xff0…