用code去探索理解Llama架构的简单又实用的方法

除了白月光我们也需要朱砂痣

      我最近也在反思,可能有时候算法和论文也不是每个读者都爱看,我也会在今后的文章中加点code或者debug模型的内容,也许还有一些好玩的应用demo,会提升这部分在文章类型中的比例

      今天带着大家通过代码角度看一下Llama,或者说看一下Casual-LLM的Transfomer到底长啥样

     对Transfomer架构需要更了解的读者,可以先看这个系列

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(1) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(2) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(3) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(4) (qq.com)

       友情提示,看代码和debug都不需要GPU,在你的PC上就可以做

       首先先安装transfomer库

       pip install transfomers

       然后进入到库下面,一般在这

图片

       进去就能找到Transfomers的库,往下拉到models,可以发现各种模型都在里面

图片

找到Llama,包含以下文件

图片

      点进modeling_llama.py发现1200多行,根本没法看(一般人没那么多耐心,但是其实仔细看两遍还是很有收获的),然后点击左边outline 大纲,(或者ctrl+shift+o),就可以有选择的看你想要研究的具体网络层,这样感觉压力瞬间小了百分之80以上

图片

      我想查某个函数的代码块,想对它加深了解,举个例子,看起来比较怪异命名的,这个线性扩展RoPE embedding的函数

图片

       然后ctl按住函数名字就能查到它作用于attetion的机制里面,在这种情况下即使我不知道它到底干啥,也能猜个89不离10,至少和什么主模块相关我清楚了

图片

     本章的话,我们先从模型主体部分看起,点LlamaModel,就能看到非常清晰的逻辑

    

图片

    主体部分包含3个子模块:

  • 先要embedding token

  • 再要包含一个通过for循环,不断的持续经过的decoder层

  • 还要包含一个Normal(RMSNorm)

     

     从大面上看也就这么3个操作

     初始化部分,初始了哪些模块我们看完了,再看看细节,从forward看起(大家看任何网络都要重点看forward)

图片

 forward输入部分,要求输入的参数:

  • input_ids:输入序列标记的索引,形状为(batch_size, sequence_length)的torch.LongTensor。

  • attention_mask:避免对填充标记进行注意力计算的掩码,形状为(batch_size, sequence_length)的torch.Tensor

  • position_ids:输入序列标记在位置嵌入中的索引,形状为(batch_size, sequence_length)的torch.LongTensor

  • past_key_values:包含预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),用于加速顺序解码的tuple,推理用的,当use_cache=True时,会返回这个参数,训练不用管

  • inputs_embeds:直接传入embedding表示而不是input_ids,形状为(batch_size, sequence_length, hidden_size)的torch.FloatTensor

  • use_cache:是否使用缓存加速解码的布尔值,当设置为True时,past_key_values的键值状态将被返回,用于加速解码

  • output_attentions:是否返回所有注意力层的注意力张量,布尔值,当设置为True时,会在返回的结果中包含注意力张量

  • output_hidden_states:是否返回所有层的隐藏状态,布尔值,当设置为True时,会在返回的结果中包含隐藏状态

  • return_dict:是否返回utils.ModelOutput对象而不是普通的元组,布尔值,当设置为True时,会返回一个ModelOutput对象

         我们继续看,下面就是一些操作步骤,输入的input_id,会被向量化,生成hidden_states

图片

    hidden_states然后就被扔进了若干个hidden_layer被for循环来回的操作,比如Llama7B的32层

   

图片

     我们简单写一段逻辑描述上述的代码

     比如在把"我爱你"已经分词的情况下 我=100,爱=200,你=300

  input_ids = [100,200,300]

  input_ids  -> nn.Emebdding(dims=3) -> hidden_states

  hidden_states = [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7.0.8.1.1]]

  hidden_states ->layer1 ->layer2 ------>layer32

  最终还是hidden_states

  也就是最终的shape和初始的hidden_states的shape的相同的

   hidden_states -> Norm -> hidden_states

   其实没必要把hidden_states理解的那么悬,就当它是个中间变量就可以了

     Layer里面都有什么呢?我们点进去Layer里面就能看到

图片

     包含了我们在Transfomer里学到的attention层,MLP层,LayerNorm这些层

     继续往下

图片

     我们可以看到首先定义了残差,然后hidden_states在没过layer之前就先被Normal了一下(这个知识点以前也讲过,Llama RMS LN是前LN),然后过attetion,过MLP,最后hidden_states=residual+hidden_states, 这样就把位置编码啥的也都带过来了

     然后我们点击进入attetion这块,看看它是咋做的

图片

      这块其实感觉都不用讲了,我觉得全部代码就这块写的逻辑是最清楚的

图片

