Contrastive Imitation Learning

机器人模仿学习中对比解码的一致性采样

摘要

本文中,我们在机器人应用的对比模仿学习中,利用一致性采样来挖掘演示质量中的样本间关系。通过在排序后的演示对比解码过程中,引入相邻样本间的一致性机制,我们旨在改进用于机器人学习的稳健在线人类行为克隆方法。我们的模型基于一致性策略,在普通行为克隆和一致性选择的行为克隆中,均优于基线行为克隆方法。我们成功地将一致性采样与对比学习相结合,用于行为克隆,证明了我们注释人类演示方法的可行性。为了增强处理奖励周期性变化的稳健性,我们添加了时间噪声,以确保在存在时间相关性的情况下仍能保持性能。实验表明,在 PushT 任务中,二元和连续评分方法的性能相似,二元对比模仿学习的最终成功率达到 92.4%。未来的研究方向包括将相似状态与好坏演示进行成对映射、将该方法扩展到更多任务,以及实现在线强化学习。

1. 引言

在动态环境中从人类演示中学习面临着巨大挑战,尤其是在机器人领域,采样误差可能导致次优甚至危险的动作。由于人类演示的不确定性带来的实际挑战,给动作展开中的实施和采样带来了困难。并非所有采样序列都是一致且同样最优的,随机采样可能导致不稳定甚至危险的结果。在此,我试图开发一种行为克隆方法,利用样本间关系来学习人类演示。通过在排序后的演示对比解码过程中,引入相邻样本间的一致性机制,我们旨在尝试改进用于机器人学习的稳健在线人类行为克隆方法。

2. 相关工作

这项工作主要受到自然语言处理领域中用于开放式文本生成的对比解码研究的启发 [Li 等人(2022 年),Li、Holtzman、Fried、Liang、Eisner、Hashimoto、Zettlemoyer 和 Lewis]。语言模型(LM)在生成随机多样且准确的输出方面存在挑战,贪婪决策和最大概率并不是一个理想的解码目标。在模仿学习中,类似的贪婪方法会产生短且重复的序列轨迹。对比解码(CD)提出了一个受合理性约束的对比目标,它返回专家和新手可能性之间的差异。我们借鉴这项工作来定义损失函数和目标框架,分别将 “好” 和 “坏” 的人类演示视为专家和新手模型,以此进行训练。

此外,我们使用一致性采样来区分人类演示的质量。Sekhari 等人提出了选择性采样,主动向有噪声的专家询问反馈。他们的选择性采样算法适用于一般函数类和多种动作,并提供了一个将有噪声的专家整合进来以提高稳健性的框架 [Sekhari 等人(2024 年),Sekhari、Sridharan、Sun 和 Wu]。这与差异最小化的概念或从专家演示中进行自训练的思想一致。行为克隆领域的一个挑战是,由于演示数据集有限,策略往往会失败,在这种情况下,行为克隆方法通常难以奏效。论文显示,f-MAX(一种用于逆强化学习状态边际匹配目标的 f 散度推广的 AIRL)对其优越性能贡献最大 [Ghasemipour 等人(2020 年),Ghasemipour、Zemel 和 Gu]。

[Ma 等人(2023 年),Ma、Hu、Wang 和 Sun][Bertsch 等人(2023 年),Bertsch、Xie、Neubig 和 Gormley] 模仿学习扩散策略通过将机器人的视觉 - 运动策略表示为条件去噪扩散过程,为生成机器人行为奠定了基础。我们利用这些基础研究成果和机器人操作基准测试,相对于给定基线有 46.9% 的改进。扩散策略学习动作分布得分函数的梯度,并根据该梯度场进行迭代优化,同时结合滚动时域控制、视觉条件和时间序列扩散变换器。这些学习技术启发了我们在模仿学习中实施对比解码,并定义了成功指标和数据集模块 [Chi 等人(2023 年),Chi、Feng、Du、Xu、Cousineau、Burchfiel 和 Song]。

