了解PPO算法(Proximal Policy Optimization)

Proximal Policy Optimization (PPO) 是一种强化学习算法,由 OpenAI 提出,旨在解决传统策略梯度方法中策略更新过大的问题。PPO 通过引入限制策略更新范围的机制,在保证收敛性的同时提高了算法的稳定性和效率。

PPO算法原理

PPO 算法的核心思想是通过优化目标函数来更新策略,但在更新过程中限制策略变化的幅度。具体来说,PPO 引入了裁剪(Clipping)和信赖域(Trust Region)的思想,以确保策略不会发生过大的改变。

PPO算法公式

PPO 主要有两种变体:裁剪版(Clipped PPO)和信赖域版(Adaptive KL Penalty PPO)。本文重点介绍裁剪版的 PPO。

  • 旧策略:

    \pi_{\theta_{\text{old}}}(a|s)

    其中,\theta_{\text{old}}​ 是上一次更新后的策略参数。

  • 计算概率比率:

    r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}
  • 裁剪后的目标函数:

    L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right]

    其中,\hat{A}_t​ 是优势函数(Advantage Function),\epsilon 是裁剪范围的超参数,通常取值为0.2。

  • 更新策略参数:

    a_{\text{new}} = \arg\max_{\theta} L^{\text{CLIP}}(\theta)
PPO算法的实现

下面是用Python和TensorFlow实现 PPO 算法的代码示例:

import tensorflow as tf
import numpy as np
import gym

# 定义策略网络
class PolicyNetwork(tf.keras.Model):
    def __init__(self, action_space):
        super(PolicyNetwork, self).__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense2 = tf.keras.layers.Dense(128, activation='relu')
        self.logits = tf.keras.layers.Dense(action_space, activation=None)
    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.logits(x)

# 定义值函数网络
class ValueNetwork(tf.keras.Model):
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense2 = tf.keras.layers.Dense(128, activation='relu')
        self.value = tf.keras.layers.Dense(1, activation=None)
    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.value(x)

# 超参数
learning_rate = 0.0003
clip_ratio = 0.2
epochs = 10
batch_size = 64
gamma = 0.99

# 创建环境
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n

# 创建策略和值函数网络
policy_net = PolicyNetwork(n_actions)
value_net = ValueNetwork()

# 优化器
policy_optimizer = tf.keras.optimizers.Adam(learning_rate)
value_optimizer = tf.keras.optimizers.Adam(learning_rate)

def get_action(observation):
    logits = policy_net(observation)
    action = tf.random.categorical(logits, 1)
    return action[0, 0]

def compute_advantages(rewards, values, next_values, done):
    advantages = []
    gae = 0
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + gamma * next_values[i] * (1 - done[i]) - values[i]
        gae = delta + gamma * gae
        advantages.insert(0, gae)
    return np.array(advantages)

def ppo_update(observations, actions, advantages, returns):
    with tf.GradientTape() as tape:
        old_logits = policy_net(observations)
        old_log_probs = tf.nn.log_softmax(old_logits)
        old_action_log_probs = tf.reduce_sum(old_log_probs * tf.one_hot(actions, n_actions), axis=1)
        
        logits = policy_net(observations)
        log_probs = tf.nn.log_softmax(logits)
        action_log_probs = tf.reduce_sum(log_probs * tf.one_hot(actions, n_actions), axis=1)
        
        ratio = tf.exp(action_log_probs - old_action_log_probs)
        surr1 = ratio * advantages
        surr2 = tf.clip_by_value(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages
        policy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))
    
    policy_grads = tape.gradient(policy_loss, policy_net.trainable_variables)
    policy_optimizer.apply_gradients(zip(policy_grads, policy_net.trainable_variables))

    with tf.GradientTape() as tape:
        value_loss = tf.reduce_mean((returns - value_net(observations))**2)
    
    value_grads = tape.gradient(value_loss, value_net.trainable_variables)
    value_optimizer.apply_gradients(zip(value_grads, value_net.trainable_variables))

