强化学习之 PPO 算法:原理、实现与案例深度剖析

目录

    • 一、引言
    • 二、PPO 算法原理
      • 2.1 策略梯度
      • 2.2 PPO 核心思想
    • 三、PPO 算法公式推导
      • 3.1 重要性采样
      • 3.2 优势函数估计
    • 四、PPO 算法代码实现(以 Python 和 PyTorch 为例)
    • 五、PPO 算法案例应用
      • 5.1 机器人控制
      • 5.2 自动驾驶
    • 六、总结


一、引言

强化学习作为机器学习中的一个重要领域,旨在让智能体通过与环境交互,学习到最优的行为策略以最大化长期累积奖励。近端策略优化(Proximal Policy Optimization,PPO)算法是强化学习中的明星算法,它在诸多领域都取得了令人瞩目的成果。本文将深入探讨 PPO 算法,从原理到代码实现,再到实际案例应用,力求让读者全面掌握这一强大的算法。

二、PPO 算法原理

2.1 策略梯度

在强化学习里,策略梯度是一类关键的优化方法,你可以把它想象成是智能体在学习如何行动时的 “指南针”。假设策略由参数 θ \theta θ 表示,这就好比是智能体的 “行动指南” 参数,智能体在状态 s s s 下采取行动 a a a 的概率为 π θ ( a ∣ s ) \pi_{\theta}(a|s) πθ(as) ,即根据当前的 “行动指南”,在这个状态下选择这个行动的可能性。

策略梯度的目标是最大化累计奖励的期望,用公式表示就是: J ( θ ) = E s 0 , a 0 , ⋯ [ ∑ t = 0 T γ t r ( s t , a t ) ] J(\theta)=\mathbb{E}_{s_0,a_0,\cdots}\left[\sum_{t = 0}^{T}\gamma^{t}r(s_t,a_t)\right] J(θ)=Es0,a0,[t=0Tγtr(st,at)]

这里的 γ \gamma γ 是折扣因子,它的作用是让智能体更关注近期的奖励,因为越往后的奖励可能越不确定,就像我们在做决策时,往往会更看重眼前比较确定的好处。 r ( s t , a t ) r(s_t,a_t) r(st,at) 是在状态 s t s_t st 下采取行动 a t a_t at 获得的奖励,比如玩游戏时,在某个游戏场景下做出某个操作得到的分数。

根据策略梯度定理,策略梯度可以表示为: ∇ θ J ( θ ) = E s , a [ ∇ θ log ⁡ π θ ( a ∣ s ) A ( s , a ) ] \nabla_{\theta}J(\theta)=\mathbb{E}_{s,a}\left[\nabla_{\theta}\log\pi_{\theta}(a|s)A(s,a)\right] θJ(θ)=Es,a[θlogπθ(as)A(s,a)]

这里的 A ( s , a ) A(s,a) A(s,a) 是优势函数,它表示采取行动 a a a 相对于平均策略的优势。简单来说,就是判断这个行动比一般的行动好在哪里,好多少,帮助智能体决定是否要多采取这个行动。

2.2 PPO 核心思想

PPO 算法的核心是在策略更新时,限制策略的变化幅度,避免更新过大导致策略性能急剧下降。这就好像我们在调整自行车的变速器,如果一下子调得太猛,可能车子就没法正常骑了。

它通过引入一个截断的目标函数来实现这一点: L C L I P ( θ ) = E t [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] L^{CLIP}(\theta)=\mathbb{E}_{t}\left[\min\left(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1+\epsilon)\hat{A}_t\right)\right] LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) 是重要性采样比,它反映了新策略和旧策略对于同一个状态 - 行动对的概率差异。 A ^ t \hat{A}_t A^t 是估计的优势函数, ϵ \epsilon ϵ 是截断参数,通常设置为一个较小的值,如 0.2 。这个截断参数就像是给策略更新幅度设定了一个 “安全范围”,在这个范围内更新策略,能保证策略既有所改进,又不会变得太糟糕。

三、PPO 算法公式推导

3.1 重要性采样

重要性采样是 PPO 算法中的关键技术之一。由于直接从当前策略采样数据效率较低,我们可以从旧策略 π θ o l d \pi_{\theta_{old}} πθold 采样数据,然后通过重要性采样比 r t ( θ ) r_t(\theta) rt(θ) 来校正数据的分布。 E s ∼ π θ [ f ( s ) ] ≈ 1 N ∑ i = 1 N π θ ( s i ) π θ o l d ( s i ) f ( s i ) \mathbb{E}_{s\sim\pi_{\theta}}[f(s)]\approx\frac{1}{N}\sum_{i = 1}^{N}\frac{\pi_{\theta}(s_i)}{\pi_{\theta_{old}}(s_i)}f(s_i) Esπθ[f(s)]N1i=1Nπθold(si)πθ(si)f(si)