,QKV矩阵就是下面前三个线性层,然后分别主管生成qkv,上面的参数也解释了多头,和多头dim是咋算出来的,GQA咋算,这块代码写的实在清晰,属实没啥可说的

       最后看一眼MLP层,就是FFN

图片

      这一层是由3个线性层组成的gate+up+down,和标准transfomer的FFN是真不一样,主要有两点不同:

1- 原始transfomer是由两个线性层组成,这里有3个,原因就在于 SwiGLU 激活函数比类似 ReLU 的激活函数,需要多一个 Linear 层(gate)进行门控,所以你说SwiGLU是比ReLU效率高了,还是低了呢?

图片

(当然这个门控可以并行操作)

2- 原始Transfomer第一个线性层先将维度映射为4h维,第二个线性层再映射回h维,接着进行激活函数操作。而llama则是将原有4h变成一个常量作为输入,且计算方式也略有不同,可能是因为4h这个说法如果模型太大会罩不住,就用一个常量来代替了(我瞎猜的)

        这基本网络就讲完了,其实大家看看也没啥玩意,比较简单

      再比如说我们像看一下Llama是咋做下游任务的,modeling里面有2个,我们挑CLM任务(主管生成的任务,CasualLM)来看

图片

     如上图所示,就是在Llama的基础模型之上,加了一个线性层

     然后我们展开LlamaForCausalLM

图片

      可以看到它包含很多内容,这里面重点肯定还是forward(前向),我们就看一下它的前向是怎么进行的,输入是啥,输出是啥?

图片

      我们先看forward输入部分,要求输入的参数和前面都一样唯一的区别就是Label

  • Labels:没法解释,就是字面上的Labels

    图片

    图片

     输出这块我们就关注hidden_states就行了,这也是最重要的

图片

      然后hiden_states和你的labels进行匹配求交叉熵损失,最后损失最小的就是最高概率生成的token 

      下面那个文本分类的其实也是一样的,只不过它计算Loss的时候会有说明,单标签就是回归任务,多标签就变分类任务了

图片

     通过这些简单的方法,就能把读者们把以前理解不太透彻的网络模块都理一遍,加深印象      

     本章完

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/383933.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Hadoop:认识MapReduce

MapReduce是一个用于处理大数据集的编程模型和算法框架。其优势在于能够处理大量的数据,通过并行化来加速计算过程。它适用于那些可以分解为多个独立子任务的计算密集型作业,如文本处理、数据分析和大规模数据集的聚合等。然而,MapReduce也有…

Github 2024-02-11 开源项目日报Top10

