LLaMA开源大模型源码分析!

 Datawhale干货 

作者:宋志学,Datawhale成员

花了一晚上照着transformers仓库的LLaMA源码,把张量并行和梯度保存的代码删掉,只留下模型基础结构,梳理了一遍LLaMA的模型结构。

今年四月份的时候,我第一次接触深度学习,也是今年第一次接触Datawhale,在Datawhale和小伙伴一起学习、讨论了大半年,不知不觉已经可以做到看源码的程度了。

Datawhale才是一个没有围墙的大学,在这里无论你有什么想法💡,只要你愿意前进,总会有小伙伴和你一起。

博客地址:

https://flowus.cn/kmno4/share/527055be-464f-4f0f-98c5-8b8f72a1fc2e

LLaMA-Model

在transformers仓库中可以看到llama的源码,首先是LlamaModel类,继承自PreTrainedModel,这个类是所有模型的基类,包含了一些通用的方法,比如保存模型、加载模型、初始化权重等。

继承关系为:LlamaModel-> LlamaPreTrainedModel-> PreTrainedModel

LlamaConfig

LlamaConfig 中主要是定义一些参数,比如vocab_size、hidden_size、num_hidden_layers、num_attention_heads等。所有的参数有默认值,可以直接创建cofing就能用。

config = LlamaConfig()

LlamaModel

6783830017ed6f9a29869202cd8218ff.jpeg

LlamaModel 初始化

  • 设置了模型的两个属性:padding_idx(用于指定填充标记的索引),vocab_size(词汇表的大小)

  • 初始化了模型的嵌入层、解码器层、归一化层

  • 嵌入层(nn.Embedding):模型使用嵌入层将输入的标记映射成密集的向量表示。

  • 解码器层(nn.ModuleList()):模型包含多个解码器层,这些层都是由 LlamaDecoderLayer 定义

  • 归一化层 LlamaRMSNorm:归一化层使用的是 Root Mean Square Layer Normalization(RMS Layer Norm)

  • 设置了是否使用 gradient_checkpoint 主要是用来节省显存

  • 调用 post_init() 完成一些初始化和准备检查的代码

