【LLM】LongRoPE:LLM上下文窗口扩展方法及非官方实现

前言

目前,大多数LLMs的上下文窗口限制在4k个标记左右,这意味着模型在处理超过这个长度的文本时性能会下降。这种限制对于需要大量上下文信息的场景,虽然可以通过在更长的文本上进行微调来将预训练LLM的上下文窗口扩展上下文窗口,但要进一步扩展上下文窗口面临着三个主要挑战:

  1. 新位置索引的未训练引入了许多灾难性值,导致分布外问题,使得微调难以收敛。
  2. 微调通常需要相应长度的文本。然而,当前数据集中特别是超过1000k的长文本非常有限。此外,对超长文本进行训练计算成本高昂,需要大量的训练时间和GPU资源。
  3. 当扩展到极长的上下文窗口时,注意力会变得分散,因为它需要在大量的标记位置上进行分配,这会降低模型在原始短上下文上的性能。

paper:LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens

link:https://arxiv.org/abs/2402.13753

LongRoPE

创新点

  1. 通过有效搜索识别并利用了位置插值中的两种非均匀性,为微调提供了更好的初始化,并在非微调情况下实现了8倍的扩展。
  2. 引入了一种渐进式扩展策略,首先对长度为256k的LLM进行微调,然后在微调后的扩展LLM上进行第二次位置插值,以实现2048k的上下文窗口。
  3. 在8k长度上重新调整LongRoPE,以恢复短上下文窗口的性能。

位置插值中的非均匀性问题

位置插值中的非均匀性问题是指在扩展大型语言模型(LLMs)的上下文窗口时,如何有效地为新增的token位置分配位置嵌入(positional embeddings),以便模型能够在更长的序列上保持或提升性能。在LongRoPE这篇文章中,作者们发现并利用了两种主要的非均匀性,以改进位置插值方法:

  1. RoPE维度的非均匀性

    • RoPE(Rotary Positional Embedding)是一种在Transformer架构中广泛使用的位置嵌入方法,它通过旋转角度来表示token的位置。
    • 不同的RoPE维度具有不同的旋转频率,这意味着低维度(高频率)和高维度(低频率)在表示位置信息时的重要性和敏感性不同。
    • 低维度对于位置信息的变化更敏感,因此在插值时应使用较小的缩放因子,以保持相邻位置token的区分度。
    • 高维度可以承受更大的插值,因为它们对于位置信息的变化不那么敏感。
  2. Token位置的非均匀性

    • 在输入序列的开始部分,token接收到的注意力分数较高,这些位置的token对于模型理解上下文尤为重要。
    • 因此,序列初始的token位置应该使用较小的插值,或者不进行插值,以保留这些关键位置的原始RoPE信息。
    • 随着序列位置的增加,可以应用更大的插值因子,因为远离序列开始的token对于模型理解上下文的重要性逐渐降低。

LongRoPE采用了以下方法解决这些非均匀性问题:

  • 有效的位置插值:通过进化搜索算法(evolutionary search)来寻找每个RoPE维度的最佳缩放因子(rescale factors),这些因子基于token位置进行调整。
  • 渐进式扩展策略:首先对长度为256k的LLM进行微调,然后在微调后的模型上进行第二次位置插值,以实现2048k的上下文窗口,而无需直接在极长文本上进行微调。
  • 短上下文窗口性能恢复:通过额外的进化搜索来调整RoPE缩放因子,以便在扩展到极长上下文窗口后,仍能保持在原始短上下文窗口内的高性能。

搜索算法

LongRoPE非官方实现

import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import gzip
import io


