LLM padding left or right

参考博客:
大部分的大模型(LLM)采用左填充(left-padding)的原因
注:文章主要内容参考以上博客,及其评论区,如有侵权,联系删除。

最近在看大模型相关内容的时候,突然想到我实习时候一直一知半解的问题,大模型采用padding left还是right。以上内容,根据我自身的理解,如有错误,请大家批评指正。这里的padding 的位置,仅仅考虑推理时候的left or right。

为什么需要padding?

因为输入的序列长度不一致,为了能够在Batch内进行数据推理,所以需要增加padding,使输入序列的长度是一致的。

为什么要考虑padding left or right?

在BERT时代,通常padding的方式为right,即在右侧进行padding,因为BERT在初始位置有个特殊token,[CLS]左侧进行padding,不好操作。
在大模型时代,可能更偏向于左侧padding, 为什么进行左侧padding,我理解主要原因可能是为了更好也是为了更好操作。
最直观的想法,如果右侧进行padding,生成的序列中间会存在padding token,还需要进一步处理padding token。如下图所示:
在这里插入图片描述

如果采用左侧的padding 的方式则是比较方便处理或者操作。在进行batch推理的时候左侧,进行操作,非常的方便, 如下图所示:
在这里插入图片描述

摘取一些比较好的解释

在这里插入图片描述

大模型batch推理时只能padding left?

大模型在推理时候,同样可以采用padding right,只不过需要增加一些步骤,没有padding left这么直观。
由于只找到LLama和Gemma的推理代码,所以仅仅参考这两个代码进行解释。
参考代码:
LLama
Gemma
下面是Gemma推理代码:

 def generate(
        self,
        prompts: Union[str, Sequence[str]],
        device: Any,
        output_len: int = 100,
        temperature: Union[float, None] = 0.95,
        top_p: float = 1.0,
        top_k: int = 100,
    ) -> Union[str, Sequence[str]]:
        """Generates responses for given prompts using Gemma model."""
        # If a single prompt is provided, treat it as a batch of 1.
        is_str_prompt = isinstance(prompts, str)
        if is_str_prompt:
            prompts = [prompts]

        batch_size = len(prompts)
        prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
        min_prompt_len = min(len(p) for p in prompt_tokens)
        max_prompt_len = max(len(p) for p in prompt_tokens)
        max_seq_len = max_prompt_len + output_len
        assert max_seq_len <= self.config.max_position_embeddings

        # build KV caches
        kv_caches = []
        for _ in range(self.config.num_hidden_layers):
            size = (batch_size, max_seq_len, self.config.num_key_value_heads,
                    self.config.head_dim)
            dtype = self.config.get_dtype()
            k_cache = torch.zeros(size=size, dtype=dtype, device=device)
            v_cache = torch.zeros(size=size, dtype=dtype, device=device)
            kv_caches.append((k_cache, v_cache))

        # prepare inputs
        token_ids_tensor = torch.full((batch_size, max_seq_len),
                                      self.tokenizer.pad_id, dtype=torch.int64)
        input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
                                            self.tokenizer.pad_id,
                                            dtype=torch.int64)
        for i, p in enumerate(prompt_tokens):
            token_ids_tensor[i, :len(p)] = torch.tensor(p)
            input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
                p[:min_prompt_len])
        token_ids_tensor = token_ids_tensor.to(device)
        input_token_ids_tensor = input_token_ids_tensor.to(device)
        prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
        input_positions_tensor = torch.arange(0, min_prompt_len,
                                              dtype=torch.int64).to(device)
        mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
                                 -2.3819763e38).to(torch.float)
        mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
        curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
        output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
            device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
            device)

        # Prefill up to min_prompt_len tokens, then treat other prefill as
        # decode and ignore output.
        for i in range(max_seq_len - min_prompt_len):
            next_token_ids = self(
                input_token_ids=input_token_ids_tensor,
                input_positions=input_positions_tensor,
                kv_write_indices=None,
                kv_caches=kv_caches,
                mask=curr_mask_tensor,
                output_positions=output_positions_tensor,
                temperatures=temperatures_tensor,
                top_ps=top_ps_tensor,
                top_ks=top_ks_tensor,
            )

            curr_prompt_mask = prompt_mask_tensor.index_select(
                1, output_index).squeeze(dim=1)
            curr_token_ids = token_ids_tensor.index_select(
                1, output_index).squeeze(dim=1)
            output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
                                           next_token_ids).unsqueeze(dim=1)
            token_ids_tensor.index_copy_(1, output_index, output_token_ids)

            input_token_ids_tensor = output_token_ids
            input_positions_tensor = output_index.unsqueeze(dim=-1)
            curr_mask_tensor = mask_tensor.index_select(2,
                                                        input_positions_tensor)
            output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
                device)
            output_index = output_index + 1

        # Detokenization.
        token_ids = token_ids_tensor.tolist()
        results = []
        for i, tokens in enumerate(token_ids):
            trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
                                    + output_len]
            if self.tokenizer.eos_id in trimmed_output:
                eos_index = trimmed_output.index(self.tokenizer.eos_id)
                trimmed_output = trimmed_output[:eos_index]
            results.append(self.tokenizer.decode(trimmed_output))

        # If a string was provided as input, return a string as output.
        return results[0] if is_str_prompt else results

