transformer参数推导

一、目录

1.Bert Embedding 参数量计算
2.多头自注意力self_attention 参数计算: d_model* d_model + 3*(d_model* d_qkvnum_heads)
3. 全连接层参数量
4.layerNormer 参数量 2
hidden
5. 编码器 解码器参数
6. 语言模型head 参数:hidden* vocab

二、实现

在这里插入图片描述参考:https://zhuanlan.zhihu.com/p/636500748

import torch
def count_parameters(model:torch.nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  1. Bert Embedding 参数量计算
    包含三个表示层+一个LayerNorm 层,表示层为wordembedding+tokentype_embedding+position_embedding
    假设词表大小 vocab size 为 30522,seq_length 为 512,那么有:
    wordsembedding 参数为:(vocab,hidden)
    segment_embedding 参数为:(2,hidden)
    position_embedding 参数为:(512,hidden)
    layerNorm 参数为 hidden*2
    合并:(30522+2+512)1024+ 10242

  2. 多头自注意力self_attention 参数计算: d_model* d_model + 3*(d_model* d_qkv*num_heads)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

在这里插入图片描述

d_model=512
n_head=8
multihead_attention=nn.MultiheadAttention(embed_dim=d_model,num_heads=n_head)
print(count_parameters(multihead_attention))
print(4 * (d_model * d_model + d_model))
  1. 全连接层参数量
    FeedForward 参数 Linear(1024,4096) 以及 Linear(4096,1024)
    参数为:210244096
class TransformerFordWard(nn.Module):
    def __init__(self,d_model,d_ff):
        super(TransformerFordWard,self).__init__()
        self.d_model=d_model
        self.d_ff=d_ff
        self.linear1 = nn.Linear(self.d_model, self.d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(self.d_ff, self.d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


d_model=512
d_ff=2048
feed_forward = TransformerFordWard(d_model, d_ff)
print(count_parameters(feed_forward)) # 2099712
print(2 * d_model * d_ff + d_model + d_ff)   # 2099712
  1. layerNormer 参数量 2*hidden
d_model = 512
layer_normalization = nn.LayerNorm(d_model)
print(count_parameters(layer_normalization)) # 1024
print(d_model * 2) # 1024
  1. 编码器 解码器参数

    编码器= attention + feed_forward+2layer_norm
    解码器= 2
    attention +feed_forward+3* layer_norm

from torch import nn
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
print(count_parameters(encoder_layer))  # 3152384

decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
print(count_parameters(decoder_layer))  # 4204032
print(decoder_layer)
  1. 语言模型head 参数:hidden* vocab
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/448793.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

仿生蝴蝶制作——蝴蝶翅膀制作

前言 上一次已经设计好了的翅膀图纸 接下来就是根据这个图纸来制作翅膀。 过程中其实可以不用尺子准确测量,直接用碳纤维棒比着剪下来就好了,然后把减下来的一截比着剪下另一只翅膀需要的材料。因为左右两只翅膀差别不能太大,所以这样是最好…

【Java设计模式】十四、策略模式

文章目录 1、策略模式2、案例:促销策略3、总结4、在源码中的实际应用 1、策略模式 从A地到B地,出行方式可选汽车、火车、飞机中的一种 日常开发,开发工具可IDEA,可Eclipse 不管你选飞机还是火车,你最终都可以实现从A地…

SpringBoot注解--08--注解@JsonInclude

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 JsonInclude注解是jackSon中最常用的注解之一,是为实体类在接口序列化返回值时增加规则的注解 1.JsonInclude用法2.JsonInclude注解中的规则有 案例需求…

vue自定义主题皮肤方案

方案一:CSS变量换肤(推荐) 利用css定义变量的方法,用var在全局定义颜色变量(需将变量提升到全局即伪类选择器 :root)然后利用js操作css变量,document.getElementsByTagName(‘body’)[0].style…

网络套接字1

网络套接字1 📟作者主页:慢热的陕西人 🌴专栏链接:Linux 📣欢迎各位大佬👍点赞🔥关注🚓收藏,🍉留言 本博客主要内容讲解了udp的Linux环境下的使用&#xff0c…

鼠标在QTreeView、QTableView、QTableWidget项上移动,背景色改变

目录 1. 前言 2. 需求 3. 功能实现 3.1. 代码实现 3.2. 功能讲解 4. 附录 1. 前言 本博文用到了Qt的model/view framework框架,如果对Qt的“模型/视图/委托”框架不懂,本博文很难读懂。如果不懂这方面的知识,请在Qt Assistant 中输入Model/View…

一款适合程序员开发复杂系统的通用平台——JNPF 开发平台

在过去,很多开发工具更侧重代码编辑,针对数据库增删改查(CRUD)类的 Web 系统开发,在界面设计、前后端数据交互等环节主要还是靠写代码,效率比较低。目前很多所谓的低代码开发平台,大多数也都是基…

SQLServer数据库系列之:查询SQLServer数据库上面的连接信息、session信息、sql语句

SQLServer数据库系列之:查询SQLServer数据库上面的连接信息、session信息、sql语句 一、查询数据库上的连接信息二、查询SQLServer数据库的session信息SQLServer数据库从入门到精通系列文章之:SQLServer数据库百篇技术文章汇总 数据库专栏系列文章阅读传送门:详细整理汇总M…

超越 Siri 和 Alexa:探索LLM(大型语言模型)的世界

揭秘LLM:语言模型新革命,智能交互的未来趋势 近年来,虚拟助手的世界发生了重大转变。 虽然 Siri 和 Alexa 本身就是革命性的,但一种称为大型语言模型 (LLM) 的新型人工智能正在将虚拟助手的概念提升到一个全新的水平。 在这篇博文…

农业四情监测系统---气象科普

农业四情监测系统是一个集环境数据采集、分析与决策支持于一体的智能化系统。它主要通过对农田的土壤、气候、作物生长及病虫害等四个关键要素的实时监测,为农业生产提供精准的数据支持和科学的决策依据。 在土壤方面,系统能够检测土壤温度、湿度、pH值及…

我终于解决MathPage.wll文件找不到问题|(最新版Word上亲测)运行时错误,53’: 文件未找到:athPage.WLL

1、问题症状: 运行时错误,53’: 文件未找到:athPage.WLL 2、 解决方案 第一步 首先我们要先找到MathType安装目录下MathPage.wll文件,直接在此电脑中搜索MathPage.wll,找到文件所在位置。 第二步 打开Word文件&#xff0c…

App拉起微信小程序参考文章

App拉起微信小程序参考文章h5页面跳转小程序-----明文URL Scheme_weixin://dl/business/?appid*appid*&path*path*&qu-CSDN博客文章浏览阅读561次,点赞16次,收藏5次。仅需两步,就能实现h5跳转小程序,明文 URL Scheme&…

Linux---多线程(上)

一、线程概念 线程是比进程更加轻量化的一种执行流 / 线程是在进程内部执行的一种执行流线程是CPU调度的基本单位,进程是承担系统资源的基本实体 在说线程之前我们来回顾一下进程的创建过程,如下图 那么以进程为参考,我们该如何去设计创建一个…

闭包表(Closure Table)存储和查询树形数据结构

闭包表通过在关系表中记录树节点之间的直接和间接关系来表示节点之间的层次结构,目的是支持高效的树遍历和查询操作。 一、创建闭包表 CREATE TABLE departments (id int NOT NULL COMMENT ID,name varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_…

Redis冲冲冲——Redis持久化方式及其区别

目录 引出Redis持久化方式Redis入门1.Redis是什么?2.Redis里面存Java对象 Redis进阶1.雪崩/ 击穿 / 穿透2.Redis高可用-主从哨兵3.持久化RDB和AOF4.Redis未授权访问漏洞5.Redis里面安装BloomFilte Redis的应用1.验证码2.Redis高并发抢购3.缓存预热用户注册验证码4.R…

掌握React中的useCallback:优化性能的秘诀

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

XWPFDocument中XmlCursor的使用

类名&#xff1a; org.apache.xmlbeans Interface XmlCursor版本&#xff1a; 原xml代码&#xff1a; <w:p w14:paraId"143E3662" w14:textId"4167FBA7" w:rsidR"001506F2" w:rsidRPr"003F3D89" w:rsidRDefault"001506F2&qu…

OpenStack安装步骤

一、准备OpenStack安装环境 1、创建实验用的虚拟机实例。 内存建议16GB&#xff08;8GB也能运行&#xff09;CPU&#xff08;处理器&#xff09;双核且支持虚拟化硬盘容量不低于200GB&#xff08;&#xff01;&#xff09;网络用net桥接模式 运行虚拟机 2、禁用防火墙与SELin…

2024会声会影永久免费版新功能软件特色及新功能

会声会影2024永久免费版是一款收到很多用户公认的极佳视频编辑软件&#xff0c;里面的每一个功能都特别的强悍你能够一键给图片视频添加特效非常的过瘾&#xff0c;赶快来一起下载试试吧。 会声会影2023-安装包&#xff1a; https://souurl.cn/gtyDFc 会声会影2023-安装包&…

Golang 开发实战day03 - Arrays Slices

Golang 教程03 - Arrays&#xff0c;Slices Go语言中的数组和切片都是用于存储数据的类型&#xff0c;但它们之间存在一些重要的区别。了解这些区别对于有效地使用它们至关重要。 1. Arrays 数组 1.1 定义 数组是一种固定大小的数据结构&#xff0c;用于存储相同类型的值。…