我们主要基于 2024 年一致性策略论文的工作进行拓展,该论文直接基于扩散策略 [Prasad 等人(2024 年),Prasad、Lin、Wu、Zhou 和 Bohg],以解决在机器人应用中实现快速策略推理时高端 GPU 的限制问题。一致性策略通过在扩散策略学习的轨迹上强制实现自一致性,使用预训练的扩散策略。具体来说,我们还使用该论文中的演示模块来处理 PushT 数据集 [Wang 等人(2022 年),Wang、Wei、Schuurmans、Le、Chi、Narang、Chowdhery 和 Zhou]。

3. 方法

3.1 动机:人类演示质量缺乏注释

模仿学习中一个长期存在的挑战是,如何有效地利用人类演示数据的质量差异,尤其是当这些数据没有标注表明演示好坏的偏好标签时。这种注释的缺乏使对比解码的实施变得复杂,因为对比解码依赖于区分好坏演示,以实现有效的行为克隆。核心问题在于人类演示对比解码的采样和注释过程。传统的离线强化学习(RL)技术假设奖励结构与演示紧密相关,但在没有明确质量注释的情况下,这一假设难以维持。

在此,假设在单峰演示任务中,大多数人类演示是成功的,次优演示被视为异常值。我们开发了一种采样方法来评估演示数据的质量。在预训练的视觉语言模型无法提供自动标注的情况下,我们利用一致性采样,假设成功行为遵循相邻轨迹,从而区分好坏演示。

通过探索弱先验和基于一致性的采样技术,我们假设大多数人类演示是成功的。利用样本间的关系,我们可以减轻缺乏明确注释的影响,仍然实现有效的行为克隆。实施基于一致性的采样,需要绘制多个序列,基于弱先验,预计其中大多数是好的。此外,我们的采样方法在演示中引入噪声,以捕捉时间相关性,并评估该方法的稳健性。最后,本项目使用扩散策略技术,探索采样算法的有效性和局限性。

3.2 仿真实验数据集

为了实施采样和对比解码,我们使用一致性策略论文中提供的人类演示数据集。我们使用三个已有的基准测试,在六个任务上进行了实验:Robomimic、Push-T 和 Franka Kitchen,这些在视觉 - 运动和基于状态的策略学习中是标准的 [Chi 等人(2023 年),Chi、Feng、Du、Xu、Cousineau、Burchfiel 和 Song] 以及 ParaDiGMS。在这里,我们重点关注 PushT 任务。

最初,我们尝试使用(RH20T 网站)和 Lerobot 存储库环境中的数据集,但最终发现这些数据集对于一致性采样来说过于复杂。最后,我们选定了对比解码论文中的数据,因为一致性和扩散策略论文中的仿真设置,能够在有限的 GPU 环境中进行更快、低延迟的评估。

PushT 任务要求使用圆形末端执行器,将 T 形块推到固定目标位置。我们使用了来自 [Chi 等人(2023 年),Chi、Feng、Du、Xu、Cousineau、Bu] 的 200 个专家演示数据集,并报告了基于状态观察的策略结果。我们每 50 个训练周期评估一次策略,将成功率记录在 wandb 上,同时也记录滚动输出视频。

3.3 采样方法

为了增强在线机器人行为克隆,我们实施了一种对比解码方法,利用批次内相邻样本间的一致性。该方法根据样本与最优行为的接近程度,将其分类为 “好” 或 “坏”。

一致性采样:在每个批次中,通过向量距离衡量轨迹相似性,与相邻轨迹接近的样本被视为 “好” 样本。这种接近程度表明成功行为的概率较高。相反,与这些相邻轨迹距离较远的样本被标记为 “坏” 样本。这些样本根据其较低的概率分布和与 “好” 轨迹的向量距离进行加权,表明其为次优行为。为了进一步改进这种方法,我们区分了来自一对弱模型和强模型的样本。这种比较通过突出两个模型样本质量的差异,有助于减少偏差。

根据 PushT 任务的期望指标,为每个演示定义一个分数。我们考虑 T 符号的覆盖百分比以及与基线 T 方向的对齐程度,分数越高表示越接近最终期望位置:

计算分数的均值和标准差:计算所有演示分数的均值(\mu)和标准差(\sigma):

基于一致性选择演示:选择分数超过均值加上标准差一定倍数的演示:

从一致性演示中采样:

从  中采样来训练模型,专注于期望行为的最具代表性的示例。这是在样本间手动调整的。此外,我们采用了一种使用好演示的加权平均质心标记的技术。这种方法有助于创建演示强度的梯度,而不是进行二元分类。远离好样本质心的演示被赋予较低的权重,表明其性能次优。这创建了一个演示强度指标,提供了一种细致的分类,而不是简单的二元标签。

这种采样方法旨在在部署期间,采样具有期望属性的行为,最终无需依赖明确的偏好注释。通过关注演示质量并利用弱先验,该方法确保机器人即使在没有直接标签的情况下,也能学习到稳健有效的行为。

3.4 增强稳健性的噪声

为了提高学习算法的稳健性,确保其能够处理现实世界中的变化,我们在演示数据中引入噪声。在采样过程中,首先确定一个概率 ,按照这个概率向演示中引入噪声。这模拟了人类演示中实际存在的误差或变化。然后,确保噪声不是完全随机的,而是具有时间相关性。噪声会持续一段时间,随着时间形成局部相关的噪声模式。对于每个演示,按照定义的概率引入噪声。这迫使学习者理解并适应随时间的变化,这对于制定稳健的策略至关重要。通过方差模拟次优性,进一步确保我们的学习算法能够处理实际的、有噪声的数据。噪声的存在还有助于识别异常值,提高算法区分好坏演示的能力。

3.5 对比解码的实现

在对比模仿学习中,目标函数可以表示为:

其中, 表示状态, 表示动作, 是由参数为  的神经网络得到的状态 - 动作对  的特征表示。

在二元情况下,每个演示被标记为好或坏。这更容易处理,因为决策是分类性的。令  是一个二元指示函数,对于好演示输出 1,对于坏演示输出 0。这个损失函数直接使用二元标签,使模型更倾向于选择被标记为好的动作,远离被标记为坏的动作。对比损失函数最大化好坏演示之间的距离:

当分数是连续的时,它们提供了对演示质量的度量,经过归一化后,取值范围从最小值到最大值。在 PushT 任务中,我们继续将  定义为覆盖度和方向的函数,它返回一个表示演示质量的连续值。我们对分数进行归一化,确保其在 0 到 1 之间(可以解释为演示为好的概率或置信水平):

这种方法允许在演示之间进行更细致的区分,模型从分数较高的演示中学习得更强。然后,损失函数可以用这些归一化分数对对比项进行加权:

在这两种情况下,这些损失函数都可以集成到模仿学习模型的训练中。对于二元分数,模型学习到好坏之间的明显区别。对于连续分数,模型的更新根据演示被认为是好或坏的程度进行加权,允许根据演示质量进行更精细的调整。这些框架可以使用机器学习中的标准优化技术来应用,通过调整参数,在逐步的训练数据批次上最小化损失函数。

4. 实验结果

我们在对比学习中比较了连续评分和二元评分。结果表明,连续评分提供了更细致的反馈,提高了模型区分不同演示质量水平的能力。
最终成功率:92.46%
最终训练损失:0.849

4.1 一致性强度加权

对超参数  进行手动调整,以确定一致性阈值的一致性强度加权。假设大多数人类演示是成功的。 的值根据在异常值范围内考虑的标准差数量进行调整。在手动调整过程中,将  设置为小于等于 1.5 会导致包含次优演示。另一方面,将  设置得高于 2.5 会导致排除过多演示,减少了有用的训练数据量,自然会导致过拟合。选择  提供了一个平衡,过滤掉了大部分异常值,同时保留了一组稳健的高质量演示用于训练。与基线相比,这种  的选择使模型实现了更好的泛化性能。

4.2 超参数调整

集成 L2 正则化来对抗过拟合,从而使模型性能更具泛化性。经过广泛测试,发现 500 个训练周期在模型复杂度和真正的确定性生成之间提供了最佳平衡。

4.3 基线比较

我们的模型最终成功率达到 92.46%,最终训练损失为 0.849,这表明模型性能强劲,能够有效地从提供的演示中学习。通过测量成功率和奖励率等指标来评估性能。为了进行全面评估,我们将对比解码策略与两个基线进行了比较:普通行为克隆和仅使用正 “好” 一致性样本的行为克隆。