class RoPEPositionalEncoding(nn.Module):
    """
    Rotary Position Encoding (RoPE) module.
    """

    def __init__(self, d_model, max_len=5000, base=10000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.base = base
        self.theta = torch.tensor(
            [base ** (-2 * (i // 2) / d_model) for i in range(d_model)]
        )

    def forward(self, positions):
        angles = positions.unsqueeze(-1) * self.theta
        return torch.stack([angles.cos(), angles.sin()], dim=-1).flatten(-2)


def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat):
    """
    Perform non-uniform interpolation on position embeddings.

    Args:
        pos_embed (torch.Tensor): Position embeddings.
        extension_ratio (float): Extension ratio for context window.
        lambda_factors (list): Lambda factors for interpolation.
        n_hat (int): Threshold for applying interpolation.

    Returns:
        torch.Tensor: Interpolated position embeddings.
    """
    d_model = pos_embed.shape[-1]
    interpolated_pos = pos_embed.clone()

    for i in range(d_model // 2):
        mask = torch.arange(pos_embed.shape[-2]) < n_hat
        scale = torch.where(
            mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i]
        )
        interpolated_pos[..., i * 2] *= scale
        interpolated_pos[..., i * 2 + 1] *= scale

    return interpolated_pos


def search_lambda_factors(
    model,
    data,
    extension_ratio,
    population_size,
    num_mutations,
    num_crossovers,
    max_iterations,
):
    """
    Search for optimal lambda factors using evolutionary search.

    Args:
        model (nn.Module): LongRoPE model.
        data (list): List of input sequences.
        extension_ratio (float): Extension ratio for context window.
        population_size (int): Size of the population for evolutionary search.
        num_mutations (int): Number of mutations per iteration.
        num_crossovers (int): Number of crossovers per iteration.
        max_iterations (int): Maximum number of iterations for evolutionary search.

    Returns:
        list: Optimal lambda factors found by the search.
    """
    population = initialize_population(population_size, extension_ratio)

    for i in range(max_iterations):
        perplexities = evaluate_population(model, data, population)
        parents = select_topk(population, perplexities, k=population_size // 2)
        population = mutate(parents, num_mutations) + crossover(parents, num_crossovers)

    return min(population, key=lambda x: evaluate_individual(model, data, x))


def progressive_extension(model, data, base_length, target_length):
    """
    Progressively extend the context window of the model.

    Args:
        model (nn.Module): LongRoPE model.
        data (list): List of input sequences.
        base_length (int): Base context window length.
        target_length (int): Target context window length.

    Returns:
        tuple: (Extended model, lambda factors, base lambda factors)
    """
    curr_model = model
    curr_length = base_length

    while curr_length < target_length:
        lambda_factors, n_hat = search_lambda_factors(
            curr_model, data, curr_length / base_length
        )
        curr_model = fine_tune(curr_model, data, curr_length, lambda_factors, n_hat)
        curr_length *= 2

    lambda_factors_base, _ = search_lambda_factors(
        curr_model, data, curr_length / base_length, max_length=base_length
    )

    return curr_model, lambda_factors, lambda_factors_base


class LongRoPEModel(nn.Module):
    """
    Long Range Rotary Position Encoding (LongRoPE) model.

    This model extends the context window of transformer-based models beyond the
    typical limit by using non-uniform interpolation of rotary position embeddings.
    It enables the model to handle longer input sequences while maintaining the
    ability to capture long-range dependencies.

    Attributes:
        d_model (int): Dimension of the model.
        n_heads (int): Number of attention heads.
        num_layers (int): Number of transformer layers.
        max_len (int): Maximum sequence length.
        rope (RoPEPositionalEncoding): Rotary Position Encoding (RoPE) module.
        transformers (nn.ModuleList): List of transformer encoder layers.
        lambda_factors (list): Lambda factors for non-uniform interpolation.
        lambda_factors_base (list): Lambda factors for the base model.
        extension_ratio (float): Extension ratio for the context window.
        n_hat (int): Threshold for applying interpolation.

    Methods:
        forward(input_ids):
            Perform forward pass on the input sequence.

            Args:
                input_ids (torch.Tensor): Input sequence tensor.

            Returns:
                torch.Tensor: Output embeddings from the model.

        extend_context(data_path, target_length, max_sequence_length, tokenizer):
            Extend the context window of the model.

            Args:
                data_path (str): Path to the input data file.
                target_length (int): Target context window length.
                max_sequence_length (int): Maximum sequence length for input data.
                tokenizer: Tokenizer object for encoding input data.

            Returns:
                LongRoPEModel: Extended LongRoPE model.
    """

    def __init__(self, d_model, n_heads, num_layers, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.rope = RoPEPositionalEncoding(d_model, max_len)
        self.transformers = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)
                for _ in range(num_layers)
            ]
        )
        self.lambda_factors = None
        self.lambda_factors_base = None

    def forward(self, input_ids):
        positions = torch.arange(input_ids.size(1), device=input_ids.device)
        pos_embeddings = self.rope(positions)

        if self.lambda_factors is not None:
            pos_embeddings = non_uniform_interpolation(
                pos_embeddings, self.extension_ratio, self.lambda_factors, self.n_hat
            )

        input_embeddings = input_ids + pos_embeddings

        for transformer in self.transformers:
            input_embeddings = transformer(input_embeddings)

        return input_embeddings

    def extend_context(self, data_path, target_length, max_sequence_length, tokenizer):
        """
        Extend the context window of the model.

        Args:
            data_path (str): Path to the input data file.
            target_length (int): Target context window length.
            max_sequence_length (int): Maximum sequence length for input data.
            tokenizer: Tokenizer object for encoding input data.

        Returns:
            LongRoPEModel: Extended LongRoPE model.
        """
        if tokenizer is None:
            raise ValueError("Tokenizer is required for extending context.")

        self.extension_ratio = target_length / self.rope.max_len

        data = load_data(data_path, tokenizer, max_sequence_length)
        model, lambda_factors, lambda_factors_base = progressive_extension(
            self, data, self.rope.max_len, target_length
        )

        self.lambda_factors = lambda_factors
        self.lambda_factors_base = lambda_factors_base
        self.n_hat = self.rope.max_len // 2

        return model


def load_data(data_path, tokenizer, max_sequence_length):
    """
    Load and preprocess the input data.

    Args:
        data_path (str): Path to the input data file.
        tokenizer: Tokenizer object for encoding input data.
        max_sequence_length (int): Maximum sequence length for input data.

    Returns:
        list: List of preprocessed input sequences.
    """
    if data_path is None or tokenizer is None:
        raise ValueError("Data path and tokenizer are required for loading data.")

    if data_path.endswith(".gz"):
        with gzip.open(data_path, "rt", encoding="utf-8") as file:
            text_data = file.read()
    else:
        with open(data_path, "r", encoding="utf-8") as file:
            text_data = file.read()

    tokenized_data = tokenizer.encode(text_data)

    sequences = [
        tokenized_data[i : i + max_sequence_length]
        for i in range(0, len(tokenized_data), max_sequence_length)
    ]

    tensor_data = [torch.tensor(seq, dtype=torch.long) for seq in sequences]

    return tensor_data


def initialize_population(population_size, extension_ratio):
    """
    Initialize the population for evolutionary search.

    Args:
        population_size (int): Size of the population.
        extension_ratio (float): Extension ratio for context window.

    Returns:
        list: Initialized population.
    """
    population = []

    population.append(torch.ones(512) * extension_ratio)

    ntk_factors = torch.tensor([extension_ratio ** (2 * i / 512) for i in range(512)])
    population.append(ntk_factors)

    yarn_factors = torch.ones(512)
    yarn_factors[:128] = 1.0
    yarn_factors[128:256] = extension_ratio ** (1 / 3)
    yarn_factors[256:] = extension_ratio
    population.append(yarn_factors)

    for _ in range(population_size - 3):
        factors = torch.ones(512)
        for i in range(512):
            if random.random() < 0.1:
                factors[i] = random.uniform(1, extension_ratio)
        population.append(factors)

    return population


def evaluate_individual(model, data, individual):
    """
    Evaluate an individual lambda factor configuration.

    Args:
        model (nn.Module): LongRoPE model.
        data (list): List of input sequences.
        individual (list): Lambda factor configuration.

    Returns:
        float: Perplexity score for the individual.
    """
    model.lambda_factors = individual
    perplexities = []

    for seq in data:
        input_ids = seq.unsqueeze(0)
        output = model(input_ids)
        perplexity = torch.exp(torch.mean(output))
        perplexities.append(perplexity.item())

    return np.mean(perplexities)


def evaluate_population(model, data, population):
    """
    Evaluate the population of lambda factor configurations.

    Args:
        model (nn.Module): LongRoPE model.
        data (list): List of input sequences.
        population (list): Population of lambda factor configurations.

    Returns:
        list: Perplexity scores for each individual in the population.
    """
    perplexities = []
    for individual in population:
        perplexity = evaluate_individual(model, data, individual)
        perplexities.append(perplexity)
    return perplexities


def select_topk(population, perplexities, k):
    """
    Select the top-k individuals from the population based on perplexity scores.

    Args:
        population (list): Population of lambda factor configurations.
        perplexities (list): Perplexity scores for each individual in the population.
        k (int): Number of top individuals to select.

    Returns:
        list: Top-k individuals from the population.
    """
    indices = np.argsort(perplexities)[:k]
    return [population[i] for i in indices]


def mutate(parents, num_mutations):
    """
    Perform mutation on the parent population.

    Args:
        parents (list): Parent population.
        num_mutations (int): Number of mutations to perform.

    Returns:
        list: Mutated population.
    """
    mutated_population = []
    for _ in range(num_mutations):
        parent = random.choice(parents)
        child = parent.clone()
        for i in range(512):
            if random.random() < 0.1:
                child[i] *= random.uniform(0.8, 1.2)
        mutated_population.append(child)
    return mutated_population


def crossover(parents, num_crossovers):
    """
    Perform crossover on the parent population.

    Args:
        parents (list): Parent population.
        num_crossovers (int): Number of crossovers to perform.

    Returns:
        list: Crossover population.
    """
    crossover_population = []
    for _ in range(num_crossovers):
        parent1, parent2 = random.sample(parents, 2)
        child = parent1.clone()
        for i in range(512):
            if random.random() < 0.5:
                child[i] = parent2[i]
        crossover_population.append(child)
    return crossover_population


def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3):
    """
    Fine-tune the LongRoPE model.

    Args:
        model (nn.Module): LongRoPE model.
        data (list): List of input sequences.
        target_length (int): Target context window length.
        lambda_factors (list): Lambda factors for interpolation.
        n_hat (int): Threshold for applying interpolation.
        num_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3.

    Returns:
        nn.Module: Fine-tuned LongRoPE model.
    """
    model.lambda_factors = lambda_factors
    model.n_hat = n_hat
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        for seq in data:
            optimizer.zero_grad()

            seq_len = seq.size(0)
            if seq_len <= target_length:
                input_ids = seq.unsqueeze(0)
            else:
                start_idx = random.randint(0, seq_len - target_length)
                input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0)

            output = model(input_ids)
            loss = torch.mean(output)

            loss.backward()
            optimizer.step()

    return model