比如我们要了解一群鸟的飞行习惯,直接去观察所有鸟的飞行轨迹很困难,那我们可以先观察一部分容易观察到的鸟(旧策略采样),然后根据这些鸟和所有鸟的一些特征差异(重要性采样比),来推测整个鸟群的飞行习惯。

3.2 优势函数估计

优势函数 A ( s , a ) A(s,a) A(s,a) 可以通过多种方法估计,常用的是广义优势估计(Generalized Advantage Estimation,GAE): A ^ t = ∑ k = 0 ∞ ( γ λ ) k δ t + k \hat{A}_t=\sum_{k = 0}^{\infty}(\gamma\lambda)^k\delta_{t + k} A^t=k=0(γλ)kδt+k

其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_{t}=r_t+\gamma V(s_{t + 1})-V(s_t) δt=rt+γV(st+1)V(st) 是 TD 误差, λ \lambda λ 是 GAE 参数,通常在 0 到 1 之间。优势函数的估计就像是给智能体的行动打分,告诉它每个行动到底有多好,以便它做出更好的决策。

四、PPO 算法代码实现(以 Python 和 PyTorch 为例)

import torch

import torch.nn as nn

import torch.optim as optim

import gym

class Policy(nn.Module):

def __init__(self, state_dim, action_dim):

    super(Policy, self).__init__()

    self.fc1 = nn.Linear(state_dim, 64)

    self.fc2 = nn.Linear(64, 64)

    self.mu_head = nn.Linear(64, action_dim)

    self.log_std_head = nn.Linear(64, action_dim)

def forward(self, x):

    x = torch.relu(self.fc1(x))

    x = torch.relu(self.fc2(x))

    mu = torch.tanh(self.mu_head(x))

    log_std = self.log_std_head(x)

    std = torch.exp(log_std)

    dist = torch.distributions.Normal(mu, std)

    return dist

class Value(nn.Module):

def __init__(self, state_dim):

    super(Value, self).__init__()

    self.fc1 = nn.Linear(state_dim, 64)

    self.fc2 = nn.Linear(64, 64)

    self.v_head = nn.Linear(64, 1)

def forward(self, x):

    x = torch.relu(self.fc1(x))

    x = torch.relu(self.fc2(x))

    v = self.v_head(x)

    return v

def ppo_update(policy, value, optimizer_policy, optimizer_value, states, actions, rewards, dones, gamma=0.99,

           clip_epsilon=0.2, lambda_gae=0.95):

states = torch.FloatTensor(states)

actions = torch.FloatTensor(actions)

rewards = torch.FloatTensor(rewards)

dones = torch.FloatTensor(dones)

values = value(states).squeeze(1)

returns = []

gae = 0

for i in reversed(range(len(rewards))):

    if i == len(rewards) - 1:

        next_value = 0

    else:

        next_value = values[i + 1]

    delta = rewards[i] + gamma * next_value * (1 - dones[i]) - values[i]

    gae = delta + gamma * lambda_gae * (1 - dones[i]) * gae

    returns.insert(0, gae + values[i])

returns = torch.FloatTensor(returns)

old_dist = policy(states)

old_log_probs = old_dist.log_prob(actions).sum(-1)

for _ in range(3):

    dist = policy(states)

    log_probs = dist.log_prob(actions).sum(-1)

    ratios = torch.exp(log_probs - old_log_probs)

    advantages = returns - values.detach()

    surr1 = ratios * advantages

    surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages

    policy_loss = -torch.min(surr1, surr2).mean()

    optimizer_policy.zero_grad()

    policy_loss.backward()

    optimizer_policy.step()

    value_loss = nn.MSELoss()(value(states).squeeze(1), returns)

    optimizer_value.zero_grad()

    value_loss.backward()

    optimizer_value.step()

def train_ppo(env_name, num_episodes=1000):

env = gym.make(env_name)

state_dim = env.observation_space.shape[0]

action_dim = env.action_space.shape[0]

policy = Policy(state_dim, action_dim)

value = Value(state_dim)

optimizer_policy = optim.Adam(policy.parameters(), lr=3e-4)

optimizer_value = optim.Adam(value.parameters(), lr=3e-4)

