Transformer 代码剖析7 - 词元嵌入(TokenEmbedding) (pytorch实现)

一、类定义与继承关系剖析

1.1 代码结构图示

神经网络基础模块
词嵌入基类
自定义词元嵌入
构造函数定义
基类初始化
词汇量参数
维度参数
填充标识参数

1.2 代码实现精讲

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
from torch import nn

class TokenEmbedding(nn.Embedding):
    """
    基于PyTorch实现的动态词元嵌入模块
    实现词元索引到高维向量的可学习映射
    核心功能:将离散的词元序列转换为连续的语义空间表示
    """
    
    def __init__(self, vocab_size, d_model):
        """
        词元嵌入构造器
        
        :param vocab_size: 词表容量(不同词元的总数)
        :param d_model: 嵌入维度(与Transformer模型维度一致)
        设计要点:
        - 继承nn.Embedding的矩阵运算特性
        - 固化填充索引为可训练参数
        - 保持维度与模型其他组件兼容
        """
        super(TokenEmbedding, self).__init__(
            vocab_size, # 嵌入数量 num_embeddings # 嵌入矩阵行数 = 词表大小
            d_model, # 嵌入维度 embedding_dim # 嵌入矩阵列数 = 模型维度
            padding_idx=1 # 填充符索引的特殊处理
        )

二、核心参数深度解读

2.1 参数矩阵可视化

假设词表容量vocab_size=10000,模型维度d_model=512时:

参数维度元素数量数学意义
weight[10000,512]5,120,000可训练的嵌入查询矩阵
padding_idxscalar1动态掩码位置标识

2.2 关键参数说明

1. vocab_size

  • 控制嵌入矩阵的行维度
  • 决定模型可处理的词元种类上限
  • 典型值域:BERT系列(~30000),GPT系列(~50000)

2. d_model

  • 控制嵌入向量的列维度
  • 与Transformer隐藏层维度严格对齐
  • 典型值域:512(原始论文)、768(BERT-base)、1024(大型模型)

3. padding_idx

  • 实现动态序列掩码的关键参数
  • 索引位置对应的梯度会被自动抑制
  • 防止填充符影响模型语义理解

三、运算过程分步推演

3.1 前向传播示例

输入序列:[3, 28, 1, 0] (1为填充符)

运算步骤:

1. 建立索引映射:

[[3],[[0.2, -0.5, ..., 1.2],  # 索引3的嵌入
 [28],[0.7, 1.1, ..., -0.3],   # 索引28的嵌入
 [1],[0.0, 0.0, ..., 0.0],    # 填充符固定值
 [0]][-0.9, 0.4, ..., 0.1]]   # 索引0的嵌入

2. 矩阵缩放(后续处理):

embeddings * sqrt(d_model)  # 维度对齐的数学技巧

3.2 梯度传播特性

  • 可微分性: 整个映射过程保持梯度通路
  • 参数更新: 通过反向传播调整嵌入矩阵
  • 特殊处理: padding_idx位置梯度始终为0

四、设计哲学解析

4.1 继承关系价值

TokenEmbedding
torch.nn.Embedding
torch.nn.Module
PyTorch基础设施

优势分析:

  • 复用性:继承矩阵运算和参数管理功能
  • 扩展性:保留自定义前向传播的可能性
  • 兼容性:无缝对接PyTorch生态工具

4.2 工程实践建议

1. 初始化技巧:

  • 默认采用均匀分布 U ( − 1 d m o d e l , 1 d m o d e l ) U(-\sqrt{\frac{1}{d_{model}}}, \sqrt{\frac{1}{d_{model}}}) U(dmodel1 ,dmodel1 )
  • 可扩展为Xavier/Kaiming初始化:
    # Xavier均匀初始化(默认)
    nn.init.xavier_uniform_(self.weight)
    
    # 特殊处理填充符
    self.weight.data[1].zero_()
    

2. 维度对齐策略:

# 与位置编码相加前的缩放
embeddings = embeddings * math.sqrt(d_model)

3. 混合精度训练:

# 自动转换为半精度
with autocast():
    embeddings = embedding_layer(input_ids)

4. 填充符处理机制:

  • 训练阶段自动跳过无效位置的计算
  • 推理阶段维持序列形状一致性

5. 计算复杂度分析:

  • 时间复杂度: O ( B ⋅ S ⋅ D ) O(B \cdot S \cdot D) O(BSD)
  • 空间复杂度: O ( V ⋅ D ) O(V \cdot D) O(VD)

完整实现细节可参考PyTorch中sparse.py 模块解析的相关文章(嵌入(Embedding)基类代码解析)或PyTorch官方Embedding文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/980424.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

不同规模企业如何精准选择AI工具: DeepSeek、Grok 和 ChatGPT 三款主流 AI 工具深度剖析与对比

本文深入探讨了最近国内外主流的 DeepSeek、Grok 和 ChatGPT 三款主流 AI 工具的技术细节、性能表现、应用场景及局限性,并从技术能力、功能需求、成本预算、数据安全和合规以及服务与支持五个关键维度,详细分析了不同规模企业在选择 AI 工具时的考量因素…

利用 Python 爬虫进行跨境电商数据采集

1 引言2 代理IP的优势3 获取代理IP账号4 爬取实战案例---(某电商网站爬取)4.1 网站分析4.2 编写代码4.3 优化代码 5 总结 1 引言 在数字化时代,数据作为核心资源蕴含重要价值,网络爬虫成为企业洞察市场趋势、学术研究探索未知领域…

【数据挖掘】Matplotlib

Matplotlib 是 Python 最常用的 数据可视化 库之一,在数据挖掘过程中,主要用于 数据探索 (EDA)、趋势分析、模式识别 和 结果展示。 📌 1. Matplotlib 基础 1.1 安装 & 导入 # 如果未安装 Matplotlib,请先安装 # pip instal…

使用Java构建高效的Web服务架构

使用Java构建高效的Web服务架构 随着互联网技术的飞速发展,Web服务在现代应用中扮演着至关重要的角色。尤其是在企业级应用中,如何构建一个高效、可扩展且易维护的Web服务架构,成为了开发者和架构师面临的一项重要挑战。Java作为一种成熟、稳…

数据库MySQL,在终端输入后,提示不是内部命令等

【解决问题】mysql提示不是内部或外部命令,也不是可运行的程序 一般这种问题是因为没有在系统变量里面添加MySQL的可执行路径 以下是添加可执行路径的方法: 第一步:winR输入services.msc 然后找到MySQL,右击属性并复制MySQL的可执…

LabVIEW正弦信号处理:FFT与最小二乘拟合的参数提取

问题一:LabVIEW能否对采集的正弦力信号进行快速傅里叶变换(FFT),并得到幅值和相位结果? 答案: 可以。LabVIEW通过内置信号处理工具包提供完整的FFT分析功能,具体实现如下: FFT分析流…

Hive-05之查询 分组、排序、case when、 什么情况下Hive可以避免进行MapReduce

一、目标 掌握hive中select查询语句中的基本语法掌握hive中select查询语句的分组掌握hive中select查询语句中的join掌握hive中select查询语句中的排序 二、要点 1. 基本查询 注意 SQL 语言大小写不敏感SQL 可以写在一行或者多行关键字不能被缩写也不能分行各子句一般要分行…

React:B站评论demo,实现列表渲染、删除按钮显示和功能实现、导航栏渲染切换及高亮显示、评论区的排序

功能要求: 1、渲染评论列表 2、删除评论功能:只显示自己评论的删除按钮;点击删除按钮,删除当前评论,列表中不再显示。 3、渲染导航Tab(最新 | 最热)和其 高亮实现 4、评论排序功能实现&…

ST表解决RMQ问题

引入 给定你一个长度为n的数组a,再给你q次询问,每次询问给定你一个区间[L,R],让你求a数组中L~R中的最大值/最小值 我们利用常规算法求时很显然会超时,以此我们需要一个数据结构——ST表来解决 ST表 ST表是一个类似于线段树的东…