# Example usage
data_path = "path/to/your/dataset"
d_model = 512
n_heads = 8
num_layers = 6
base_length = 4096
target_length = 2048 * 1024

data = load_data(data_path)
model = LongRoPEModel(d_model, n_heads, num_layers, base_length)
model = model.extend_context(data, target_length)

input_ids = torch.randn(2, target_length, d_model)
output = model(input_ids)
print(output.shape)  # Expected shape: (batch_size, target_length, d_model)

dad

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/483664.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【鸿蒙系统】 ---OpenHarmony加快本地编译(二)

&#x1f48c; 所属专栏&#xff1a;【鸿蒙系统】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢…

SpringBoot3集成PostgreSQL

标签&#xff1a;PostgreSQL.Druid.Mybatis.Plus&#xff1b; 一、简介 PostgreSQL是一个功能强大的开源数据库系统&#xff0c;具有可靠性、稳定性、数据一致性等特点&#xff0c;且可以运行在所有主流操作系统上&#xff0c;包括Linux、Unix、Windows等。 通过官方文档可以…

MySQL:表的操作

文章目录 创建表查看表结构修改表删除表 前面对于库的操作有了认识后&#xff0c;下面进行表的操作 创建表 以下图为例 创建表其实和定义结构体有点类似&#xff0c;总的来说就是先定义列名&#xff0c;然后后面跟着是列的数据类型&#xff0c;之后在定义结束后可以带上对应的…

