Encoder-decoder 与Decoder-only 模型之间的使用区别

承接上文:Transformer Encoder-Decoer 结构回顾
笔者以huggingface T5 transformer 对encoder-decoder 模型进行了简单的回顾。

由于笔者最近使用decoder-only模型时发现,其使用细节和encoder-decoder有着非常大的区别;而huggingface的接口为了实现统一化,很多接口的使用操作都是以encoder-decoder的用例为主(如T5),导致在使用hugging face运行decoder-only模型时(如GPT,LLaMA),会遇到很多反直觉的问题。

本篇进一步涉及decoder-only的模型,从技术细节上,简单列举一些和encoder-decoder模型使用上的区别。

以下讨论均以huggingface transformer接口为例。

1. 训练时input与output合并

对于encoder-decoder模型,我们需要把input和output 分别 喂给模型的encoder和decoder。也就是说,像T5这种模型,会有一个单独的encoder编码input的上下文信息,由decoder解码output和计算loss。简而言之,如果是encoder-decoder模型,我们 只把 output喂给decoder(用于计算loss,teacher forcing),这对于我们大多是人来说是符合直觉的。

但decoder-onyl模型,需要你手动地将input和output合并在一起,作为decoder的输入。因为,从逻辑上讲,对于decoder-only模型而言,它们并没有额外的encoder去编码input的上下文,所以需要把input作为“前文”,让decoder基于这一段“前文”,把“后文”的output预测出来(auto regressive)。因此,在训练时,input和output是合并在一起喂给decoder-only 模型的(input这段前文必须要有)。这对于大多数习惯了使用encoder-decoder的人来说,是很违反直觉的。

于此相对应的,decoder-only 模型计算loss时的“答案”(ground truth reference)也得是input和output的合并(因为计算loss的时候,输入token representation得和输出ground truth reference要对应)。而这样一来,decoder 的loss就既包含output,又会涉及input上的预测error。由于我们大多数情况下不希望去惩罚decoder模型在input上的error,一般的做法是,训练时我们只计算output上的loss ,即,把input token对应的ground truth全部设置为-100(cross entropy ignore idx)。

2. 测试时,手动提取output

encoder-decoder模型的输出就是很“纯粹”的output(模型的预测结果)

但decoder-only模型,在做inference的时候,模型的输出就会既包含output也包含input(因为input也喂给了decoder)

所以这种情况下,decoder-only 模型我们需要手动地把output给分离出来。

如下所示:
在这里插入图片描述
笔者也很无语,huggingface的 model.generate() 接口为什么不考虑一下,对于decoder-only模型,设置一个额外参数,能够自动提取output(用input token的数量就可以自动定位output,不难实现的)

3. batched inference的速度和准确度

如果想要批量地进行预测,简单的做法就是把一个batch的样本,进行tokenization之后,在序列末尾(右边)pad token,补足长度差异。这对于encoder-decoder 模型来说是适用的。

但是对于decoder-only模型,你需要在训练时,额外地将tokenizer的pad 位置设置为左边:
在这里插入图片描述
因为你一旦设置为默认的右边,模型在做inference时,一个batch的样本,所有pad token就都在序列末尾。而decoder only模型是auto regressive地生成新token的,最右边的pad token就很容易影响到模型生成的内容

有人就会问,这个时候和encoder-decoder模型一样,用attention mask把那些pad tokens都遮掉,不就不会影响模型生成的内容了吗?

但是很遗憾,对于decoder-only模型,huggingface model.generate 接口并不支持输入attention mask(如下面官方api所描述):
在这里插入图片描述
所以你如果想batched inference,不得不在训练和测试的时候,把tokenizer的pad设置在左手边,以降低pad token对生成内容的影响;或者干脆设置batch size为1

经过笔者自己的实验,推理时batch size==1能够显著提升推理准确度
以下为笔者测试的性能表现排序:

  1. batch size 为1 (完全没有pad token的影响),性能最好
  2. batch size不为1,pad token在左侧(pad token影响降低,但还是会损伤推理性能),在部分任务上,性能降低较为明显
  3. batch size不为1,pad token在右侧(pad token对大量的样本的预测产生极大干扰,模型最后的输出基本都是乱码),性能接近于0

总之,当前huggingface的生成接口对于decoder-only模型的支持度并不是非常高,decoder-only模型推理的速度和精度,以及接口使用的便捷程度,都会逊于encoder-decoder模型。