[数据结构] - - - 链表

一、定义 链表:是一种常见的线性数据结构,它通过一组节点(Node)来存储数据,每个节点包含两部分:数据域和指针域。 1.1 链表的基本概念 节点(Node):链表的最小单元&#…

Linux的动态库与静态库

目录 动静态库的基本原理 认识动静态库 动静态库各自的特征 静态库 动态库 动静态库与内存 静态库的加载方式 动态库的加载方式 加载到物理内存的细节 静态库的打包与使用 打包 使用 动态库的打包与使用 打包 使用 我以前写的一篇文章中就用网吧与在宿舍自己组装电…

图漾PercipioIPTool软件使用

文章目录 前期准备1.PercipioIPTool软件1.1 更改网络适配器1.2 更改自动获取IP1.3设置静态IP 前期准备 1.一根超五类及其以上规格网线(cat5e、cat6…) 2.相机,配套网线和IO线 3.配套软件PercipioViewer或者PercipioIPTool软件(Windows环境使…

EasyRTC嵌入式WebRTC技术与AI大模型结合:从ICE框架优化到AI推理

实时通信技术在现代社会中扮演着越来越重要的角色,从视频会议到在线教育,再到远程医疗,其应用场景不断拓展。WebRTC作为一项开源项目,为浏览器和移动应用提供了便捷的实时通信能力。而EasyRTC作为基于WebRTC的嵌入式解决方案&…

《白帽子讲 Web 安全:点击劫持》

目录 摘要: 一、点击劫持概述 二、点击劫持的实现示例:诱导用户收藏指定淘宝商品 案例 构建恶意页面: 设置绝对定位和z - index: 控制透明度: 三、其他相关攻击技术 3.1图片覆盖攻击与 XSIO 3.2拖拽劫持与数据…

计算机网络---SYN Blood(洪泛攻击)

文章目录 三次握手过程SYN Flood攻击原理防御措施协议层优化网络层拦截系统配置调整 TCP协议是 TCP/IP 协议栈中一个重要的协议,平时我们使用的浏览器,APP等大多使用 TCP 协议通讯的,可见 TCP 协议在网络中扮演的角色是多么的重要。 TCP 协议…

GitCode 助力 python-office:开启 Python 自动化办公新生态

项目仓库:https://gitcode.com/CoderWanFeng1/python-office 源于需求洞察,打造 Python 办公神器 项目作者程序员晚枫在运营拥有 14w 粉丝的 B 站账号 “Python 自动化办公社区” 时,敏锐察觉到非程序员群体对 Python 学习的强烈需求。在数字…

Trae智能协作AI编程工具IDE:如何在MacBook Pro下载、安装和配置使用Trae?

Trae智能协作AI编程工具IDE:如何在MacBook Pro下载、安装和配置使用Trae? 一、为什么选择Trae智能协作IDE? 在AI编程新时代,Trae通过以下突破性功能重新定义开发体验: 双向智能增强:AI不仅提供代码补全&a…

Qt空项目代码解释

一、 背景 创建的是一个 QWidget 项目。 二、main.cpp 1、图片 2、代码解释 (1)QApplication Qt 图形化界面中一定有 QApplication (2)Widget w; 是 QWidget 的子类。 (3)w.show(); 继承父类的显示…

Codeforces Round 1007 (Div. 2)(ABCD1)

A. The Play Never Ends 翻译: 让我们来介绍一种双人游戏--乒乓球,在这种游戏中,胜负永远分明,不可能出现平局。 索赛、福福和浩海三人想用一生的时间打乒乓球。他们决定用以下方式永远打下去: 在每场比赛中&#xff…

多元数据直观表示(R语言)

一、实验目的: 通过上机试验,掌握R语言实施数据预处理及简单统计分析中的一些基本运算技巧与分析方法,进一步加深对R语言简单统计分析与图形展示的理解。 数据: 链接: https://pan.baidu.com/s/1kMdUWXuGCfZC06lklO5iXA 提取码: …