Llama模型结构解析(源码阅读)

目录

  • 1. LlamaModel整体结构流程图
  • 2. LlamaRMSNorm
  • 3. LlamaMLP
  • 4. LlamaRotaryEmbedding

  • 参考资料:
    https://zhuanlan.zhihu.com/p/636784644
    https://spaces.ac.cn/archives/8265 ——《Transformer升级之路:2、博采众长的旋转式位置编码》

前言:本次阅读代码位置,在transformers库底下的modeling_llama.py,具体位置在:transformers/models/llama/modeling_llama.py,如下图所示:在这里插入图片描述

1. LlamaModel整体结构流程图

在这里插入图片描述

2. LlamaRMSNorm

  • 代码如下
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        return (self.weight * hidden_states).to(input_dtype)
  • RMSNorm的公式如下所示:
    x i 1 n ∑ i = 1 n x i 2 + e p s ∗ w e i g h t i \frac{x_i}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n}{x_i}^2 + eps}} * weight_i n1i=1nxi2+eps xiweighti

    • 其中,公式与代码的对应关系如下:
      在这里插入图片描述

3. LlamaMLP

  • 代码如下:
class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  • 流程图:
    在这里插入图片描述

  • 其中输入为x,输出为y

  • 代码中intermediate_size一般比hidden_size大,我们通过在jupyter notebook中打印Llama-13B的模型,可以看到如下所示:
    在这里插入图片描述

  • 总结:MLP模块就是几个nn.Linear的组合

4. LlamaRotaryEmbedding

  • 代码如下

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
  • 具体的使用,还调用了另外两个函数,如下所示:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
    
  • 注意这里的实现跟原始推导有点区别,这里实现的方式如下图所示:
    在这里插入图片描述

  • 原始推导如下图所示:
    在这里插入图片描述
    具体可以查看作者的博客:👉戳我👈

  • 总结:RoPE就是在attention计算时,K跟Q做内积之前,先给各自注入位置信息。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/97467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

无涯教程-分类算法 - 随机森林

随机森林是一种监督学习算法,可用于分类和回归,但是,它主要用于分类问题,众所周知,森林由树木组成,更多树木意味着更坚固的森林。同样,随机森林算法在数据样本上创建决策树,然后从每…

Linux文件管理知识:查找文件(第二篇)

上篇文章详细介绍了linux系统中查找文件的工具或者命令程序locate和find命令的基本操作。那么,今天这篇文章紧接着查找文件相关操作内容介绍。 Find命令所属操作列表中的条目,有助于我们想要的结果输出。上篇文章已讲到find 命令是基于搜索结果来执行操作…

LAMP介绍与配置

一.LAMP 1.1.LAMP架构的组成 CGI(通用网关接口)和FastCGI(快速公共网关接口)都是用于将Web服务器与后端应用程序(如PHP、Python等)进行交互的协议/接口。 特点 CGI FastCGI 运行方式 每个请求启动…

Seaborn数据可视化(四)

目录 1.绘制箱线图 2.绘制小提琴图 3.绘制多面板图 4.绘制等高线图 5.绘制热力图 1.绘制箱线图 import seaborn as sns import matplotlib.pyplot as plt # 加载示例数据(例如,使用seaborn自带的数据集) tips sns.load_dataset("t…

中国智慧燃气行业市场需求

文章来源:中研普华产业研究院 关键词:智慧燃气、智慧燃气场站、智慧燃气平台、设备设施数字化、数字孪生、工业互联网 智慧燃气,是以城市输气管网为基础,各终端用户协调发展,以信息通信平台为支撑,具有信…

C++信息学奥赛1177:奇数单增序列

#include<bits/stdc.h> using namespace std; int main(){int n;cin>>n; // 输入整数 n&#xff0c;表示数组的大小int arr[n]; // 创建大小为 n 的整型数组for(int i0;i<n;i) cin>>arr[i]; // 输入数组元素for(int i0;i<n;i){ // 对数组进行冒泡排序f…

在腾讯云服务器OpenCLoudOS系统中安装svn(有图详解)

1. 安装svn yum -y install subversion 安装成功&#xff1a; 2. 创建数据根目录及仓库 mkdir -p /usr/local/svn/svnrepository 创建test仓库&#xff1a; svnadmin create /usr/local/svn/test test仓库创建成功&#xff1a; 3. 修改配置test仓库 cd /usr/local/svn/te…

39.RESTful案例

