LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

在这里插入图片描述

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
        super().__init__()
        # 相比RoPE增加scale参数
        self.scale = scale
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        # 构建max_seq_len_cached大小的张量t
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        # 张量t归一化,RoPE没有这一步
        t /= self.scale
        # einsum计算频率矩阵
        # 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 在-1维做一次拷贝、拼接
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        # 注册为模型的缓冲区cos_cached和sin_cached
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            t /= self.scale
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        # 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
        super().__init__()
        # 与线性插值法相比,实现更简单,alpha仅用来改变base
        base = base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
        super().__init__()
        # 是否开启NTK(Neural Tangent Kernel)
        self.ntk = ntk
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # emb:[max_seq_len_cached, dim]
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            if self.ntk:
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
                # 计算新的inv_freq
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                self.register_buffer("inv_freq", inv_freq)
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            if not self.ntk:
                # 缩放
                t *= self.max_position_embeddings / seq_len
            # 得到新的频率矩阵freqs
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            # freqs与自身拼接得到新的emb
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            # 注册为模型的缓冲区cos_cached和sin_cached
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)

        # 长度裁剪
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/73962.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ArcGIS Pro 基础安装与配置介绍

ArcGIS Pro ArcGIS Pro作为ESRI面向新时代的GIS产品,它在原有的ArcGIS平台上继承了传统桌面软件(ArcMap)的强大的数据管理、制图、空间分析等能力,还具有其独有的特色功能,例如二三维融合、大数据、矢量切片制作及发布…

用vim打开后中文乱码怎么办

Vim中打开文件乱码主要是文件编码问题。用户可以参考如下解决方法。 1、用vim打开.vimrc配置文件 vim ~/.vimrc**注意:**如果用户根目录下没有.vimrc文件就把/etc/vim/vimrc文件复制过来直接用 cp /etc/vim/vimrc ~/.vimrc2、在.vimrc中加入如下内容 set termen…

树莓派3B CSI摄像头配置

1.硬件连接 1、找到 CSI 接口(树莓派3B的CSI接口在HDMI接口和音频口中间),需要拉起 CSI 接口挡板,如下: 2、将摄像头排线插入CSI接口。记住,有蓝色胶带的一面应该面向音频口或者网卡方向, 确认方向并插紧排线,将挡板…

app专项测试:app弱网测试

目录 弱网测试背景 网络测试要点 弱网测试关注指标 弱网测试工具 fiddler模拟网络延时场景 网络设置参考 Network Emulator Toolkit模拟网络丢包场景(windows网络) APP弱网测试 弱网使用工具: app弱网测试要点 APP网络测试要点 网络…

Mysql 搭建MHA高可用架构,实现自动failover,完成主从切换

目录 自动failover MHA: MHA 服务 项目:搭建Mysql主从复制、MHA高可用架构 实验项目IP地址配置: MHA下载地址 项目步骤: 一、修改主机名 二、编写一键安装mha node脚本和一键安装mha mangaer脚本,并执行安装…

网络安全 Day31-运维安全项目-容器架构下

容器架构下 6. Dockerfile6.1 Docker自动化DIY镜像之Dockerfile1) 环境准备2) 书写Dockerfile内容3) 运行Dockerfile生成镜像4) 运行容器5) 小结 6.2 案例14:Dockerfile-RUN指令1) 书写Dockerfile2) 构建镜像3) 启动容器4) 测试结果 6.3 Dockerfile指令 …

小程序制作教程:从零开始搭建企业小程序

在如今的数字化时代,企业介绍小程序成为了企业展示与推广的重要工具。通过企业介绍小程序,企业可以向用户展示自己的品牌形象、产品服务以及企业文化等内容,进而提高用户对企业的认知度和信任度。本文将介绍如何从零开始搭建一个企业介绍小程…

【Vue-Router】命名视图

命名视图 同时 (同级) 展示多个视图,而不是嵌套展示,例如创建一个布局,有 sidebar (侧导航) 和 main (主内容) 两个视图,这个时候命名视图就派上用场了。 可以在界面中拥有多个单独命名的视图,而不是只有一个单独的出…

LeetCode150道面试经典题--最后一个单词的长度(简单)