# 训练循环
for epoch in range(epochs):
    observations = []
    actions = []
    rewards = []
    values = []
    next_values = []
    dones = []
    
    obs = env.reset()
    done = False
    while not done:
        obs = obs.reshape(1, -1)
        observations.append(obs)
        action = get_action(obs)
        actions.append(action)
        
        value = value_net(obs)
        values.append(value)
        
        obs, reward, done, _ = env.step(action.numpy())
        rewards.append(reward)
        dones.append(done)
        
        if done:
            next_values.append(0)
        else:
            next_value = value_net(obs.reshape(1, -1))
            next_values.append(next_value)
    
    returns = compute_advantages(rewards, values, next_values, dones)
    advantages = returns - values

    observations = np.concatenate(observations, axis=0)
    actions = np.array(actions)
    returns = np.array(returns)
    advantages = np.array(advantages)
    
    ppo_update(observations, actions, advantages, returns)

    print(f'Epoch {epoch+1} completed')
总结

PPO 算法通过引入裁剪机制和信赖域约束,限制了策略更新的幅度,提高了训练过程的稳定性和效率。其简单而有效的特性使其成为目前强化学习中最流行的算法之一。通过理解并实现 PPO 算法,可以更好地应用于各种强化学习任务,提升模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/787902.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LAMP万字详解(概念、构建步骤)

目录 LAMP Apache 起源 主要特点 软件版本 编译安装httpd服务器 编译安装的优点 操作步骤 准备工作 编译 安装 优化执行路径 添加服务 守护进程 配置httpd 查看 Web 站点的访问情况 虚拟主机 类型 部署基于域名的虚拟主机 为虚拟主机提供域名解析&#xff…

ESP32的I2S引脚及支持的音频标准使用说明

ESP32 I2S 接口 ESP32 有 2 个标准 I2S 接口。这 2 个接口可以以主机或从机模式,在全双工或半双工模式下工作,并且可被配置为 8/16/32/48/64-bit 的输入输出通道,支持频率从 10 kHz 到 40 MHz 的 BCK 时钟。当 1 个或 2 个 被配置为主机模式…

db期末复习自用[应试向 附习题]

第一章 数据库系统实现整体数据的结构化,主要特征之一,是db区别于文件系统的本质区别。 数据库系统三个阶段:人工、文件、数据库系统。 数据库管理系统的功能:数据库定义、操纵 、(保护、存储、维护)、数…

大模型/NLP/算法面试题总结2——transformer流程//多头//clip//对比学习//对比学习损失函数

用语言介绍一下Transformer的整体流程 1. 输入嵌入(Input Embedding) 输入序列(如句子中的单词)首先通过嵌入层转化为高维度的向量表示。嵌入层的输出是一个矩阵,每一行对应一个输入单词的嵌入向量。 2. 位置编码&…

020-GeoGebra中级篇-几何对象之点与向量

本文概述了在GeoGebra中如何使用笛卡尔或极坐标系输入点和向量。用户可以通过指令栏输入数字和角度,使用工具或指令创建点和向量。在笛卡尔坐标系中,示例如“P(1,0)”;在极坐标系中,示例如“P(1;0)”或“v(5;90)”。文章还介绍了点…

SpringBoot + MyBatisPlus 实现多租户分库

一、引言 在如今的软件开发中,多租户(Multi-Tenancy)应用已经变得越来越常见。多租户是一种软件架构技术,它允许一个应用程序实例为多个租户提供服务。每个租户都有自己的数据和配置,但应用程序实例是共享的。而在我们的Spring Boot MyBati…

刷代码随想录有感(130):动态规划——编辑距离

题干&#xff1a; 代码&#xff1a; class Solution { public:int minDistance(string word1, string word2) {vector<vector<int>>dp(word1.size() 1, vector<int>(word2.size() 1));for(int i 0; i < word1.size(); i)dp[i][0] i;for(int j 0; j …

使用Mplayer实现MP3功能

核心功能 1. 界面设计 项目首先定义了一个clearscreen函数&#xff0c;用于清空屏幕&#xff0c;为用户界面的更新提供了便利。yemian函数负责显示主菜单界面&#xff0c;提供了包括查看播放列表、播放控制、播放模式选择等在内的9个选项。 2. 文件格式支持 is_supported_f…

数据抓取技术在视频内容监控与快速读取中的应用

引言 在数字化时代&#xff0c;视频内容的快速读取和监控对于内容提供商来说至关重要。思通数科的OPEN-SPIDER抓取技术为这一需求提供了高效的解决方案。 OPEN-SPIDER技术概述 OPEN-SPIDER是思通数科开发的一种先进的数据抓取技术&#xff0c;它能够&#xff1a; - 高效地从各…

Qt 音频编程实战项目