普通行为克隆:这个基线是在所有可用演示上训练模型,而不区分好坏样本。虽然它提供了对整体演示质量的基本理解,但没有利用区分不同演示质量的潜在好处。普通行为克隆方法的成功率为 88.32%,训练损失为 1.256。较低的成功率和较高的训练损失表明,该模型难以从混合质量的数据中有效泛化。

正 “好” 一致性样本的行为克隆:在这个基线中,模型仅在通过一致性采样方法确定为 “好” 的演示上进行训练。通过专注于高质量演示,这种方法旨在提高模型的学习效率和性能。好一致性样本的行为克隆成功率达到 85.74%,训练损失为 1.024。与普通行为克隆相比的这种改进,证明了利用高质量数据进行训练的好处。

对比解码策略:我们提出的对比解码策略优于两个基线,成功率达到 92.46%,训练损失为 0.849,为该任务提供了最优策略。

5. 结论

最终,本项目的目标是在对比模仿学习中,利用一致性采样挖掘演示质量中的样本间关系。我们的研究结果表明,我们的模型优于基线行为克隆方法。我们成功地将一致性采样与对比学习相结合用于行为克隆,证明了我们注释人类演示方法的可行性。为了增强稳健性,我们引入噪声来处理奖励的周期性变化。这种方法在存在时间相关性的情况下,有效地保持了性能。我们的实验表明,在 PushT 任务中,二元和连续评分方法的性能相似。

未来的研究方向包括将相似状态与好坏演示进行成对映射。此外,我们可以将该方法扩展到 PushT 任务之外的更多任务,并实现在线强化学习,以进一步提高模型性能和适应性。鉴于自然语言处理领域的进展,对比解码与样本间一致性策略技术相结合,是一种很有前景的技术,可用于在机器人轨迹序列生成中引入独特行为。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962709.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC? 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping(路由映射) 3.2一般参数传递 3.3RequestParam(参数重命名) 3.4RequestBody(传递JSON数据) 3.5Pa…

DeepSeek的使用技巧介绍

DeepSeek是一款由杭州深度求索人工智能技术有限公司开发的AI工具,结合了自然语言处理和深度学习技术,能够完成多种任务,如知识问答、数据分析、文案创作、代码开发等。以下将从使用技巧、核心功能及注意事项等方面详细介绍DeepSeek的使用方法…

创新创业计划书|建筑垃圾资源化回收

目录 第1部分 公司概况........................................................................ 1 第2部分 产品/服务...................................................................... 3 第3部分 研究与开发.................................................…

为AI聊天工具添加一个知识系统 之80 详细设计之21 符号逻辑 之1