下面是GPT-4对于为何decoder-only模型做推理时不需要attention mask(最后一句话是亮点。。。):

The instruction from the Hugging Face Transformers documentation regarding the use of input_ids for decoder-only models during generation (like GPT-2, GPT-3) stems from the typical use case and architecture of these models.
Why input_ids?
Decoder-only models like GPT-2 or GPT-3 are designed to generate text sequentially. When generating text, the model uses its previously generated tokens as context to generate the next token. Since these models are autoregressive, they generate one token at a time, and the sequence of generated tokens grows until it reaches a stopping condition (like the max_length or eos_token_id). During this process, the only necessity is to know which tokens have been generated (encoded as input_ids), so the model knows the context based on which it should generate the subsequent token(s).
About attention_mask
While it’s true that attention masks are used to prevent the model from attending to certain positions within the input sequence (for example, padding tokens in a batched input scenario), during the generation phase, the need to specify an attention_mask explicitly is not as critical as during training or evaluation. This is because, during generation, the model is focusing on the tokens it has already generated and the positions it needs to fill next – all of which are relevant and none should be masked out.
That said, for certain generation scenarios or model configurations, you might want to control the attention mechanism explicitly. The Hugging Face Transformers library does allow for passing additional arguments like attention_mask in some contexts, but for the default use case of generating text with decoder-only models, including the attention_mask is not necessary.
Batched Inference
For batched inference, input_ids are necessary to understand the sequence (or sequences) from which to start generating text. If the sequences within a batch have different lengths, padding might be necessary to shape the input tensor appropriately. In such cases, an attention_mask becomes relevant because it allows the model to distinguish between the actual content and the padding. Therefore, while the generation call as you’ve shown doesn’t explicitly mention attention_mask, depending on the specifics of your use case and the model’s implementation, you might still need or want to provide it to ensure proper handling of batched inputs.

其他待补充

。。。

总结

总而言之,个人认为 huggingface目前的模型接口,对于decoder-only模型的使用并不是很友好。在使用过程中需要注意很多细节,不然会遇到许多问题,而这些问题,encoder-decoder模型是完全不会有的。

参考:

  • 官方接口
  • alpaca-lora

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/404468.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设计模式——抽象工厂模式

定义: 抽象工厂模式(Abstract Factory Pattern)提供一个创建一系列或相互依赖对象的接口,而无须指定它们具体的类。 概述:一个工厂可以提供创建多种相关产品的接口,而无需像工厂方法一样,为每一个产品都提供一个具体…

【开源】JAVA+Vue.js实现医院门诊预约挂号系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的遥感目标检测系统(Python+PySide6界面+训练代码)

摘要:本文介绍了一种基于深度学习的遥感目标检测系统系统的代码,采用最先进的YOLOv8算法并对比YOLOv7、YOLOv6、YOLOv5等算法的结果,能够准确识别图像、视频、实时视频流以及批量文件中的遥感目标。文章详细解释了YOLOv8算法的原理&#xff0…

力扣226 翻转二叉树 Java版本

文章目录 题目描述解题思路代码 题目描述 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1] 示例 2: 输入:root…

智慧物业信息管理系统平台及APP建设项目

随着城市化步伐的不断加快,物业管理成为城市建设中不可或缺的一部分。为了更好地满足各方对物业管理的全面需求,智慧物业信息管理系统平台及APP项目,融合了八大子系统,旨在为其提供更全面、高效的物业管理解决方案。 1. 物业信用…

盘点自动化汽车生产线设备 数据采集分析联合各设备

1.机器人自动装配线 机器人自动装配线已成为汽车制造业中的常见场景。这些机器人在汽车组装的各个环节发挥关键作用,从焊接和铆接到零部件组装。它们不仅提高了装配速度,还确保了产品的一致性,降低了废品率。 2.3D打印技术 3D打印技术正在汽车…

Draw.io绘制UML图教程

一、draw.io介绍 1、draw.io简介 draw.io 是一款强大的免费在线图表绘制工具,支持创建流程图、组织结构图、时序图等多种图表类型。它提供丰富的形状库、强大的文本编辑和样式设置功能,使用户能够轻松创建专业级图表。draw.io 具有用户友好的界面&…

UTONMOS开启数智龙年,打造元宇宙游戏圈新名片