1.题目 给你一个字符串 s,由若干单词组成,单词前后用一些空格字符隔开。返回字符串中 最后一个 单词的长度。 单词 是指仅由字母组成、不包含任何空格字符的最大子字符串。 2.示例 3.思路 通过对字符串的反转,转为数组开始遍历&#xff0c…

Python中使用隧道爬虫ip提升数据爬取效率

作为专业爬虫程序员,我们经常面临需要爬取大量数据的任务。然而,有些网站可能会对频繁的请求进行限制,这就需要我们使用隧道爬虫ip来绕过这些限制,提高数据爬取效率。本文将分享如何在Python中使用隧道爬虫ip实现API请求与响应的技…

UML之四种事物

目录 结构事物 行为事物 分组事物: 注释事物 结构事物 1.类(Class) -类是对一组具有相同属性、方法、关系和语义的对象的描述。一个类实现一个或多个接口 2.接口(interface) -接口描述 了一个类或构件的一个服务的操作集。接口仅仅是定义了一组操作的规范&…

每日后端面试5题 第三天

1. 线程有哪几种状态以及各种状态之间的转换?(必会) 看图: 图片来自 线程状态转换图及其5种状态切换_小曹的blog的博客-CSDN博客 图片来自 总算把线程六种状态的转换说清楚了! - 知乎 线程一共有4种状态,分别是: 1.…

【日常积累】RPM包依赖下载及私有yum仓库搭建

概述 某些时候,我们需要下载某个RPM包依赖的依赖。如某些内网环境,就需要自行准备rpm包。可以通过能上互联网的服务器进行相应的rpm包下载,然后在拷贝到相应的服务器安装,或者搭建自己的内容rpm包仓库。 查看*.rpm 包依赖&#…

分布式系统监控zabbix安装部署以及使用

文章目录 分布式系统监控zabbix安装部署及使用一.zabbix监控1.什么是zabbix2.zabbix功能3.zabbix的构成4.zabbix的3种架构4.1 C/S架构4.2 分布式架构:zabbix-proxy-client架构4.3 master-node-client架构 5.zabbix工作原理及数据流向6.zabbix监控模式 二.zabbix部署…

41、可靠传输——停等ARQ

前面两节内容我们学习了传输层的基本概况的一些知识,包括传输层在TCP/IP协议栈中负责的任务、传输层的两大协议,以及端口号、套接字等一些基本的概念。从这一节开始,我们将开启两大协议中TCP协议的学习。 但是,经过之前的学习&am…

Kotlin语法

整理关键语法列表如下: https://developer.android.com/kotlin/interop?hlzh-cn官方指导链接 语法形式 说明 println("count ${countnum}")字符串里取值运算 val count 2 var sum 0 类型自动推导 val 定义只读变量,优先 var定义可变变量…

shell之正则表达式及三剑客grep命令

一、正则表达式概述 什么是正则表达式? 正则表达式是一种描述字符串匹配规则的重要工具 1、正则表达式定义: 正则表达式,又称正规表达式、常规表达式 使用字符串描述、匹配一系列符合某个规则的字符串 正则表达式 普通字符: 大小写字母…

【云原生】K8S存储卷:PV、PVC详解

目录 一、emptyDir存储卷二、hostPath存储卷三、nfs共享存储卷四、PVC 和 PV4.1 NFS使用PV和PVC4.2创建动态PV 一、emptyDir存储卷 容器磁盘上的文件的生命周期是短暂的,这就使得在容器中运行重要应用时会出现一些问题。首先,当容器崩溃时,ku…

ReBel 论文学习笔记

论文:《Combining Deep Reinforcement Learning and Search for Imperfect-Information Games》 地址:https://arxiv.org/abs/2007.13544v2 代码:https://github.com/facebookresearch/rebel 材料: BV1gt4y1k77C(1小时…

Linux 当fork在for循环中的问题

以下代码会打印几个"A"&#xff1f; 例1.代码如下&#xff1a; int main(int argc, char* argv[],char* envp[]) { for(int i 0;i < 2; i ) { fork(); printf("A\n"); } exit(0); } 代码分析&#xff1a; //父进程for(int i …