本文要点 要点 前面我们讨论了本项目中的正则表达式。现在我们将前面讨论的正则表达式视为狭义的符号文本及其符号规则rule(认识的原则--认识上认识对象的约束),进而在更广泛的视角下将其视为符号逻辑及其符号原则principle(知识…

Spring Boot 热部署实现指南

在开发 Spring Bot 项目时,热部署功能能够显著提升开发效率,让开发者无需频繁重启服务器就能看到代码修改后的效果。下面为大家详细介绍一种实现 Spring Boot 热部署的方法,同时也欢迎大家补充其他实现形式。 步骤一、开启 IDEA 自动编译功能…

ARM嵌入式学习--第十一天(中断处理 , ADC)

--中断的概念 中断是指计算机运行过程中,出现某些意外情况需主机干预时,机器能自动停止正在运行的程序并转入处理新情况的程序,处理完毕后又返回被暂停的程序继续运行 --CPU处理事情的方式 -轮询方式 不断查询是否有事情需要处理&#xff0c…

ARM嵌入式学习--第十天(UART)

--UART介绍 UART(Universal Asynchonous Receiver and Transmitter)通用异步接收器,是一种通用串行数据总线,用于异步通信。该总线双向通信,可以实现全双工传输和接收。在嵌入式设计中,UART用来与PC进行通信,包括与监控…

socket实现HTTP请求,参考HttpURLConnection源码解析

背景 有台服务器,网卡绑定有2个ip地址,分别为: A:192.168.111.201 B:192.168.111.202 在这台服务器请求目标地址 C:192.168.111.203 时必须使用B作为源地址才能访问目标地址C,在这台服务器默认…

漏洞扫描工具之xray

下载地址:https://github.com/chaitin/xray/releases 1.9.11 使用文档:https://docs.xray.cool/tools/xray/Scanning 与burpsuite联动: https://xz.aliyun.com/news/7563 参考:https://blog.csdn.net/lza20001103/article/details…

正月初三特殊的一天

在我们河南豫东地区,初三这一天一般情况下可以在家休息,不需要串门走亲戚,给亲戚的长辈或比自己辈份长的拜年。 特殊的正月初三 还有两种情况,正月初三这一天必须去走亲戚。一种是有去世的亲戚没有过三周年,正月初三这…

强化学习笔记——4策略迭代、值迭代、TD算法

基于策略迭代的贝尔曼方程和基于值迭代的贝尔曼方程,关系还是不太理解 首先梳理一下: 通过贝尔曼方程将强化学习转化为值迭代和策略迭代两种问题 求解上述两种贝尔曼方程有三种方法:DP(有模型),MC&#xff…

HTTP协议和静态web服务器

一、HTTP协议 1 HTTP协议的定义 网络协议 网络协议是指计算机通信网络中两台计算机之间进行通信所必须共同遵守的规定或规则。HTTP协议 HTTP协议(超文本传输协议)是一种网络通信协议,它允许将超文本标记语言(HTML)文档从Web服务器传送到客户端的浏览器。默认端口:80HTTPS协…

智能汽车网络安全威胁报告

近年来随着智能汽车技术的快速发展,针对智能汽车的攻击也逐渐从传统的针对单一车辆控制器的攻击转变为针对整车智能化服务的攻击,包括但不限于对远程控制应用程序的操控、云服务的渗透、智能座舱系统的破解以及对第三方应用和智能服务的攻击。随着WP.29 …

在虚拟机里运行frida-server以实现对虚拟机目标软件的监测和修改参数(一)(android Google Api 35高版本版)

frida-server下载路径 我这里选择较高版本的frida-server-16.6.6-android-x86_64 以root身份启动adb 或 直接在android studio中打开 adb root 如果使用android studio打开的话,最好选择google api的虚拟机,默认以root模式开启 跳转到下载的frida-se…

Node.js——body-parser、防盗链、路由模块化、express-generator应用生成器

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

【2024年华为OD机试】(C卷,200分)- 启动多任务排序 (JavaScriptJava PythonC/C++)

一、问题描述 题目解析 本题是一个典型的拓扑排序问题。拓扑排序用于解决有向无环图(DAG)中的节点排序问题,使得对于图中的每一条有向边 (u, v),u 在排序中总是位于 v 的前面。在本题中,任务之间的依赖关系可以看作是有向图中的边,而任务的执行顺序就是拓扑排序的结果。…

【NLP251】NLP RNN 系列网络

NLP251 系列主要记录从NLP基础网络结构到知识图谱的学习 1.原理及网络结构 1.1RNN 在Yoshua Bengio论文中( http://proceedings.mlr.press/v28/pascanu13.pdf )证明了梯度求导的一部分环节是一个指数模型…

Privacy Eraser,电脑隐私的终极清除者

Privacy Eraser 是一款专为保护用户隐私而设计的全能型软件,它不仅能够深度清理计算机中的各类隐私数据,还提供了多种系统优化工具,帮助用户提升设备的整体性能。通过这款软件,用户可以轻松清除浏览器历史记录、缓存文件、Cookie、…

数据分析常用的AI工具

数据分析领域中常用的AI工具种类繁多,涵盖了从数据处理、分析到可视化和预测的各个环节。以下是一些常见且广泛应用的AI数据分析工具及其特点: 1. 数据处理与清洗工具 Python库:如PandasAI,集成了生成式AI能力,支持自…

npm常见报错整理

npm install时报UNMET PEER DEPENDENCY 现象 npm install时报UNMET PEER DEPENDENCY,且执行npm install好几遍仍报这个。 原因 不是真的缺少某个包,而是安装的依赖版本不对,警告你应该安装某一个版本。 真的缺少某个包。 解决 看了下package.json文件,我的react是有的…