Pytorch深度解析:Transformer嵌入层源码逐行解读

前言

本部分博客需要先阅读博客:
《Transformer实现以及Pytorch源码解读(一)-数据输入篇》
作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
        super(Transformer, self).__init__()
        # 词嵌入层
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
        # 编码层:使用Transformer
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        # 输出层
        self.output = nn.Linear(hidden_dim, num_class)

    def forward(self, inputs, lengths):
        inputs = torch.transpose(inputs, 0, 1)
        hidden_states = self.embeddings(inputs)
        hidden_states = self.position_embedding(hidden_states)
        attention_mask = length_to_mask(lengths) == False
        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
        logits = self.output(hidden_states)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。

我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_
init
,rest_parameters,forward:

Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
            elif padding_idx < 0:
                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
                padding_idx = self.num_embeddings + padding_idx
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        if _weight is None:
            self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
            # print("===========================================1")
            # print(self.weight)
            #将self.weight进行nornal归一化
            self.reset_parameters()
            print("===========================================2")
            print(self.weight)
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            self.weight = Parameter(_weight)

        self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:
        init.normal_(self.weight)
        self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。

N

(

mean

,

std

2

)

.

\mathcal{N}(\text{mean}, \text{std}^2).

N

(

mean

,

std

2

)

.

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
    r"""Fills the input Tensor with values drawn from the normal
    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.normal_(w)
    """
    return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {
  explicit Fan(Tensor& tensor) {
    const auto dimensions = tensor.ndimension();
    TORCH_CHECK(
        dimensions >= 2,
        "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");

    if (dimensions == 2) {
      in = tensor.size(1);
      out = tensor.size(0);
    } else {
      in = tensor.size(1) * tensor[0][0].numel();
      out = tensor.size(0) * tensor[0][0].numel();
    }
  }

  int64_t in;
  int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {
  NoGradGuard guard;
  return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {
  return F::detail::embedding(
      input,
      weight,
      options.padding_idx(),
      options.max_norm(),
      options.norm_type(),
      options.scale_grad_by_freq(),
      options.sparse());
}

继续追踪,发现weight中的每个变量被下面的c++代码填充了正太分布的随机数。

void normal_kernel(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {
  CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
  templates::cpu::normal_kernel(self, mean, std, generator);
}

随机数的生成调用如下的代码,首先询问:目前代码是在什么设备上运行,并调用cpu或者gup上的随机数生成方法。

template <typename T>
static inline T * check_generator(c10::optional<Generator> gen) {
  TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
  TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
  TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
  return gen->get<T>();
}

/**
 * Utility function used in tensor implementations, which
 * supplies the default generator to tensors, if an input generator
 * is not supplied. The input Generator* is also static casted to
 * the backend generator type (CPU/CUDAGeneratorImpl etc.)
 */
template <typename T>
static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) {
  return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}

至此,embedding的每个随机数的生成过程都清楚了。

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/716863.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

怎么给二维码添加文字或logo?快速美化二维码的使用技巧

怎么给已生成的二维码修改样式呢&#xff1f;目前常规生成的二维码大多是普通黑白色的&#xff0c;没有明显的标识不利于用户辨别。想要提升二维码的辨识度可以通过添加logo、添加文字的方式来改变二维码的样式&#xff0c;让用户看到二维码就知道是否是自己需要的内容&#xf…

智能制造uwb高精度定位系统模块,飞睿智能3厘米定位测距芯片,无人机高速传输

在科技日新月异的今天&#xff0c;定位技术已经渗透到我们生活的方方面面。从手机导航到自动驾驶&#xff0c;再到无人机定位&#xff0c;都离不开精准的定位系统。然而&#xff0c;随着应用场景的不断拓展&#xff0c;传统的定位技术如GPS、WiFi定位等&#xff0c;因其定位精度…

【AI基础】大模型部署工具之ollama的安装部署

ollama是大模型部署方案&#xff0c;对应docker&#xff0c;本质也是基于docker的容器化技术。 从前面的文章可以看到&#xff0c;部署大模型做的准备工作是比较繁琐的&#xff0c;包括各个环节的版本对应。ollama提供了一个很好的解决方案。 ollama主要针对主流的LLaMA架构的…

如何使用xurlfind3r查找目标域名的已知URL地址

关于xurlfind3r xurlfind3r是一款功能强大的URL地址查询工具&#xff0c;该工具本质上是一个CLI命令行工具&#xff0c;可以帮助广大研究人员从多种在线源来查询目标域名的已知URL地址。 功能介绍 1、从被动在线源获取URL地址以实现最大数量结果获取&#xff1b; 2、支持从Way…

可通过小球进行旋转的十字光标(vtkResliceCursor)

前一段事件看到VTK的一个例子&#xff1a; 该案例是vtk.js写的&#xff0c;觉得很有意思&#xff0c;个人正好也要用到&#xff0c;于是萌生了用C修改VTK源码来实现该功能的想法。原本以为很简单&#xff0c;只需要修改一下vtkResliceCursor就可以了&#xff0c;加上小球&#…

【面试 - 页面优化举例】页面跳转卡顿问题解决 - 页面跳转速度优化

目录 为何要优化如何优化优化1 - 懒加载优化2 - el-tree 子节点默认不展开 为何要优化 页面A跳转到也页面B时&#xff0c;页面出现卡顿情况&#xff1a; 【问题】页面A → 页面B时&#xff0c;页面B进入到了 created 钩子后过了六七秒才进入到 mounted 钩子&#xff1b;【分析经…

遗传算法浅理解

1. 什么是遗传算法&#xff1f; ​ 遗传算法&#xff0c;又称为 Genetic algorithm(GA)Genetic algorithm(GA)。其主要思想就是模拟生物的遗传与变异。它的用途非常广泛&#xff0c;可以用于加速某些求最大或者最小值的算法&#xff08;换句话说就是加速算法收敛&#xff0c;最…

PV180R1K1T1NMMC派克通轴传动结构柱塞泵

PV180R1K1T1NMMC派克通轴传动结构柱塞泵 派克柱塞泵的结构组成部分&#xff1a;柱塞、手把、斜盘、压盘、滑履、泵体、配油盘、传送轴。其优点如下&#xff1a; 1、结构紧凑耐用&#xff0c;具有灵活的安装接口 2、安静的工作 3、效率高 4、降低功耗和减少发热 5、具有“…

升级到tomcat10和Java 21后,idea控制台system.out.println输出中文乱码问题

最近一次性从tomcat 9升级到tomcat 10&#xff0c;同时Java sdk也从1.8升级到21。 升级过程中&#xff0c;当然会遇到很多问题&#xff0c;但是控制台输出中文乱码问题&#xff0c;着实折腾了很久。 1、尝试各种方法 网上说的很多通用方法都试过了&#xff0c;就是不生效。包…

编码在网络安全中的应用和原理

前言:现在的网站架构复杂&#xff0c;大多都有多个应用互相配合&#xff0c;不同应用之间往往需要数据交互&#xff0c;应用之间的编码不统一&#xff0c;编码自身的特性等都很有可能会被利用来绕过或配合一些策略&#xff0c;造成一些重大的漏洞。 什么是编码&#xff0c;为什…

别再这么起号了!TikTok小白起号误区,你中招了吗?

看过不少Tiktok新手的起号失败案例&#xff0c;总结下来就是以下这几个问题&#xff0c;今天结合一些个人起号心得给大家分享怎么成功在TK起号&#xff0c;希望对大家有所帮助。 手机/网络环境 首先我们要确保手机环境和网络环境没有问题&#xff0c;如果被TK判断出是非海外用户…

【YOLOv8改进[注意力]】在YOLOv8中添加ECA高效通道注意力(2020.4)的实践 + 含全部代码和详细修改方式 + 手撕结构图

本文将进行在YOLOv8中添加ECA高效通道注意力的实践,助力YOLOv8目标检测效果的实践,文中含全部代码、详细修改方式以及手撕结构图。助您轻松理解改进的方法。 改进前和改进后的参数对比: 目录 一 ECA 二 在YOLOv8中添加ECA注意力

synchronized死锁

1、死锁案例 /*** descpription: 死锁案例* date 2024/6/17*/ public class DeadLockDemo {public static void main(String[] args) {Object objA new Object();Object objB new Object();new Thread(() ->{synchronized (objA){System.out.println(Thread.currentThrea…

【产品经理】ERP订单处理3-解密促销策略

由于订单到电商ERP系统中&#xff0c;订单金额已经不能支持改动&#xff0c;故电商前台端的优惠券活动、满减活动、支付优惠活动无法使用&#xff0c;故电商ERP只涉及赠品的活动。 一、订单金额阶梯送 顾名思义&#xff1a;订单金额在某个或者某几个金额范围内赠送商品。 字…

设备物联网关在实际生产中的作用解析-天拓四方

随着物联网技术的迅猛发展&#xff0c;设备物联网关作为连接物理世界与数字世界的核心组件&#xff0c;其应用已经渗透到工业、农业、医疗等多个领域。本案例将聚焦于设备物联网关在某制造企业中的应用&#xff0c;详细解析其在实际生产中的重要作用。 案例背景 某制造企业面…

Browserslist: caniuse-lite is outdated。浏览器列表:caniuse lite已经过时???

一、最近运行项目启动时提示 Browserslist: caniuse-lite is outdated. Please run: npx update-browserslist-dblatest Why you should do it regularly: https://github.com/browserslist/update-db#readme 这要是这一句&#xff0c;Browserslist: caniuse-lite is outdated.…

离散数学-代数系统证明题归类

什么是独异点&#xff1f; 运算 在B上封闭&#xff0c;运算 可结合&#xff0c;且存在幺元。 学会合理套用题目公式结合律 零元&#xff1f; 群中不可能有零元 几个结论要熟记&#xff1a; 1.当群的阶为1时&#xff0c;它的唯一元素视作幺元e 2.若群的阶大于1时&#xff0c;…

成都爱尔胡建斌院长提醒一张眼底照,眼病早知道

眼底藏在眼睛后方&#xff0c;平时没注意无察觉&#xff0c;其实非常重要。它包含的部位多掌控着视觉问题&#xff0c;稍不注意就是视觉受损&#xff0c;视觉缺失&#xff0c;严重的甚至失明致盲。 眼球前面的角膜、晶体等&#xff0c;被称为眼前段&#xff0c;后面则被称之为…

Springboot微服务整合缓存的时候报循环依赖的错误 两种解决方案

错误再现 Error starting ApplicationContext. To display the conditions report re-run your application with debug enabled. 2024-06-17 16:52:41.008 ERROR 20544 --- [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPLI…

如何通过改善团队合作来提高招聘效率

当招聘顶尖人才时&#xff0c;时间就是一切。招聘效率取决于团队快速响应和完成任务的能力&#xff0c;但招聘经理和面试官并不总是最关心重要的招聘任务。更重要的是&#xff0c;求职者的经历取决于准备好的面试官是否准时出现。有时候最好的候选人会接受另一份工作&#xff0…