RESTful案例 准备环境 Employee.java public class Employee {private Integer id;private String lastName;private String email;//1 male, 0 femaleprivate Integer gender; } //省略get、set和构造方法EmployeeDao.java package com.atguigu.SpringMVC.dao;import com.…

C++信息学奥赛1178:成绩排序

#include<bits/stdc.h> using namespace std; int main(){int n;cin>>n; // 输入整数 n&#xff0c;表示数组的大小int arr[n]; // 创建大小为 n 的整型数组 arrstring brr[n]; // 创建大小为 n 的字符串数组 brrfor(int i0;i<n;i) cin>>brr[i]>>ar…

有线耳机插入电脑没声音

有线耳机插入电脑没声音 首先确保耳机和电脑都没问题&#xff0c;那就有可能是声音输出设备设置错误 右击任务栏的声音图标-打开声音设置-选择输出设备。

2 hadoop的目录

1. 目录结构&#xff1a; 其中比较的重要的路径有&#xff1a; hdfs,mapred,yarn &#xff08;1&#xff09;bin目录&#xff1a;存放对Hadoop相关服务&#xff08;hdfs&#xff0c;yarn&#xff0c;mapred&#xff09;进行操作的脚本 &#xff08;2&#xff09;etc目录&#x…

线上问诊:业务数据采集

系列文章目录 线上问诊&#xff1a;业务数据采集 线上问诊&#xff1a;数仓数据同步 文章目录 系列文章目录前言一、环境安装1.DataX 二、全量同步1.DataX配置文件生成2.启动hadoop测试一下。3.全量同步 三、增量同步1.配置Flume2.编写Flume拦截器3.通道测试4.修改Maxwell参数…

Pytorch学习:神经网络模块torch.nn.Module和torch.nn.Sequential

文章目录 1. torch.nn.Module1.1 add_module&#xff08;name&#xff0c;module&#xff09;1.2 apply(fn)1.3 cpu()1.4 cuda(deviceNone)1.5 train()1.6 eval()1.7 state_dict() 2. torch.nn.Sequential2.1 append 3. torch.nn.functional.conv2d 1. torch.nn.Module 官方文档…

环保数字化,让污染无处遁形

环保一直以来都是我国大力推崇的举措&#xff0c;“保护环境、人人有责”的标语深入人心&#xff0c;但是环保绝不是某一天某一年就能做好的事情&#xff0c;而在于一朝一夕坚持不懈&#xff0c;下文将针对环保的场景介绍一下数字孪生技术在环保领域的应用。 一、环保背景 新中…

几个nlp的小项目(文本分类)

几个nlp的小项目(文本分类) 导入加载数据类、评测类查看数据集精确展示数据测评方法设置参数tokenizer,token化的解释对数据集进行预处理加载预训练模型进行训练设置训练模型的参数一个根据任务名获取,测评方法的函数创建预训练模型开始训练本项目的工作完成了什么任务?导…

CNN 02(CNN原理)

一、卷积神经网络(CNN)原理 1.1 卷积神经网络的组成 定义 卷积神经网络由一个或多个卷积层、池化层以及全连接层等组成。与其他深度学习结构相比&#xff0c;卷积神经网络在图像等方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他浅层或深度神经…

景联文科技数据标注:人体关键点标注用途及各点的位置定义

人体关键点标注是一种计算机视觉任务&#xff0c;指通过人工的方式&#xff0c;在指定位置标注上关键点&#xff0c;例如人脸特征点、人体骨骼连接点等&#xff0c;常用来训练面部识别模型以及统计模型。这些关键点可以表示图像的各个方面&#xff0c;例如角、边或特定特征。在…

unity 之参数类型之引用类型

文章目录 引用类型引用类型与值类型的差异 引用类型 在Unity中&#xff0c;引用类型是指那些在内存中存储对象引用的数据类型。以下是在Unity中常见的引用类型的介绍&#xff1a; 节点&#xff08;GameObject&#xff09;&#xff1a; 在Unity中&#xff0c;游戏对象&#xff…

day28 异常

to{}catch{} try{}catch{}的流传输 try {fis new FileInputStream("file-APP\\fos.txt");fos new FileOutputStream("fos.txt");int a ;while ((a fis.read())! -1){fos.write(a);}System.out.println(a); } catch (IOException e) {e.printStackTrace()…

在编辑器中使用正则

正则是一种文本处理工具&#xff0c;常见的功能有文本验证、文本提取、文本替换、文本切割等。有一些地方说的正则匹配&#xff0c;其实是包括了校验和提取两个功能。 校验常用于验证整个文本的组成是不是符合规则&#xff0c;比如密码规则校验。提取则是从大段的文本中抽取出…