以下面的图为例子进行讲解:
在这里插入图片描述
在推理的时候,先进行右侧padding,使长度一致。选择最短的长度同时进行处理,上图为1, 那么我们同时处理batch min(即1), 然后开始逐个token进行推理,怎么避免下图的形式呢在这里插入图片描述
核心代码为下面内容:output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, next_token_ids).unsqueeze(dim=1),主要的思路就是,如何当前的token为padding token 则填充上一步预测的token的结果,否则填充当前的token。
例如:
在这里插入图片描述
当前位置的token不为pading,则token还是3,不为2这个位置预测的token。
在这里插入图片描述
当前token为padding,则为2这个位置预测的token。
以上就是大模型采用right padding的方法。

总结

感觉pading left or right, 其实无所谓,主要就是为了方便。根据实际情况的具体需求,进行使用,用的正确,方便即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/557803.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

经典网络解读—IResNet

论文&#xff1a;Improved Residual Networks for Image and Video Recognition&#xff08;2020.4&#xff09; 作者&#xff1a;Ionut Cosmin Duta, Li Liu, Fan Zhu, Ling Shao 链接&#xff1a;https://arxiv.org/abs/2004.04989 代码&#xff1a;https://github.com/iduta…

Ubuntu22.04.4 - 网络配置 - 笔记

一、设置固定ip 1、cd /etc/netplan 查看文件夹下的配置文件 我这里叫 00-installer-config.yaml 2、sudo nano /etc/netplan/00-installer-config.yaml 完成配置后&#xff0c;按下Ctrl O保存更改&#xff0c;然后按下Ctrl X退出nano编辑器。 3、sudo netplan apply 4、ip …

C++ 继承(一)

一、继承的概念 继承是面向对象编程中的一个重要概念&#xff0c;它指的是一个类&#xff08;子类&#xff09;可以从另一个类&#xff08;父类&#xff09;继承属性和方法。子类继承父类的属性和方法后&#xff0c;可以直接使用这些属性和方法&#xff0c;同时也可以在子类中…

springboot+vue全栈开发【2.前端准备工作篇】

目录 前言准备工作Vue框架介绍MVVM模式 快速入门导入vue在vscode创建一个页面 前言 hi&#xff0c;这个系列是我自学开发的笔记&#xff0c;适合具有一定编程基础&#xff08;html、css那些基础知识要会&#xff01;&#xff09;的同学&#xff0c;有问题及时指正&#xff01;…

语雀如何显示 Markdown 语法

正常的文章链接 https://www.yuque.com/TesterRoad/t554s28/eds3pfeffefw12x94wu8rwer8o 访问后是文章&#xff0c;无法复制 markdown 的内容 在链接后增加参数 /markdown?plaintrue&linebreakfalse&anchorfalse 直接显示代码

ros2 RVIZ2 不显示urdf模型

ros2 RVIZ2 不显示urdf模型 我的情况是 &#xff1a; 没有如何报错但是不显示 Description Topic 手动写上 /robot_description

python使用tkinter和ttkbootstrap制作UI界面(二)

这次讲解UI界面常用的主键&#xff0c;延续上文的框架进行编写&#xff0c;原界面如下&#xff1a; Combobox组件应用&#xff08;下拉框&#xff09; """Combobox组件"""global comvalue_operatorcomvalue_operator tk.StringVar()value_ope…

就业班 第三阶段(nginx) 2401--4.19 day3 nginx3

二、企业 keepalived 高可用项目实战 1、Keepalived VRRP 介绍 keepalived是什么keepalived是集群管理中保证集群高可用的一个服务软件&#xff0c;用来防止单点故障。 ​ keepalived工作原理keepalived是以VRRP协议为实现基础的&#xff0c;VRRP全称Virtual Router Redundan…

