深入解析强化学习中的 Generalized Advantage Estimation (GAE)

中文版

深入解析强化学习中的 Generalized Advantage Estimation (GAE)


1. 什么是 Generalized Advantage Estimation (GAE)?

在强化学习中,计算策略梯度的关键在于 优势函数(Advantage Function) 的设计。优势函数 ( A ( s , a ) A(s, a) A(s,a) ) 衡量了执行某动作 ( a a a ) 比其他动作的相对价值。然而,优势函数的估计往往面临以下两大问题:

  1. 高方差问题:由于强化学习中的样本通常有限,直接使用单步回报或 Monte Carlo 方法计算优势值会导致高方差。
  2. 偏差问题:使用引入近似函数(如值函数 ( V ( s ) V(s) V(s) ))会降低方差,但可能引入偏差。

为了平衡 偏差和方差,Schulman 等人在 2016 年提出了 Generalized Advantage Estimation (GAE) 方法,它是一种在偏差和方差之间权衡的优势函数估计方法,被广泛应用于强化学习中的近端策略优化(PPO)等算法。


2. GAE 的数学原理

GAE 的核心思想是通过时间差分(Temporal Difference, TD)误差的加权和,估计优势函数:

TD 残差(Temporal Difference Residuals)
δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)

GAE 的递归定义
A t GAE = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^\text{GAE} = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l} AtGAE=l=0(γλ)lδt+l

其中:

  • ( γ \gamma γ ) 是折扣因子,用于控制未来回报的权重。
  • ( λ \lambda λ ) 是 GAE 的衰减系数,控制长期和短期偏差的平衡。
  • ( δ t \delta_t δt ) 是每一步的 TD 残差,反映即时回报和值函数的差异。

GAE 的公式可以通过递归形式表示为:
A t GAE = δ t + ( γ λ ) ⋅ A t + 1 GAE A_t^\text{GAE} = \delta_t + (\gamma \lambda) \cdot A_{t+1}^\text{GAE} AtGAE=δt+(γλ)At+1GAE

通过 ( λ \lambda λ ) 的调节,GAE 可以在单步 TD 估计(低方差高偏差)和 Monte Carlo 估计(高方差低偏差)之间找到一个平衡点。


3. GAE 在 PPO 中的应用

PPO paper:https://arxiv.org/pdf/1707.06347

在近端策略优化(PPO)算法中,策略梯度的更新依赖于优势函数 ( A t A_t At ) 的估计,而 GAE 为优势函数的估计提供了一个高效的工具。

PPO 的 损失函数 包括两部分:

  1. 策略更新(Actor Loss):
    L actor = E t [ min ⁡ ( r t ( θ ) ⋅ A t GAE , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A t GAE ) ] L^\text{actor} = \mathbb{E}_t \left[ \min(r_t(\theta) \cdot A_t^\text{GAE}, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t^\text{GAE}) \right] Lactor=Et[min(rt(θ)AtGAE,clip(rt(θ),1ϵ,1+ϵ)AtGAE)]
    其中 ( r t ( θ ) r_t(\theta) rt(θ) ) 是当前策略与旧策略的概率比。

  2. 值函数更新(Critic Loss):
    L critic = E t [ ( R t − V ( s t ) ) 2 ] L^\text{critic} = \mathbb{E}_t \left[ (R_t - V(s_t))^2 \right] Lcritic=Et[(RtV(st))2]

PPO 使用 GAE 来高效估计 ( A t A_t At ),从而使得梯度更新既稳定又高效。


4. GAE 的代码实现

以下是 GAE 的核心代码实现:

import numpy as np