for episode in range(num_episodes):

    states, actions, rewards, dones = [], [], [], []

    state = env.reset()

    done = False

    while not done:

        state = torch.FloatTensor(state)

        dist = policy(state)

        action = dist.sample()

        next_state, reward, done, _ = env.step(action.detach().numpy())

        states.append(state)

        actions.append(action)

        rewards.append(reward)

        dones.append(done)

        state = next_state

    ppo_update(policy, value, optimizer_policy, optimizer_value, states, actions, rewards, dones)

    if episode % 100 == 0:

        total_reward = 0

        state = env.reset()

        done = False

        while not done:

            state = torch.FloatTensor(state)

            dist = policy(state)

            action = dist.mean

            next_state, reward, done, _ = env.step(action.detach().numpy())

            total_reward += reward

            state = next_state

        print(f"Episode {episode}, Average Reward: {total_reward}")

if __name__ == "__main__":

train_ppo('Pendulum-v1')

五、PPO 算法案例应用

5.1 机器人控制

在机器人控制领域,PPO 算法可以用于训练机器人的运动策略。例如,训练一个双足机器人行走,机器人的状态可以包括关节角度、速度等信息,行动则是关节的控制指令。通过 PPO 算法,机器人可以学习到如何根据当前状态调整关节控制,以实现稳定高效的行走。

5.2 自动驾驶

在自动驾驶场景中,车辆的状态包括位置、速度、周围环境感知信息等,行动可以是加速、减速、转向等操作。PPO 算法可以让自动驾驶系统学习到在不同路况和环境下的最优驾驶策略,提高行驶的安全性和效率。

六、总结

PPO 算法作为强化学习中的优秀算法,以其高效的学习能力和良好的稳定性在多个领域得到了广泛应用。通过深入理解其原理、公式推导,结合代码实现和实际案例分析,我们能够更好地掌握和运用这一算法,为解决各种复杂的实际问题提供有力的工具。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/967503.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

apache-poi导出excel数据

excel导出 自动设置宽度&#xff0c;设置标题框&#xff0c;设置数据边框。 excel导出 添加依赖 <dependency><groupId>org.apache.poi</groupId><artifactId>poi-ooxml</artifactId><version>5.2.2</version></dependency>…

10 FastAPI 的自动文档

FastAPI 是一个功能强大且易于使用的 Web 框架&#xff0c;它的最大亮点之一就是内置的 自动文档生成 功能。通过集成 Swagger UI 和 ReDoc&#xff0c;FastAPI 可以自动为我们的 API 生成交互式文档。这不仅使得开发者能够更快速地了解和测试 API&#xff0c;还能够为前端开发…

微软AI研究团队推出LLaVA-Rad:轻量级开源基础模型,助力先进临床放射学报告生成

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

mysql8.0使用MHA实现高可用

一、MHA 介绍 MHA&#xff08;Master HA&#xff09;是一款开源的 MySQL 的高可用程序&#xff0c;它为 MySQL 主从复制架构提供了 automating master failover 功能。MHA 在监控到 master 节点故障时&#xff0c;会提升其中拥有最新数据的 slave 节点成为新的master 节点&…

D3实现站点路线图demo分享

分享通过D3实现的站点路线分布图demo&#xff0c;后续会继续更新其他功能。 功能点 点位弹窗 效果图如下&#xff1a; 轨迹高亮 效果图如下&#xff1a; 添加路线箭头 箭头展示逻辑&#xff1a;根据高速路线最后两个点位&#xff0c;计算得出箭头的点位 效果图如下&#x…

【系统架构设计师】操作系统 ③ ( 存储管理 | 页式存储弊端 - 段式存储引入 | 段式存储 | 段表 | 段表结构 | 逻辑地址 的 合法段地址判断 )

文章目录 一、页式存储弊端 - 段式存储引入1、页式存储弊端 - 内存碎片2、页式存储弊端 - 逻辑结构不匹配3、段式存储引入 二、段式存储 简介1、段式存储2、段表3、段表 结构4、段内地址 / 段内偏移5、段式存储 优缺点6、段式存储 与 页式存储 对比 三、逻辑地址 的 合法段地址…

物联网软件开发与应用方向应该怎样学习,学习哪些内容,就业方向是怎样?(文末领取整套学习视频,课件)物联网硬件开发与嵌入式系统

随着物联网技术的飞速发展&#xff0c;物联网软件开发与应用方向成为了众多开发者关注的焦点。那么&#xff0c;如何在这个领域中脱颖而出呢&#xff1f;本文将为你提供一份详细的学习指南&#xff0c;帮助你从零开始&#xff0c;逐步掌握物联网软件开发与应用的核心技能。 一…

Linux——基础命令1

$&#xff1a;普通用户 #&#xff1a;超级用户 cd 切换目录 cd 目录 &#xff08;进入目录&#xff09; cd ../ &#xff08;返回上一级目录&#xff09; cd ~ &#xff08;切换到当前用户的家目录&#xff09; cd - &#xff08;返回上次目录&#xff09; pwd 输出当前目录…

OpenFeign远程调用返回的是List<T>类型的数据

