Proximal Policy Optimization (PPO) 是一种强化学习算法,由 OpenAI 提出,旨在解决传统策略梯度方法中策略更新过大的问题。PPO 通过引入限制策略更新范围的机制,在保证收敛性的同时提高了算法的稳定性和效率。
PPO算法原理
PPO 算法的核心思想是通过优化目标函数来更新策略,但在更新过程中限制策略变化的幅度。具体来说,PPO 引入了裁剪(Clipping)和信赖域(Trust Region)的思想,以确保策略不会发生过大的改变。
PPO算法公式
PPO 主要有两种变体:裁剪版(Clipped PPO)和信赖域版(Adaptive KL Penalty PPO)。本文重点介绍裁剪版的 PPO。
-
旧策略:
其中, 是上一次更新后的策略参数。
-
计算概率比率:
-
裁剪后的目标函数:
其中, 是优势函数(Advantage Function), 是裁剪范围的超参数,通常取值为0.2。
-
更新策略参数:
PPO算法的实现
下面是用Python和TensorFlow实现 PPO 算法的代码示例:
import tensorflow as tf
import numpy as np
import gym
# 定义策略网络
class PolicyNetwork(tf.keras.Model):
def __init__(self, action_space):
super(PolicyNetwork, self).__init__()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dense2 = tf.keras.layers.Dense(128, activation='relu')
self.logits = tf.keras.layers.Dense(action_space, activation=None)
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return self.logits(x)
# 定义值函数网络
class ValueNetwork(tf.keras.Model):
def __init__(self):
super(ValueNetwork, self).__init__()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dense2 = tf.keras.layers.Dense(128, activation='relu')
self.value = tf.keras.layers.Dense(1, activation=None)
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return self.value(x)
# 超参数
learning_rate = 0.0003
clip_ratio = 0.2
epochs = 10
batch_size = 64
gamma = 0.99
# 创建环境
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
# 创建策略和值函数网络
policy_net = PolicyNetwork(n_actions)
value_net = ValueNetwork()
# 优化器
policy_optimizer = tf.keras.optimizers.Adam(learning_rate)
value_optimizer = tf.keras.optimizers.Adam(learning_rate)
def get_action(observation):
logits = policy_net(observation)
action = tf.random.categorical(logits, 1)
return action[0, 0]
def compute_advantages(rewards, values, next_values, done):
advantages = []
gae = 0
for i in reversed(range(len(rewards))):
delta = rewards[i] + gamma * next_values[i] * (1 - done[i]) - values[i]
gae = delta + gamma * gae
advantages.insert(0, gae)
return np.array(advantages)
def ppo_update(observations, actions, advantages, returns):
with tf.GradientTape() as tape:
old_logits = policy_net(observations)
old_log_probs = tf.nn.log_softmax(old_logits)
old_action_log_probs = tf.reduce_sum(old_log_probs * tf.one_hot(actions, n_actions), axis=1)
logits = policy_net(observations)
log_probs = tf.nn.log_softmax(logits)
action_log_probs = tf.reduce_sum(log_probs * tf.one_hot(actions, n_actions), axis=1)
ratio = tf.exp(action_log_probs - old_action_log_probs)
surr1 = ratio * advantages
surr2 = tf.clip_by_value(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages
policy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))
policy_grads = tape.gradient(policy_loss, policy_net.trainable_variables)
policy_optimizer.apply_gradients(zip(policy_grads, policy_net.trainable_variables))
with tf.GradientTape() as tape:
value_loss = tf.reduce_mean((returns - value_net(observations))**2)
value_grads = tape.gradient(value_loss, value_net.trainable_variables)
value_optimizer.apply_gradients(zip(value_grads, value_net.trainable_variables))
# 训练循环
for epoch in range(epochs):
observations = []
actions = []
rewards = []
values = []
next_values = []
dones = []
obs = env.reset()
done = False
while not done:
obs = obs.reshape(1, -1)
observations.append(obs)
action = get_action(obs)
actions.append(action)
value = value_net(obs)
values.append(value)
obs, reward, done, _ = env.step(action.numpy())
rewards.append(reward)
dones.append(done)
if done:
next_values.append(0)
else:
next_value = value_net(obs.reshape(1, -1))
next_values.append(next_value)
returns = compute_advantages(rewards, values, next_values, dones)
advantages = returns - values
observations = np.concatenate(observations, axis=0)
actions = np.array(actions)
returns = np.array(returns)
advantages = np.array(advantages)
ppo_update(observations, actions, advantages, returns)
print(f'Epoch {epoch+1} completed')
总结
PPO 算法通过引入裁剪机制和信赖域约束,限制了策略更新的幅度,提高了训练过程的稳定性和效率。其简单而有效的特性使其成为目前强化学习中最流行的算法之一。通过理解并实现 PPO 算法,可以更好地应用于各种强化学习任务,提升模型的性能。