根据Github Trendings的统计,今日(2024-02-11统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目4非开发语言项目2C项目1C项目1Solidity项目1JavaScript项目1Rust项目1HTML项目1 免费服务列表 | f…

shell脚本编译与解析

文章目录 shell变量全局变量(环境变量)局部变量设置PATH 环境变量修改变量属性 启动文件环境变量持久化 ./和. 的区别脚本编写重定向判断 和循环命令行参数传入参数循环读取命令行参数获取用户输入 处理选项处理简单选项处理带值选项 重定向显示并且同时…

【开源】基于JAVA+Vue+SpringBoot的实验室耗材管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 耗材档案模块2.2 耗材入库模块2.3 耗材出库模块2.4 耗材申请模块2.5 耗材审核模块 三、系统展示四、核心代码4.1 查询耗材品类4.2 查询资产出库清单4.3 资产出库4.4 查询入库单4.5 资产入库 五、免责说明 一、摘要 1.1…

为什么说 2023 年是 AI 视频生成的突破年?2024 年的 AI 视频生成有哪些值得期待的地方?

Diffusion Models视频生成-博客汇总 前言:2023年是 AI 视频生成的突破年,AI视频已经达到GPT-2级别了。去年我们取得了长足的进步,但距离普通消费者每天使用这些产品还有很长的路要走。视频的“ChatGPT时刻”何时到来? 目录 前言 …

计算机网络——06分组延时、丢失和吞吐量

分组延时、丢失和吞吐量 分组丢失和延时是怎样发生的 在路由器缓冲区的分组队列 分组到达链路的速率超过了链路输出的能力分组等待排到队头、被传输 延时原因: 当当前链路有别的分组进行传输,分组没有到达队首,就会进行排队,从…

SHA-512在Go中的实战应用: 性能优化和安全最佳实践

SHA-512在Go中的实战应用: 性能优化和安全最佳实践 简介深入理解SHA-512算法SHA-512的工作原理安全性分析SHA-512与SHA-256的比较结论 实际案例分析数据完整性验证用户密码存储数字签名总结 性能优化技巧1. 利用并发处理2. 避免不必要的内存分配3. 适当的数据块大小总结 与其他…

【JavaEE】_传输层协议UDP与TCP

目录 1. 开发中常见的数据组织格式 1.1 XML 1.2 JSON 1.3 Protobuf 2. 端口号 3. UDP协议 4. TCP协议 4.1 特点 4.2 TCP报文格式 4.3 TCP可靠性机制 4.3.1 确认应答机制 4.3.2 超时重传机制 4.3.2.1 丢包的两种情况 4.3.2.2 重传时间 4.3.3 连接管理机制 4.3.3…

分享88个文字特效,总有一款适合您

分享88个文字特效,总有一款适合您 88个文字特效下载链接:https://pan.baidu.com/s/1Y0JCf4vLyxIJR6lfT9VHvg?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不…

160基于matlab的负熵和峭度信号的盲分离

基于matlab的负熵和峭度信号的盲分离。基于峭度的FastICA算法的收敛速度要快,迭代次数比基于负熵的FastICA算法少四倍以上。SMSE随信噪比增大两种判据下的FastICA算法都逐渐变小,但是基于峭度的算法的SMSE更小,因此基于峭度的FastICA算法性能…

H12-821_26

26.下列选项中,哪些路由前缀满足下面的IP-Prefix条件? A.20.0.1.0/24 B.20.0.1.0/23 C.20.0.1.0/25 D.20.0.1.0/28 答案:ACD 注释: 前缀列表可以匹配路由前缀和网络掩码。 ip ip-prefix test index 10 permit 20.0.0.0 16 greater-equal 24 less-equal…

【开源】SpringBoot框架开发个人健康管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 健康档案模块2.2 体检档案模块2.3 健康咨询模块 三、系统展示四、核心代码4.1 查询健康档案4.2 新增健康档案4.3 查询体检档案4.4 新增体检档案4.5 新增健康咨询 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpri…

ZigBee学习——在官方例程实现组网

✨Z-Stack版本:3.0.2 ✨IAR版本:10.10.1 ✨这篇博客是在善学坊BDB组网实验的基础上进行完善,并指出实现的过程中会出现的各种各样的问题! 善学坊教程地址: ZigBee3.0 BDB组网实验 文章目录 一、基础工程选择二、可能遇…

Linux笔记之Docker进行镜像备份与迁移

Linux笔记之Docker进行镜像备份与迁移 ——2024-02-11 code review! 文章目录 Linux笔记之Docker进行镜像备份与迁移1. 导出容器文件系统为 tar 归档文件2. 将 tar 归档文件导入为新的 Docker 镜像3. 运行新的 Docker 镜像并创建容器 1. 导出容器文件系统为 tar 归档文件 要导…

【前端web入门第五天】03 清除默认样式与外边距问题【附综合案例产品卡片与新闻列表】

文章目录: 1.清除默认样式 1.1清除内外边距1.2清除列表圆点(项目符号) 3.外边距问题-合并现象4.外边距问题–塌陷问题5.行内元素垂直内外边距6.圆角与盒子阴影 6.1圆角 6.2 盒子模型-阴影(拓展) 综合案例一 产品卡片 综合案例二 新闻列表 1.清除默认样式 在实际设计开发中,要…

OpenCV-36 多边形逼近与凸包

目录 一、多边形的逼近 二、凸包 一、多边形的逼近 findContours后的轮廓信息countours可能过于复杂不平滑,可以用approxPolyDP函数对该多边形曲线做适当近似,这就是轮廓的多边形逼近。 apporxPolyDP就是以多边形去逼近轮廓,采用的是Doug…

带特效喝酒神器小程序源码-多种游戏支持流量主

由多个游戏组合而成,每一个小程序都基本带特效~~ 功能如下 1.小马快跑(支持竞选模式和个人单选模式,PS马是真的在跑哟) 2.彩票智能选号(支持多个彩种选号,快来选你的专属号码吧) 3.整蛊鳄鱼(少了一颗牙自动往酒杯加酒,看你和几杯) 4.真心话大冒险(这个就不多做解释啦) 5.…

【数学建模】【2024年】【第40届】【MCM/ICM】【F题 减少非法野生动物贸易】【解题思路】

一、题目 (一) 赛题原文 2024 ICM Problem F: Reducing Illegal Wildlife Trade Illegal wildlife trade negatively impacts our environment and threatens global biodiversity. It is estimated to involve up to 26.5 billion US dollars per y…

PyCharm2023.3.2配置conda环境

重点在于Path to conda这一步,需要找到conda.bat这个文件,PyCharm才能识别出现有的conda环境。

分享76个文字特效,总有一款适合您

分享76个文字特效,总有一款适合您 76个文字特效下载链接:https://pan.baidu.com/s/1rIiUdCMQScoRVKhFhXQYpw?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不…