def __init__(self, config: LlamaConfig):
    super().__init__(config)
    self.padding_idx = config.pad_token_id
    self.vocab_size = config.vocab_size

    # embedding 层
    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
    # 中间的一堆 decoderlayers 层
    self.layers = nn.ModuleList(
        [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
    )
    self._use_sdpa = config._attn_implementation == "sdpa"
    self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
    self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    self.gradient_checkpointing = False
    # Initialize weights and apply final processing
    self.post_init()

可以看一下 post_init() 的代码,主要是初始化权重和gradient_checkpointing相关的一些事情。该方法在PreTrainedModel基类中,transformers中所有模型基本都继承这个类。

def post_init(self):
    """
    A method executed at the end of each Transformer model initialization, to execute code that needs the model's
    modules properly initialized (such as weight initialization).
    """
    self.init_weights()
    self._backward_compatibility_gradient_checkpointing()

LlamaModel forward

forward 部分的代码有点长,但其实大部分都是张量并行或者是节省显存相关的代码,对于理解模型结构来说可以直接忽略。

首先进来就是把 inputs_ids 进行向量化,然后拿到 hidden_states 。然后是存起来所有的hidden_states 进入 decoder_layer 再拿一个 hidden_states,作为下一轮 decoder_layerhidden_states 输入,最后给 hidden_states norm一下。如下代码所示:

inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds

for decoder_layer in self.layers:
    # 存起来所有的 hidden_states
    if output_hidden_states:
        all_hidden_states += (hidden_states,)
    # 这里是 decoder_layer 的 forward
    layer_outputs = decoder_layer(
        hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    # 再拿一个 hidden_states,作为下一轮 decoder_layer 的 hidden_states 输入
    hidden_states = layer_outputs[0]

hidden_states = self.norm(hidden_states)

最后就是以 BaseModelOutputWithPast 的形式输出。ok,接下来继续看decoder_layer中的其他代码。

LlamaDecoderLayer

Embedding层不用多说,用的就是torch中的nn.Embedding。那就直接来看DecoderLayer。

8e444bd665786abd4f4a471a8a72b244.png

DecoderLayers 初始化

先来看初始化。

  • hidden_size : 也就是在上面说的输入输出。

  • self_attn : 别看它写这么多啊,其实就是选一下用什么 attention 。看见大写字母不要怕,直接点进去看看怎么个事!

    LLAMA_ATTENTION_CLASSES = {
        "eager": LlamaAttention,
        "flash_attention_2": LlamaFlashAttention2,
        "sdpa": LlamaSdpaAttention,
    }
  • mlp : 一个全连接层 LlamaMLP 这个待会后面再说,输入输出都是 hidden_size 大小。

  • input_layernorm : LlamaRMSNorm 层,输入时候的norm

  • post_attention_layernorm : 丢入 mlp 之前的操作。

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\

DecoderLayers forward

首先复制一份 hidden_statesresidual。然后 hidden_states 进入 input_layernorm 进行norm。然后进入 self_attn 进行 attention 操作,拿到 hidden_statesself_attn_weightspresent_key_value。然后 hidden_statesresidual 相加,得到 hidden_states

然后 hidden_states 进入 post_attention_layernorm 进行norm。最后 hidden_states 进入 mlp 进行全连接操作,拿到 hidden_states。然后 hidden_statesresidual 相加,得到 hidden_states。最后输出 hidden_states

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    **kwargs,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
    outputs += (self_attn_weights,)

if use_cache:
    outputs += (present_key_value,)

return outputs

Llama Attention

31aaddd17134c8360db0117ba21d30b6.png

看代码首先映入眼帘的就是  Attention Is All You Need  好好好,很有精神!那我们接着往下看。

先来看 init 部分叭。

  • layer_idx : 这个就是第几个 DecoderLayers 层。不用关心。

  • attention_dropout : 用于dropout的概率。

  • hidden_size : 输入输出大小。

  • num_attention_heads : 多头注意力的头数。

  • head_dim : 多头注意力的维度 self.hidden_size // self.num_heads,和transformers中的一样。

  • num_key_value_heads : 用于key和value的头数。

其他的参数都在 LlamaConfig 中有默认值,可以直接使用,也可以直接去LlamaConfig的源码中看具体的解释,这里就不再多说。

再往下就是 q_projk_projv_projo_proj 四个矩阵(全连接层),耳熟能详了。

class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

LlamaAttention forward

重头戏来了,attention forward 部分。

注意:其中有关于张量并行或者显存节省的部分我就直接省略了,直接看主要代码。这个笔记主要是分析llama的模型结构,并不讨论如何节省显存。

首先拿到 hidden_statesbatch_sizeseq_len 。然后把 hidden_states 丢入 q_projk_projv_proj 三个矩阵(全连接层),拿到 query_stateskey_statesvalue_states 。然后把 query_stateskey_statesvalue_states reshape 为下一步计算做准备。

将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果

key_statesvalue_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化

然后 attn_weights 加上 attention_mask,再 softmaxdropout。然后 attn_weightsvalue_states 相乘,把 attn_output reshape 为下一步计算做准备,最后把 attn_output 丢入 o_proj ,然后return就行了。

好了,至此。我觉得llama最激动人心的地方已经结束了。

# 获取 batch_size 和 seq_len
bsz, q_len, _ = hidden_states.size()

# 把 hidden_states 丢入 q_proj、k_proj、v_proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# 把 q_proj、k_proj、v_proj 的输出 reshape 为下一步计算做准备
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# 将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# 首先,它将key_states和value_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

# 然后 attn_weights 加上 attention_mask
attn_weights = attn_weights + attention_mask

# softmax + dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)

# 然后 attn_weights 和 value_states 相乘
attn_output = torch.matmul(attn_weights, value_states)

# 然后把 attn_output reshape 为下一步计算做准备
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

# 最后把 attn_output 丢入 o_proj
attn_output = self.o_proj(attn_output)

# 返回 attn_output、attn_weights、present_key_value
return attn_output, attn_weights, past_key_value

LlamaMLP

c1cd4cf6f3c0e2a88c2f2b536e2fd10f.png

看完 attention 再看 MLP ,突然就觉得好简单了,哈哈哈。这部分代码比较少,就直接放到一起了。

x进来之后先进去up_proj和gate_proj,gate_proj进行激活,然后这俩再乘起来,丢进 down_proj。那直接放个图叭,这个过程有点简单了。

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 这俩不必多说
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        # 三个全连接层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

LlamaRMSNorm

RMSNorm函数可以用以下数学公式表示:

其中:

  • 是层的输入。

  • 代表层的权重。

  • 是权重的数量。

  • 是一个小常数,用于数值稳定性(以避免除以零的情况)。

这种归一化有助于通过确保权重的规模不会变得过大或过小来稳定学习过程,这在具有许多层的深度学习模型中特别有用。

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

参考:https://space.bilibili.com/45156039

8c2428ec014740c4d5aa5ee991a3cb14.png

干货学习,三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/265425.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

3.苍穹外卖-day03

苍穹外卖-day03 课程内容 公共字段自动填充 新增菜品 菜品分页查询 删除菜品 修改菜品 功能实现:菜品管理 菜品管理效果图: 1. 公共字段自动填充 1.1 问题分析 在上一章节我们已经完成了后台系统的员工管理功能和菜品分类功能的开发,在新…

AI工程化—— 如何让AI在企业多快好省的落地?

1. 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。 点击跳转到网站 作为计算机科学的一个重要领域,机器学习也是目前人工智能领域非常活跃的分支之一。机器学习通过分析海量数据、总结规…

sql_lab之sqli中的堆叠型注入(less-38)

堆叠注入(less-38) 1.判断注入类型 http://127.0.0.3/less-38/?id1 and 12 -- s 没有回显 http://127.0.0.3/less-38/?id1 and 11 -- s 有回显 则说明是单字节’注入 2.查询字段数 http://127.0.0.3/less-38/?id1 order by 4 -- s 报错 http:/…

c# OpenCvSharp 检测(斑点检测、边缘检测、轮廓检测)(五)

在C#中使用OpenCV进行图像处理时,可以使用不同的算法和函数来实现斑点检测、边缘检测和轮廓检测。 斑点检测边缘检测轮廓检测 一、斑点检测(Blob) 斑点检测是指在图像中找到明亮或暗的小区域(通常表示为斑点)&#…

数据结构和算法-二叉排序树(定义 查找 插入 删除 时间复杂度)

文章目录 二叉排序树总览二叉排序树的定义二叉排序树的查找二叉排序树的插入二叉排序树的构造二叉排序树的删除删除的是叶子节点删除的是只有左子树或者只有右子树的节点删除的是有左子树和右子树的节点 查找效率分析查找成功查找失败 小结 二叉排序树 总览 二叉排序树的定义 …

JavaScript系列-函数(function)

文章目录 函数定义函数的特征 创建函数方式函数声明实现函数内部操作对外部可见 函数表达式匿名表达式带名称表达式 函数调用方式函数提升函数作用域作用域和函数栈递归 嵌套函数和闭包闭包特性-保存变量 使用 arguments 对象箭头函数定义 更多内容 函数定义 提示:函…

【MYSQL】MYSQL 的学习教程(六)之 SQL 语句执行流程

1. 一条 SQL 查询语句是如何被执行的 MySQL 的基本架构示意图如下所示: MYSQL 线程处理请求流程: SQL 接口:MySQL 中处理请求的线程在获取到请求以后获取 SQL 语句去交给 SQL 接口去处理查询解析器:解析器会将 SQL 接口传递过来…

VSCode Emoji 在 Windows10 下的显示问题

VSCode Emoji 在 Windows10 下的显示问题 问题描述 使用系统快捷键 Win ;(分号) 或 Win .(句号) 可以打开系统的 Emoji 面板,用于输入表情符号。 但是在 Windows 10 的 VSCode 中,一部分 Emoji 的显示会出现问题,比如以下这些&#xff1…

跟着LearnOpenGL学习9--光照

文章目录 一、颜色二、创建光照场景 一、颜色 显示世界中有无数种颜色,每一个物体都有它们自己的颜色。我们需要使用(有限的)数值来模拟现实世界中(无限的)的颜色,所以并不是所有现实世界中的颜色都可以用…

网络游戏管理新规:重氪滚服成历史,SLG、MMO与小游戏逻辑亟待换新

12月22日,国家新闻出版署网站发布《网络游戏管理办法(草案征求意见稿)》(以下简称《办法》),向社会公开征求意见。 午间《办法》一经发出后,在行业内立刻引发震动,诸多从业者表示&a…

将mapper.xml保存为idea的文件模板

将mapper.xml保存为idea的文件模板 在idea的File and Code Templates中将需要使用模板的内容添加为模板文件。 那么接下来请看图&#xff0c;跟着步骤操作吧。 mapper.xml文件内容 <?xml version"1.0" encoding"UTF-8"?> <!DOCTYPE mapper P…

基于SpringBoot的校园疫情防控管理系统 JAVA简易版

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学生2.2 老师2.3 学校管理部门 三、系统展示四、核心代码4.1 新增健康情况上报4.2 查询健康咨询4.3 新增离返校申请4.4 查询防疫物资4.5 查询防控宣传数据 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBoot…

数字音频编辑软件audition 2021 mac功能介绍

Audition 2021 mac是一款专业数字音频编辑软件&#xff0c;提供先进的音频混音、编辑和效果处理功能&#xff0c;专为音频和视频专业人员设计。无论是要录制音乐、无线电广播&#xff0c;还是为录像配音&#xff0c;Audition都能帮到您。它可提供先进的音频混合、编辑、控制和效…

大创项目推荐 深度学习 植物识别算法系统

文章目录 0 前言2 相关技术2.1 VGG-Net模型2.2 VGG-Net在植物识别的优势(1) 卷积核&#xff0c;池化核大小固定(2) 特征提取更全面(3) 网络训练误差收敛速度较快 3 VGG-Net的搭建3.1 Tornado简介(1) 优势(2) 关键代码 4 Inception V3 神经网络4.1 网络结构 5 开始训练5.1 数据集…

AutoEncoder个人记录

原理 最常见的降维算法有主成分分析法PCA&#xff0c;通过对协方差矩阵进行特征分解而得到数据的主要成分&#xff0c;但是 PCA 本质上是一种线性变换&#xff0c;提取特征的能力极为有限。 AutoEncoder把长度为d_in输入特征向量变换到长度为d_out的输出向量&#xff0c;借助于…

地震勘探原理---数字滤波处理

一. 地震数字滤波的目标 核心任务&#xff1a;去噪&#xff0c;提高地震资料信噪比 噪声压制: 野外采集中可以通过检波器组合, 震源组合, 地震多次覆盖技术来压制干扰波, 但是由于多种原因, 野外采集的资料仍然残留一定干扰波, 必须采用室内数字处理的方式来进行压制. 根据有效…

MyBatis 通过 SqlSession 实现动态Entity批量插入

需要几个关键点: 1、entity对应的service需要继承BaseService 2、entity对应的serviceImpl需要实现baseMapper方法&#xff0c;需要把当前的mapper返回去 3、entity对应的Mapper需要BaseMapper

Java并发工具类---ForkJoin、countDownlatch、CyclicBarrier、Semaphore

一、Fork Join fork join是JDK7引入的一种并发框架&#xff0c;采用分而治之的思想来处理并发任务 ForkJoin框架底层实现了工作窃取&#xff0c;当一个线程完成任务处于空闲状态时&#xff0c;会窃取其他工作线程的任务来做&#xff0c;这样可以充分利用线程来进行并行计算&a…

官宣!DevExpress Blazor UI组件,支持全新的.NET 8渲染模式

DevExpress Blazor UI组件使用了C#为Blazor Server和Blazor WebAssembly创建高影响力的用户体验&#xff0c;这个UI自建库提供了一套全面的原生Blazor UI组件&#xff08;包括Pivot Grid、调度程序、图表、数据编辑器和报表等&#xff09;。 .NET 8为Blazor引入了令人兴奋的重…

柯桥外语学习-俄语零基础入门教学之与衣服有关的词汇

本期为大家带来的是与衣物有关的相关词汇&#xff01; 最近全国大范围降温&#xff0c;大家一定要关注天气预告及时增减衣物&#xff0c;小心不要感冒啦~ 一、服装组成部分 领子 воротник 方领 квадрадный воротник 圆领 закругленн…