如何修改大模型的位置编码 --以LLama为例

最近在看RoPE相关内容,一些方法通过简单修改位置编码就可以无需训练支持更长的文本内容。由于一些模型,已经训练好了,但是怎么修改已经训练好的模型位置编码。查了以下相关代码,记录一下。原理这里就不细讲了,贴几个相关博客。
十分钟读懂旋转编码(RoPE)
Transformer升级之路:11、将β进制位置进行到底
Transformer升级之路:10、RoPE是一种β进制编码

NTK

下图为NTK的原理证明:截取自Transformer升级之路:10、RoPE是一种β进制编码
在这里插入图片描述
在这里插入图片描述

看了上面的公式,我在考虑为什么需要建立 λ \lambda λ和k之间的关系?

因为我们要修改 β \beta β进制为 β λ \beta\lambda βλ,由于k我们是可以知道的比如我们需要把位置编码缩小为10倍,直接设置k为10,但是采用NTK的方式,维度缩小为10倍,那么我们就不确定, λ \lambda λ怎么设置了。所以需要简历 λ \lambda λ和k之间的关系,从上图可知, λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2)
下面开始理解如何修改RoPE为NTK的形式:
以下为LLama的RoPE代码实现

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

    @property
    def sin_cached(self):
        logger.warning_once(
            "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._sin_cached

    @property
    def cos_cached(self):
        logger.warning_once(
            "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._cos_cached

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

λ \lambda λ和k之间的关系,那么代码怎么实现呢,我们只需要修改 β λ \beta\lambda βλ的结果即可,其中 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d
参考代码为:点击

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    k = 8 #Alpha value
    base = base * k ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

为什么采用base = base * k ** (dim / (dim-2)),原始的base为10000, 而 β \beta β 1000 0 2 / d 10000^{2/d} 100002/d,发现在代码里面修改的仅仅是base的结果, β \beta β= b a s e 2 / d base^{2/d} base2/d,而 λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d2),我们需要把k和base进行融合,修改成,base*k的形式形成新的base, λ \lambda λ等于k的指数 2 / ( d − 2 ) ∗ d / 2 ∗ 2 / d = d / ( d − 2 ) ∗ 2 / d 2/(d-2)*d/2*2/d=d/(d-2)*2/d 2/(d2)d/22/d=d/(d2)2/d λ = ( k d / ( d − 2 ) ) 2 / d \lambda=(k^{d/(d-2)})^{2/d} λ=(kd/(d2))2/d,因为2/d在RoPE的代码里面已经计算过了:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)

所以我们则赋值新的base为base * k ** (dim / (dim-2))。

Dynamic NTK