在使用 OpenFeign 进行远程调用时&#xff0c;如果接口返回的是 List 类型的数据&#xff0c;可以通过以下方式处理&#xff1a; 直接定义返回类型为List Feign 默认支持 JSON 序列化/反序列化&#xff0c;如果服务端返回的是 List的JSON格式数据&#xff0c;可以直接在 Feig…

向量数据库简单对比

文章目录 一、Chroma二、Pinecone/腾讯云VectorDB/VikingDB三、redis四、Elasticsearch五、Milvus六、Qdrant七、Weaviate八、Faiss 一、Chroma 官方地址&#xff1a; https://www.trychroma.com/优点 ①简单&#xff0c;非常简单构建服务。 ②此外&#xff0c;Chroma还具有自…

字符指针、数组指针和函数指针

1. 字符指针变量 1.1 简单例子 字符指针 char* 在C语言中主要由两种用法&#xff1a; 1.用于存放一个字符变量的地址。 2.用字符指针接收一个字符串。 这里并不是将整个字符串的地址存入 pstr 指针&#xff0c;指针变量 pstr 中存放的是常量字符串的首字符 h 的地址。 以一个…

【Linux网络编程】之守护进程

【Linux网络编程】之守护进程 进程组进程组的概念组长进程 会话会话的概念会话ID 控制终端控制终端的概念控制终端的作用会话、终端、bash三者的关系 前台进程与后台进程概念特点查看当前终端的后台进程前台进程与后台进程的切换 作业控制相关概念作业状态&#xff08;一般指后…

JS宏进阶:XMLHttpRequest对象

一、概述 XMLHttpRequest简称XHR&#xff0c;它是一个可以在JavaScript中使用的对象&#xff0c;用于在后台与服务器交换数据&#xff0c;实现页面的局部更新&#xff0c;而无需重新加载整个页面&#xff0c;也是Ajax&#xff08;Asynchronous JavaScript and XML&#xff09;…

怎么查看电脑显存大小(查看电脑配置)

这里提供一个简单的方法查看 winr打开cmd 终端输入dxdiag进入DirectX 点击显示查看设备的显示内存&#xff08;VRAM&#xff09; 用这个方法查看电脑配置和显存是比较方便的 dxdiag功能 Dxdiag是Windows的DirectX诊断工具&#xff0c;其主要作用包括但不限于以下几点&#…

优惠券平台(一):基于责任链模式创建优惠券模板

前景概要 系统的主要实现是优惠券的相关业务&#xff0c;所以对于用户管理的实现我们简单用拦截器在触发接口前创建一个单一用户。 // 用户属于非核心功能&#xff0c;这里先通过模拟的形式代替。后续如果需要后管展示&#xff0c;会重构该代码 UserInfoDTO userInfoDTO new…

【机器学习】数据预处理之scikit-learn的Scaler与自定义Scaler类进行数据归一化

scikit-learn的Scaler数据归一化 一、摘要二、训练数据集和测试数据集的归一化处理原则三、scikit-learn中的Scalar类及示例四、自定义StandardScaler类进行数据归一化处理五、小结 一、摘要 本文主要介绍了scikit-learn中Scaler的使用方法&#xff0c;特别强调了数据归一化在…

机器学习中过拟合和欠拟合问题处理方法总结

目录 一、背景二、过拟合(Overfitting)2.1 基本概念2.2 过拟合4个最主要的特征2.3 防止过拟合的11个有效方法 三、欠拟合&#xff08;Underfitting&#xff09;3.1 基本概念3.2 欠拟合的4个特征3.3 防止欠拟合的11个有效方法 四、总结五、参考资料 一、背景 在机器学习模型训练…

ABP框架9——自定义拦截器的实现与使用

一、AOP编程 AOP定义:面向切片编程&#xff0c;着重强调功能&#xff0c;将功能从业务逻辑分离出来。AOP使用场景&#xff1a;处理通用的、与业务逻辑无关的功能&#xff08;如日志记录、性能监控、事务管理等&#xff09;拦截器:拦截方法调用并添加额外的行为&#xff0c;比如…

基于YoloV11和驱动级鼠标模拟实现Ai自瞄

本文将围绕基于 YoloV11 和驱动级鼠标实现 FPS 游戏 AI 自瞄展开阐述。 需要着重强调的是&#xff0c;本文内容仅用于学术研究和技术学习目的。严禁任何个人或组织将文中所提及的技术、方法及思路应用于违法行为&#xff0c;包括但不限于在各类游戏中实施作弊等违规操作。若因违…

示例代码:C# MQTTS双向认证(客户端)(服务器EMQX)

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github&#xff1a;codetoys&#xff0c;所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的&#xff0c;可以在任何平台上使用。 源码指引&#xff1a;github源…