def compute_gae(rewards, values, gamma=0.99, lam=0.95):
    """
    使用 GAE 计算优势函数
    Args:
        rewards: 每一步的即时奖励 (list or array)
        values: 每一步的状态值函数估计 (list or array)
        gamma: 折扣因子
        lam: GAE 的衰减系数
    Returns:
        advantages: 每一步的优势函数估计
    """
    advantages = np.zeros_like(rewards)
    gae = 0  # 初始化 GAE
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * (values[t + 1] if t < len(rewards) - 1 else 0) - values[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

# 示例数据
rewards = [1, 1, 1, 1, 1]  # 即时奖励
values = [0.5, 0.6, 0.7, 0.8, 0.9]  # 状态值函数估计
advantages = compute_gae(rewards, values)
print("GAE 计算结果:", advantages)

5. 数值模拟

假设我们有以下场景:

  • 即时奖励:玩家在每一步获得固定奖励 ( r t = 1 r_t = 1 rt=1 )。
  • 状态值估计:模型估计的值函数逐步递增。

我们将模拟不同 ( λ \lambda λ ) 值对优势函数估计的影响。

import matplotlib.pyplot as plt

# 参数设置
gamma = 0.99
rewards = [1, 1, 1, 1, 1]
values = [0.5, 0.6, 0.7, 0.8, 0.9]

# 不同 lambda 值的 GAE
lambda_values = [0.5, 0.8, 0.95, 1.0]
results = {}

for lam in lambda_values:
    advantages = compute_gae(rewards, values, gamma, lam)
    results[lam] = advantages

# 绘图
for lam, adv in results.items():
    plt.plot(adv, label=f"λ = {lam}")

plt.xlabel("时间步 (t)")
plt.ylabel("优势函数 (A_t)")
plt.title("不同 λ 对 GAE 的影响")
plt.legend()
plt.grid()
plt.show()

绘图结果
在这里插入图片描述


6. 总结
  1. GAE 的优势

    • 低方差:通过 ( λ \lambda λ ) 控制,引入更多的短期回报,减少方差。
    • 高效率:兼顾短期 TD 和长期 Monte Carlo 的优点。
    • 灵活性:可以根据任务需求调整偏差和方差的权衡。
  2. PPO 中的应用
    GAE 是 PPO 算法中计算优势函数的重要工具,其估计结果直接影响策略梯度和价值函数的更新效率。

通过本文的介绍,我们可以更深入地理解 GAE 的数学原理、代码实现以及其在实际场景中的应用,希望对强化学习爱好者有所帮助!

英文版

Deep Dive into Generalized Advantage Estimation (GAE) in Reinforcement Learning


1. What is Generalized Advantage Estimation (GAE)?

In reinforcement learning, estimating the advantage function ( A ( s , a ) A(s, a) A(s,a) ) is a crucial step in computing the policy gradient. The advantage function measures how much better a specific action ( a ) is compared to others in a given state ( s s s ). However, estimating this function poses two main challenges:

  1. High variance: Direct computation using one-step rewards or Monte Carlo rollouts often results in high variance, which makes optimization unstable.
  2. Bias: Introducing approximations (e.g., using value functions ( V ( s ) V(s) V(s) )) reduces variance but introduces bias.

To balance bias and variance, Schulman et al. introduced Generalized Advantage Estimation (GAE) in 2016. GAE is an efficient method that adjusts the advantage function estimate using a weighted sum of temporal difference (TD) residuals.


2. Mathematical Foundation of GAE

The key idea of GAE is to compute the advantage function using a combination of short-term and long-term rewards, weighted by a decay factor.

Temporal Difference (TD) Residual:
δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)

GAE Recursive Formula:
A t GAE = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^\text{GAE} = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l} AtGAE=l=0(γλ)lδt+l

Where:

  • ( γ \gamma γ ) is the discount factor, controlling the weight of future rewards.
  • ( λ \lambda λ ) is the GAE decay factor, balancing short-term and long-term contributions.
  • ( δ t \delta_t δt ) is the TD residual, representing the difference between immediate rewards and value estimates.

Alternatively, GAE can be written recursively:
A t GAE = δ t + ( γ λ ) ⋅ A t + 1 GAE A_t^\text{GAE} = \delta_t + (\gamma \lambda) \cdot A_{t+1}^\text{GAE} AtGAE=δt+(γλ)At+1GAE