校园圈子系统--自带校园跑腿功能,校园交友,校园陪玩,校园交友墙,地图找伴,二手市场等功能。源码交付,支持二开!APP小程序H5等移动端都有。

一、需求分析 在搭建校园论坛平台之前&#xff0c;我们需要进行详细的需求分析。这包括以下几个方面&#xff1a; 1.用户需求 我们需要了解目标用户群体的需求和喜好&#xff0c;包括学生的年龄层次、兴趣爱好、关注话题等。通过调查问卷、访谈等方式收集用户需求&#xff0c;为…

数学算法(算法竞赛、蓝桥杯)--最大公约数,欧几里得算法

1、B站视频链接&#xff1a;G05 最大公约数 欧几里得算法_哔哩哔哩_bilibili 题目链接&#xff1a;[NOIP2001 普及组] 最大公约数和最小公倍数问题 - 洛谷 #include <bits/stdc.h> using namespace std; typedef long long LL; LL x,y,ans;LL gcd(LL a,LL b){return b0?…

包叔推荐12代i3-独显组装电脑主机配置清单

去年Intel第十代i5-依然是主流热选机型。 今年&#xff0c;随着i3-的价格优势越来越大&#xff0c;已经成功取代了i5-。 今天包叔推荐几套12代i3-独立显卡组装电脑主机配置。 列表&#xff1a;一组核心显示配置&#xff0c;其余三组均为独立显示配置。 适合主机预算在2000元至3…