黑马python-python基础语法

1.注释&#xff1a; 单行注释&#xff1a;#注释内容 多行注释&#xff1a; """ 第一行 第二行 第三行 """ 或 第一行 第二行 第三行 2.定义变量 变量名值 变量名满足标识符命名规则即可 3.标识符命名规则&#xff1a; 有数组、字母、下划线组成…

欢乐钓鱼大师加速、暴击内置脚本,直接安装

无需手机root,安装软件即可使用&#xff0c;仅限安卓。 网盘自动获取 链接&#xff1a;https://pan.baidu.com/s/1lpzKPim76qettahxvxtjaQ?pwd0b8x 提取码&#xff1a;0b8x

从零开始学习Linux(4)----yum和vim

1.Linux软件包管理器yum Linux中我们要进行工具/指令/程序&#xff0c;安装&#xff0c;检查卸载等&#xff0c;需要yum的软件 安装软件的方式&#xff1a; 源代码安装---交叉编译的工具rpm包直接安装yum/apt-get yum是我们Linux预装的一个指令&#xff0c;搜索&#xff0c;下…

claude3国内注册

claude3国内注册 Claude 3 作为大型语言模型的强大之处在于其先进的算法设计和大规模训练数据的应用&#xff0c;能够执行复杂和多样化的任务。以下是 Claude 3 主要的强项&#xff1a; 接近人类的理解能力&#xff1a;Claude 3 能够更加深入地理解文本的含义&#xff0c;包括…

外贸企业邮箱有什么用?如何选择适合的外贸企业邮箱?

外贸公司每天都需要与各个国家的客户打交道&#xff0c;通过邮箱聊天、谈合作。由于语言、文化差异&#xff0c;一个小错误可能会致使业务失败和数据泄漏风险。做为外贸企业的重要沟通工具&#xff0c;企业电子邮件的功效是显而易见的。那样&#xff0c;外贸企业邮箱有什么用&a…

【在本机上部署安装禅道详细操作步骤2024】

1、进入禅道官网&#xff0c;选择开源版进行下载&#xff1a;禅道下载 - 禅道开源项目管理软件 2、根据自身电脑环境选择合适的版本&#xff0c;此处是windows版本&#xff1a; 3、双击打开下载好的.exe安装包-选择安装目录-【Extract】-然后就等着安装完成就行了 4、安装完成…

JavaSE高阶篇-反射

第一部分、Junit单元测试 1&#xff09;介绍 1.概述:Junit是一个单元测试框架,在一定程度上可以代替main方法,可以单独去执行一个方法,测试该方法是否能跑通,但是Junit是第三方工具,所以使用之前需要导入jar包 2&#xff09;Junit的基本使用&#xff08;重点啊&#xff09; 1.…

【jinja2】模板渲染

HTML文件 return render_template(index.html)h1: 一级标题 变粗变大(狗头 <

VSCode断点调试(ROS)

0、安装ros插件 在扩展商店中安装ROS插件&#xff08;Microsoft&#xff09; 1、修改CMakeList.txt # set(CMAKE_BUILD_TYPE "Release") // 注释Release模式 set(CMAKE_BUILD_TYPE "Debug") // 设置为Debug模式 # set(CMAKE_CXX_FLAGS_RELEASE &…

PolarDB MySQL 版 Serverless评测|一文带你体验什么是极致弹性

PolarDB MySQL 版 Serverless评测|一文带你体验什么是极致弹性 什么是PolarDB MySQL 版PolarDB MySQL版体验弹性压测一弹性压测二弹性压测三弹性缩容 操作体验 在体验PolarDB MySQL 版之前&#xff0c;这里先为大家提供一下PolarDB MySQL 版 Serverless评测入口&#xff0c;以供…

五种主流数据库:集合运算

关系型数据库中的表与集合理论中的集合类似&#xff0c;表是由行&#xff08;记录&#xff09;组成的集合。因此&#xff0c;SQL 支持基于数据行的各种集合运算&#xff0c;包括并集运算&#xff08;Union&#xff09;、交集运算&#xff08;Intersect&#xff09;和差集运算&a…

neo4j使用详解(十八、java driver使用及性能优化<高级用法>——最全参考)

Neo4j系列导航&#xff1a; neo4j安装及简单实践 cypher语法基础 cypher插入语法 cypher插入语法 cypher查询语法 cypher通用语法 cypher函数语法 neo4j索引及调优 neo4j java Driver等更多 1.依赖引入 <dependency><groupId>org.neo4j.driver</groupId><…