By adjusting ( λ \lambda λ ), GAE can interpolate between:

  • Low variance, high bias: ( λ = 0 \lambda = 0 λ=0 ) (one-step TD residual).
  • High variance, low bias: ( λ = 1 \lambda = 1 λ=1 ) (Monte Carlo return).

3. Application of GAE in PPO

In Proximal Policy Optimization (PPO), GAE plays a critical role in estimating the advantage function, which is used in both the policy update and value function update.

PPO Loss Function:

  1. Actor Loss (Policy Update):
    L actor = E t [ min ⁡ ( r t ( θ ) ⋅ A t GAE , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A t GAE ) ] L^\text{actor} = \mathbb{E}_t \left[ \min(r_t(\theta) \cdot A_t^\text{GAE}, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t^\text{GAE}) \right] Lactor=Et[min(rt(θ)AtGAE,clip(rt(θ),1ϵ,1+ϵ)AtGAE)]
    Where ( r t ( θ ) r_t(\theta) rt(θ) ) is the probability ratio between the new and old policies.

  2. Critic Loss (Value Function Update):
    L critic = E t [ ( R t − V ( s t ) ) 2 ] L^\text{critic} = \mathbb{E}_t \left[ (R_t - V(s_t))^2 \right] Lcritic=Et[(RtV(st))2]

PPO relies on GAE to provide a stable and accurate advantage estimate ( A t GAE A_t^\text{GAE} AtGAE ), ensuring efficient policy gradient updates.


4. Code Implementation

Below is a Python implementation of GAE:

import numpy as np