Dynamic在NTK的基础上进行简单的修改,采用NTK的时候更加灵活。
截图源于:RoPE到底是何方神圣(数学推理+优化方法)
在这里插入图片描述
代码实现:

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def forward(self, x, position_ids):
        # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_position_embeddings:#只有长度超过了预训练的阈值,进行NTK
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (
                base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: this may break with compilation

        cos, sin = super().forward(x, position_ids)
        return cos, sin

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/483752.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Git工具的详细使用

一、环境说明 [rootgit ~]# getenforce Disabled [rootgit ~]# systemctl status firewalld ● firewalld.service - firewalld - dynamic firewall daemonLoaded: loaded (/usr/lib/systemd/system/firewalld.service; disabled; vendor preset: enabled)Active: inactive (d…

javaWeb项目-人职匹配推荐系统功能介绍

开发工具:IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架:ssm、Springboot 前端:Vue、ElementUI 关键技术:springboot、SSM、vue、MYSQL、MAVEN 数据库工具:Navicat、SQLyog 项目关键技术 1、JSP技术 JSP(Java…

YOLOv5 | 网络结构 | 详细讲解YOLOv5的网络结构

⭐欢迎大家订阅我的专栏一起学习⭐ 🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀 YOLOv5涨点专栏:http://t.csdnimg.cn/70xZa YOLOv8涨点专栏:http://t.csdnimg.cn…

0基础 三个月掌握C语言(13)-下

数据在内存中的存储 浮点数在内存中的存储 常见的浮点数:3.141592、1E10等 浮点数家族包括:float、double、long double类型 浮点数表示的范围:在float.h中定义 练习 关于(float*)&n: &n:这是一…

【赠书活动】Python编程 从入门到实践 第3版(图灵出品)(文末送书-进行中)

编辑推荐 适读人群 :本书适合对Python感兴趣的所有读者阅读。 编程入门就选蟒蛇书! 【经典】Python入门经典,常居Amazon等编程类图书TOP榜 【畅销】热销全球,以12个语种发行,影响超过 250 万读者 【口碑】好评如潮…

termux+ubuntu使用笔记

文章目录 termuxtermux自动启动服务的方法1. 写.bashrc文件2. 利用termux-services来实现 安装sshtermux 执行定时任务 ubuntu参考文章 这里仅针对自己在使用过程所做的笔记 termux环境下搭建Ubuntu环境可以参考:https://github.com/MFDGaming/ubuntu-in-termux上提…

如何在 Django 中使用 pyecharts

为项目新建一个目录,将其命名为django_pyecharts_demo, 在终端中切换到这个目录,并创建一个虚拟环境。 python -m venv django_pyecharts激活虚拟环境 django_pyecharts\Scripts\activate要停止使用虚拟环境,可执行命令 deactivate创建并激…

Linux V4L2 应用编程

V4L2:Video4Linux2,是 Linux 内核中的一个框架,提供了一套用于视频设备驱动程序开发的 API。它是一个开放的、通用的、模块化的视频设备驱动程序框架,允许 Linux 操作系统和应用程序与各种视频设备(如摄像头、视频采集…

Spring-声明式事务实例(有详细注释)

前提知识 Spring-IOC容器注解方式使用https://blog.csdn.net/m0_61160520/article/details/136784799?spm1001.2014.3001.5501切点表达式https://blog.csdn.net/m0_61160520/article/details/136782885?spm1001.2014.3001.5501 案例 1.创建项目 2.导入依赖 <dependen…

003、Dynamo Python创建楼板

今天我们来创建一块楼板&#xff0c;仍然是找Dynamo里有的节点&#xff0c;可以对照参考练习。 首先&#xff0c;我们打开API手册&#xff0c;在索引里搜索Floor&#xff0c;发现在Floor的方法里&#xff0c;没有找到创建楼板的方法&#xff0c;于是在搜索栏搜索&#xff0c…

python(django(自动化))之流程接口展示功能前端开发

1、创建模板代码如下&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><title>测试平台</title> </head> <body role"document"> <nav class "navbar n…

电脑如何关闭自启动应用?cmd一招解决问题

很多小伙伴说电脑刚开机就卡的和定格动画似的&#xff0c;cmd一招解决问题&#xff1a; CtrlR打开cmd,输入&#xff1a;msconfig 进入到这个界面&#xff1a; 点击启动&#xff1a; 打开任务管理器&#xff0c;禁用不要的自启动应用就ok了

LangChain核心模块 Retrieval——文本嵌入模型、Vector stores

Text embedding models 文本嵌入模型 检索的另一个关键部分是为文档创建嵌入。 Embeddings 类是设计用于与文本嵌入模型交互的类。 Embeddings创建一段文本的矢量表示&#xff0c;这样我们就可以在向量空间中思考文本&#xff0c;并执行语义搜索之类的操作&#xff0c;在向…

详解库和程序运行过程

我最近开了几个专栏&#xff0c;诚信互三&#xff01; > |||《算法专栏》&#xff1a;&#xff1a;刷题教程来自网站《代码随想录》。||| > |||《C专栏》&#xff1a;&#xff1a;记录我学习C的经历&#xff0c;看完你一定会有收获。||| > |||《Linux专栏》&#xff1…

Websocket + Vue使用

这里有一篇文档可以参考一下> 闪现 POM文件 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId><version>2.7.0</version> </dependency> WebSocketConf…

IIS7/iis8/iis10安装II6兼容模块 以windows2022为例

因安全狗的提示 安全狗防护引|擎安装失败 可能原因是: IIS7及以上版本末安装1IS6兼容模块! .所以操作解决 如下. 在开始菜单中,找到服务器管理器.找到下图的IIS,右键添加角色和功能,找到web服务器的管理工具选项,iis6管理兼容性 打钩并安装. 如下图

力扣---最长回文子串---二维动态规划

二维动态规划思路&#xff1a; 首先&#xff0c;刚做完这道题&#xff1a;力扣---最长有效括号---动态规划&#xff0c;栈-CSDN博客&#xff0c;所以会有一种冲动&#xff0c;设立g[i]&#xff0c;表示以第i位为结尾的最长回文子串长度&#xff0c;然后再遍历一遍取最大长度即可…

Web前端-JS

JavaScript&#xff0c;简称js&#xff1a;负责网页的行为&#xff08;交互效果&#xff09;。是一门跨平台&#xff0c;面向对象的脚本语言&#xff08;编写出来的语言不需要编译&#xff0c;通过浏览器的解释就可以运行&#xff09; JS引入方式 1.内嵌样式 这样打开页面就会…

【CVPR2024】CricaVPR

【CVPR2024】CricaVPR: Cross-image Correlation-aware Representation Learning for Visual Place Recognition 这个论文提出了一种具有跨图像相关性的鲁棒全局表示方法用于视觉位置识别&#xff08;VPR&#xff0c;Visual Place Recognition &#xff09;任务&#xff0c;命…

Linux系统——iptables超细致解释

目录 内核如何处理数据包流程图 一、表 二、链 三、表、链、规则的关系 四、数据报文进/出节点经过哪些规则 五、NAT——网络地址转换 1.SNAT 2.DNAT 内核如何处理数据包流程图 规则是管理员对数据包制定的一种触发机制&#xff0c;即当数据包达到某种条件&#xff0c;…