一Qt 音频基础知识 QT multimediaQMediaPlayer 类&#xff1a;媒体播放器&#xff0c;主要用于播放歌曲、网络收音 机等功能。QMediaPlaylist 类&#xff1a;专用于播放媒体内容的列表。 二 音频项目实战程序 //版本5.12.8 .proQT core gui QT multimedia greate…

基于深度学习的电影推荐系统

1 项目介绍 1.1 研究目的和意义 在电子商务日益繁荣的今天&#xff0c;精准预测商品销售数据成为商家提升运营效率、优化库存管理以及制定营销策略的关键。为此&#xff0c;开发了一个基于深度学习的商品销售数据预测系统&#xff0c;该系统利用Python编程语言与Django框架&a…

在RockyLinux上安装Solr8.11(新版本)

在RockyLinux上安装Solr8.11&#xff08;新版本&#xff09; 安装准备安装java环境 安装Solr下载修改配置开放端口验证一下 安装准备 安装java环境 搜索提供可安装的包 yum search java 我们在这里看到有很多&#xff0c;我这里安装的1.8版本。我们这里选择描述为Runtime en…

斯坦福大学博士在GitHub发布的漫画机器学习小抄,竟斩获129k标星

斯坦福大学数据科学博士Chris Albon在GitHub上发布了一份超火的机器学习漫画小抄&#xff0c;发布仅仅一天就斩获GitHub榜首标星暴涨120k&#xff0c;小编有幸获得了一份并把它翻译成中文版本&#xff0c;今天给大家分享出来&#xff01; 轻松的画风配上让人更容易理解的文字讲…

万字总结GBDT原理、核心参数以及调优思路

万字总结GBDT原理、核心参数以及调优思路 在机器学习领域&#xff0c;梯度提升决策树&#xff08;Gradient Boosting Decision Tree, GBDT&#xff09;以其卓越的预测性能和强大的模型解释能力而广受推崇。GBDT通过迭代地构建决策树&#xff0c;每一步都在前一步的残差上进行优…

【力扣高频题】042.接雨水问题

上一篇我们通过采用 双指针 的方法解决了 经典 容器盛水 问题 &#xff0c;本文我们接着来学习一道在面试中极大概率会被考到的经典题目&#xff1a;接雨水 问题 。 42. 接雨水 给定 n 个非负整数&#xff0c;表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子…

【高校科研前沿】中国农业大学姚晓闯老师等人在农林科学Top期刊发表长篇综述:深度学习在农田识别中的应用

文章简介 论文名称&#xff1a;Deep learning in cropland field identification: A review&#xff08;深度学习在农田识别中的应用&#xff1a;综述&#xff09; 第一作者及单位&#xff1a;Fan Xu&#xff08;中国农业大学土地科学与技术学院&#xff09; 通讯作者及单位&…

【电路笔记】-C类放大器

C类放大器 文章目录 C类放大器1、概述2、C类放大介绍3、C类放大器的功能4、C 类放大器的效率5、C类放大器的应用:倍频器6、总结1、概述 尽管存在差异,但我们在之前有关 A 类、B 类和 AB 类放大器的文章中已经看到,这三类放大器是线性或部分线性的,因为它们在放大过程中再现…

【WebGIS平台】传统聚落建筑科普数字化建模平台

基于上述概括出建筑单体的特征部件&#xff0c;本文利用互联网、三维建模和地理信息等技术设计了基于浏览器/服务器&#xff08;B/S&#xff09;的传统聚落建筑科普数字化平台。该平台不仅实现了对传统聚落建筑风貌从基础到复杂的数字化再现&#xff0c;允许用户轻松在线构建从…

Java线程池及面试题

1.线程池介绍 顾名思义&#xff0c;线程池就是管理一系列线程的资源池&#xff0c;其提供了一种限制和管理线程资源的方式。每个线程池还维护一些基本统计信息&#xff0c;例如已完成任务的数量。 总结一下使用线程池的好处&#xff1a; 降低资源消耗。通过重复利用已创建的…

去除Win32 Tab Control控件每个选项卡上的深色对话框背景

一般情况下&#xff0c;我们是用不带边框的对话框来充当Tab Control的每个选项卡的内容的。 例如&#xff0c;主对话框IDD_TABBOX上有一个Tab Control&#xff0c;上面有两个选项卡&#xff0c;第一个选项卡用的是IDD_DIALOG1充当内容&#xff0c;第二个用的则是IDD_DIALOG2。I…