JDK下载配置

一、JDK的作用 Java开发环境&#xff1a;JDK提供了完整的Java开发环境&#xff0c;包含编译器&#xff08;javac&#xff09;、解释器&#xff08;java&#xff09;、打包工具&#xff08;jar&#xff09;、文档生成工具&#xff08;javadoc&#xff09;等一系列工具&#xff0…

【高并发服务器 01】—— 基础知识回顾

接下来四周时间&#xff0c;我将会做一个高并发服务器相关的项目。 前置知识&#xff1a;操作系统系统编程、网络编程、基础的数据结构、C语言。 开发环境&#xff1a;VMware虚拟机&#xff1a;Ubuntu 20.04.6 LTS、vscode 今天先回顾一些基础知识。 1.文件与IO 标准IO&#…

软件测试教程 性能测试概论

文章目录 1. 性能测试实施的流程1.1 常见的性能问题1.2 性能测试是什么&#xff1f;1.3 性能测试和功能测试之间的区别1.4 什么样的系统/软件表现属于性能好&#xff0c;什么样的软件性能表现属于性能不好1.5 为什么要进行性能测试1.6 性能测试实施的流程1.7 常见的性能指标以及…

Python虚拟环境conda的安装使用

文章目录 conda虚拟环境的详细步骤和注意事项&#xff1a;**安装Conda****创建Conda虚拟环境****激活Conda虚拟环境****安装Python包****管理Conda环境****其他优势与特性** 相较于venv&#xff0c;使用conda管理虚拟环境有以下优势&#xff1a;**性能****资源占用****其他性能…