def compute_gae(rewards, values, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation (GAE).
    Args:
        rewards: List of rewards at each timestep.
        values: List of value function estimates at each timestep.
        gamma: Discount factor.
        lam: GAE decay factor.
    Returns:
        advantages: GAE-based advantage estimates.
    """
    advantages = np.zeros_like(rewards)
    gae = 0  # Initialize GAE
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * (values[t + 1] if t < len(rewards) - 1 else 0) - values[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

# Example usage
rewards = [1, 1, 1, 1, 1]  # Reward at each timestep
values = [0.5, 0.6, 0.7, 0.8, 0.9]  # Value function estimates
advantages = compute_gae(rewards, values)
print("GAE Advantages:", advantages)
# GAE 计算结果: [4 3 2 1 0]

5. Numerical Simulation

We can simulate how different values of ( λ \lambda λ ) impact the GAE estimates using the following script:

import matplotlib.pyplot as plt

# Parameters
gamma = 0.99
rewards = [1, 1, 1, 1, 1]
values = [0.5, 0.6, 0.7, 0.8, 0.9]

# Compute GAE for different lambda values
lambda_values = [0.5, 0.8, 0.95, 1.0]
results = {}

for lam in lambda_values:
    advantages = compute_gae(rewards, values, gamma, lam)
    results[lam] = advantages

# Plot the results
for lam, adv in results.items():
    plt.plot(adv, label=f"λ = {lam}")

plt.xlabel("Time Step (t)")
plt.ylabel("Advantage (A_t)")
plt.title("Impact of λ on GAE")
plt.legend()
plt.grid()
plt.show()

6. Summary
  1. What GAE Solves:

    • GAE balances the bias-variance trade-off in advantage estimation, making it a key tool in reinforcement learning.
  2. GAE in PPO:

    • GAE ensures stable and efficient policy updates by providing accurate advantage estimates for the actor-critic framework.
  3. Key Takeaways:

    • ( λ \lambda λ ) is a critical hyperparameter in GAE, allowing control over the trade-off between bias and variance.
    • GAE is widely adopted in modern reinforcement learning algorithms, particularly in on-policy methods like PPO.

This blog post illustrates the importance of GAE in reinforcement learning, along with its implementation and impact on training stability. By leveraging GAE, algorithms like PPO achieve superior performance in complex environments.

后记

2024年12月12日21点38分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/935495.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于SpringBoot框架的个性化的旅游网站 (计算机毕业设计)+万字毕业论文 T025

系统合集跳转 源码获取链接 一、系统环境 运行环境: 最好是java jdk 1.8&#xff0c;我们在这个平台上运行的。其他版本理论上也可以。 IDE环境&#xff1a; Eclipse,Myeclipse,IDEA或者Spring Tool Suite都可以 tomcat环境&#xff1a; Tomcat 7.x,8.x,9.x版本均可 操作系统…

java web 实验五 Servlet控制层设计(设计性)

实验五 Servlet控制层设计&#xff08;设计性&#xff09; //代码放在资源包里了 实验目的 熟悉Servlet的基本语法。掌握采用HTML、JS、JDBC、JSP、Servlet和四层结构的综合应用。实验要求 本实验要求每个同学单独完成&#xff1b;调试程序要记录调试过程中出现的问题及解决…

汽车保养系统+ssm

摘 要 由于APP软件在开发以及运营上面所需成本较高&#xff0c;而用户手机需要安装各种APP软件&#xff0c;因此占用用户过多的手机存储空间&#xff0c;导致用户手机运行缓慢&#xff0c;体验度比较差&#xff0c;进而导致用户会卸载非必要的APP&#xff0c;倒逼管理者必须改…

仿iOS日历、飞书日历、Google日历的日模式

仿iOS日历、飞书日历、Google日历的日模式&#xff0c;24H内事件可自由上下拖动、自由拉伸。 以下是效果图&#xff1a; 具体实现比较简单&#xff0c;代码如下&#xff1a; import android.content.Context; import android.graphics.Canvas; import android.graphics.Color;…

网易云信荣获“HarmonyOS NEXT SDK星河奖”

近日&#xff0c;鸿蒙生态伙伴 SDK 开发者论坛在北京举行。 网易云信凭借在融合通信领域的技术创新和鸿蒙生态贡献&#xff0c;荣获鸿蒙生态“HarmonyOS NEXT SDK星河奖”。 会上&#xff0c;华为鸿蒙正式推出 SDK 生态繁荣伙伴支持计划&#xff0c;旨在为 SDK 领域伙伴和开发…

Electromagnetic Tracking Smart Car based on STM32F401 or GD32F450ZGT6

Electromagnetic Smart Car1 基于梁山派的电磁循迹智能车的主控芯片来自立创梁山派板载的国产兆易创新GD32F450ZGT6&#xff0c;整车采用模块化开发&#xff0c;由电源模块、L298N驱动模块、电磁循迹模块、梁山派、调试模块和显示模块组成。 嵌入式软件开发环境是&#xff1a;K…

Windows下Docker Desktop+k8s安装和部署程序

Windows下Docker Desktopk8s安装和部署程序 一、安装Docker DesktopKubernetes 1.需要安装windows版的docker 安装 Docker Desktop&#xff0c;启用Hyper-V、虚拟机平台和容器 https://www.docker.com/get-started/ 2.启用Kubernetes 打开Docker-Desktop&#xff0c;启用…

网络原理03

回顾 应用层&#xff1a;应用程序&#xff0c;数据具体如何使用 传输层&#xff1a;关注起点和终点 网络层&#xff1a;关注路径规划 数据链路层&#xff1a;关注相邻节点的转发 物理层&#xff1a;硬件设备 应用层 应用程序 在应用层&#xff0c;很多时候&#xff0c;…

HTTP 状态码大全

常见状态码 200 OK # 客户端请求成功 400 Bad Request # 客户端请求有语法错误 不能被服务器所理解 401 Unauthorized # 请求未经授权 这个状态代码必须和WWW- Authenticate 报头域一起使用 403 Forbidden # 服务器收到请求但是拒绝提供服务 404 Not Found # 请求资源不存…

Ajax--实现检测用户名是否存在功能

目录 &#xff08;一&#xff09;什么是Ajax &#xff08;二&#xff09;同步交互与异步交互 &#xff08;三&#xff09;AJAX常见应用情景 &#xff08;四&#xff09;AJAX的优缺点 &#xff08;五&#xff09;使用jQuery实现AJAX 1.使用JQuery中的ajax方法实现步骤&#xf…

unique_ptr自定义删除器,_Compressed_pair利用偏特化减少存储的一些设计思路

主要是利用偏特化&#xff0c; 如果自定义删除器是空类&#xff08;没有成员变量&#xff0c;可以有成员函数&#xff09;&#xff1a; _Compressed_pair会继承删除器&#xff08;删除器作为基类&#xff09;&#xff0c;但_Compressed_pair里不保存删除器对象&#xff0c;只…

AGCRN论文解读

一、创新点 传统GCN只能基于静态预定义图建模全局共享模式&#xff0c;而AGCRN通过两种GCN的增强模块&#xff08;NAPL、DAGG&#xff09;实现了更精细的节点特性学习和图结构生成。 1 节点自适应参数学习模块&#xff08;NAPL&#xff09; 传统GCN通过共享参数&#xff08;权重…

使用观测云排查数据库死锁故障

故障发现 核心应用 pod 发生重启&#xff0c;同时接收到对应使用者反馈业务问题&#xff0c;开始排查。 观测云排查现场 1、根据重启应用信息&#xff0c;查询 APM 执行数据库 update 操作大量报错&#xff0c;执行时间在 5min 以上。 分析 APM 链路异常&#xff0c;发现是触…

UNIX数据恢复—UNIX系统常见故障问题和数据恢复方案

UNIX系统常见故障表现&#xff1a; 1、存储结构出错&#xff1b; 2、数据删除&#xff1b; 3、文件系统格式化&#xff1b; 4、其他原因数据丢失。 UNIX系统常见故障解决方案&#xff1a; 1、检测UNIX系统故障涉及的设备是否存在硬件故障&#xff0c;如果存在硬件故障&#xf…

npm或yarn包配置地址源

三种方法 1.配置.npmrc 文件 在更目录新增.npmrc文件 然后写入需要访问的包的地址 2.直接yarn.lock文件里面修改地址 简单粗暴 3.yarn install 的时候添加参数 设置包的仓库地址 yarn config set registry https://registry.yarnpkg.com 安装&#xff1a;yarn install 注意…

opencv——图片矫正

图像矫正 图像矫正的原理是透视变换&#xff0c;下面来介绍一下透视变换的概念。 听名字有点熟&#xff0c;我们在图像旋转里接触过仿射变换&#xff0c;知道仿射变换是把一个二维坐标系转换到另一个二维坐标系的过程&#xff0c;转换过程坐标点的相对位置和属性不发生变换&a…

【学习】企业通过CMMI认证,还需要申请CSMM资质吗

​ 企业通过CMMI认证之后&#xff0c;是否还有必要申请CSMM资质&#xff1f;这是一个值得软件企业深思的问题。虽然CMMI和CSMM都在组织的软件过程改进和认证方面发挥着重要作用&#xff0c;但它们各自拥有自己的特点在。企业需要根据自身发展需求来选择适合的认证方式。首先我…

OpenHarmony-3.HDF input子系统(5)

HDF input 子系统OpenHarmony-4.0-Release 1.Input 概述 输入设备是用户与计算机系统进行人机交互的主要装置之一&#xff0c;是用户与计算机或者其他设备通信的桥梁。常见的输入设备有键盘、鼠标、游戏杆、触摸屏等。本文档将介绍基于 HDF_Input 模型的触摸屏器件 IC 为 GT91…

BurpSuite之移动端流量抓包

学习视频来自B站UP主泷羽sec&#xff0c;如涉及侵权马上删除文章。 笔记只是方便学习&#xff0c;以下内容只涉及学习内容&#xff0c;切莫逾越法律红线。 安全见闻&#xff0c;包含了各种网络安全&#xff0c;网络技术&#xff0c;旨在明白自己的渺小&#xff0c;知识的广博&a…

Any2Policy: Learning Visuomotor Policy with Any-Modality(类似AnyGPT)

发表时间&#xff1a;NeurIPS 2024 论文链接&#xff1a;https://readpaper.com/pdf-annotate/note?pdfId2598959255168534016&noteId2598960522854466816 作者单位&#xff1a;Midea Group Motivation&#xff1a;Current robotic learning methodologies often focus…