新年已过,全国各个城市早已客流涌动、热闹非凡。这种繁华景象不仅存在于现实世界,也被复刻到元宇宙的虚拟空间中。 据介绍,UTONMOS“源起山海-神念无界”元宇宙游戏是以原创IP玄幻神话故事“元宇宙史纪”为蓝本打造的元宇宙游戏空间&#xf…

图文说明Linux云服务器如何更改实例镜像

一、应用场景举例 在学习Linux的vim时,我们难免要对vim进行一些配置,这里我们提供一个vim插件的安装包: curl -sLf https://gitee.com/HGtz2222/VimForCpp/raw/master/install.sh -o./install.sh && bash ./install.sh 但是此安装包…

可变形注意力(Deformable Attention)及其拓展

文章目录 一、补充知识(一)可变形卷积(Deformable Convolution)(二)多头注意力机制 二、可变形注意力模块三、可变形自注意力模块(一)偏移模块:(二&#xff0…

“比特币暴涨讯号显现”!减半牛市来临前还有一次震撼回撤?“52000美元保卫战”已经打响!

虽然比特币在20日一度冲高至近5.3万美元大关,创下自2021年11月来新高,但随后开始回落,在51000美元至52000美元之间反复窄幅波动,甚至在21日晚一度跌至50625美元。比特币的未来走势,已牵动不少投资者的心。 自1月底比特…

华为OD机试真题-来自异国的客人-2023年OD统一考试(C卷)--Python3--开源

题目: 考察内容: 10进制转为任何进制 代码: """ 题目分析:输入: k --物品价值;n 幸运数字;m 进制 输出: 幸运数字的个数 异常;0 eg; 10 2 4思路&…

STM32—触摸键

目录 1 、 电路构成及原理图 2 、编写实现代码 3、代码讲解 4、烧录到开发板调试、验证代码 5、检验效果 此笔记基于朗峰 STM32F103 系列全集成开发板的记录。 1 、 电路构成及原理图 触摸键简单的了解就是一次电容的充放电过程。从原理图可以看出,触摸键 …

代码随想录算法训练营第21天—回溯算法01 | ● 理论基础 ● *77. 组合

理论基础 回溯是一种纯暴力搜索的方法,它和递归相辅相成,通常是执行完递归之后紧接着执行回溯相较于以往使用的for循环暴力搜索,回溯能解决更为复杂的问题,如以下的应用场景应用场景 组合问题 如一个集合{1,2,3,4},找…

Linux 权限详解

目录 一、权限的概念 二、权限管理 三、文件访问权限的相关设置方法 3.1chmod 3.2chmod ax /home/abc.txt 一、权限的概念 Linux 下有两种用户:超级用户( root )、普通用户。 超级用户:可以再linux系统下做任何事情&#xff…

Vant轮播多个div结合二维数组的运用

需求说明 在开发H5的时候,结合Vant组件的轮播组件Swipe实现如下功能。我们查阅vant组件库官方文档可以得知,每个SwipeItem组件代表一个卡片,实现的是每屏展示单张图片或者单个div轮播方式,具体可以查阅:Vant 2 - 轻量、…

如何计算点、线、面关系

从公众号转载,关注微信公众号掌握更多技术动态 --------------------------------------------------------------- 普遍有三种方式 面积和判别法:判断目标点与多边形的每条边组成的三角形面积和是否等于该多边形,相等则在多边形内部。 夹角…

c#程序,oracle使用Devart驱动解决第第三方库是us7ascii,数据乱码的问题

最近做项目,要跟对方系统的库进行读写,结果发现对方采用的是oracle的us7ascii编码,我们系统默认采用的是ZHS16GBK,导致我们客户端读取和写入对方库的数据都是乱码,搜索网上,发现需要采用独立的oracle驱动去…

网络知识

目录 IP地址(Internet protocol address) —— 互联网协议地址 子网掩码 网关 路由 DNS(Domain Name Server) —— 域名服务器 IP地址(Internet protocol address) —— 互联网协议地址 子网掩码 作用:划分网段 网络部分相同的IP地址&a…

简介高效的 CV 入门指南: 100 行实现 InceptionResNet 图像分类

简介高效的 CV 入门指南: 100 行实现 InceptionResNet 图像分类 概述InceptionResNetInception 网络基本原理关键特征 ResNet 网络深度学习早期问题残差学习 InceptionResNet 网络InceptionResNet v1InceptionResNet v2改进的 Inception 模块更有效的残差连接设计 100 行实现 I…