【jvm】jinfo使用

jinfo介绍 jinfo 是一个命令行工具&#xff0c;用于查看和修改 Java 虚拟机&#xff08;JVM&#xff09;的配置参数。它通常用于调试和性能调优。 使用 jinfo 命令&#xff0c;你可以查看当前 JVM 的配置参数&#xff0c;包括堆大小、线程数、垃圾回收器类型等。此外&#xf…

立体统计图表绘制方法(分离式环图)

立体统计图表绘制方法&#xff08;分离式环形图&#xff09; 记得我学统计学的时候&#xff0c;那些统计图表大都是平面的框框图&#xff0c;很呆板&#xff0c;就只是表现出统计的意义就好了。在网络科技发展进步的当下&#xff0c;原来一些传统的统计图表都有了进一步的创新。…

unity 多屏幕操作

想了解基础操作请移步&#xff1a;&#xff08;重点是大佬写的好&#xff0c;这里就不再赘述&#xff09; Unity 基础 之 使用 Display 简单的实现 多屏幕显示的效果_unity display-CSDN博客 在panel上也可以通过获取 Canvas&#xff0c;来达到切换多屏幕的操作&#xff0c; …

数据结构基础:一篇文章教你单链表(头插,尾插,查找,头删等的解析和代码)

和我一起学编程呀&#xff0c;大家一起努力&#xff01; 这篇文章耗时比较久&#xff0c;所以大家多多支持啦 链表的结构及结构 概念&#xff1a;链表是⼀种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表 中的指针链接次序实现的。 理解&a…

在MongoDB建模1对N关系的基本方法

“我在 SQL 和规范化数据库方面拥有丰富的经验&#xff0c;但我只是 MongoDB 的初学者。如何建立一对 N 关系模型&#xff1f;” 这是我从参加 MongoDB 分享日活动的用户那里得到的最常见问题之一。 我对这个问题没有简短的答案&#xff0c;因为方法不只有一种&#xff0c;还有…

LangChain核心模块 Retrieval——文档加载器

Retrieval ​ 许多LLM申请需要用户的特定数据&#xff0c;这些数据不属于模型训练集的一部分&#xff0c;实现这一目标的主要方法是RAG(检索增强生成)&#xff0c;在这个过程中&#xff0c;将检索外部数据&#xff0c;然后在执行生成步骤时将其传递给LLM。 ​ LangChain 提供…

探索设计模式的魅力:精准、快速、便捷:游标尺模式在软件设计中的三大优势

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;并且坚持默默的做事。 精准、快速、便捷&#xff1a;游标尺模式在软件设计中的三大优势 文章目录 一、案例场景&…

黑马程序员:C++核心编程——3.函数提高

目录 1.函数默认参数 2.函数占位数 3.函数重载* 1.函数默认参数 形参列表中可以有默认值。 注意&#xff1a;如果某个位置有默认值&#xff0c;那么这个位置之后的都要有。如果函数声明有默认值了&#xff0c;函数实现的时候就不能有默认值&#xff08;防止默认值不同而导…

蓝桥杯第十五届抱佛脚(二)竞赛中的数据结构

蓝桥杯第十五届抱佛脚&#xff08;二&#xff09;内置数据结构 文章目录 蓝桥杯第十五届抱佛脚&#xff08;二&#xff09;内置数据结构在竞赛中常见的数据结构数组(Array)链表(Linked List)栈(Stack)队列(Queue)树(Tree)映射(Map) 内置数据结构的快速使用迭代器&#xff08;It…

综合知识篇20-基于中间件的开发新技术考点(2024年软考高级系统架构设计师冲刺知识点总结系列文章)

专栏系列文章: 2024高级系统架构设计师备考资料(高频考点&真题&经验)https://blog.csdn.net/seeker1994/category_12593400.html案例分析篇00-【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例…