深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。

策略梯度算法的基本概念

在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 $\pi_\theta(a|s)$,其中$\theta$ 是策略的参数,$s$ 是状态,$a$ 是行动。

目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数$J(\theta)$,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。

策略梯度的数学推导

假设我们的目标函数 $J(\theta)$ 定义为:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]

其中$\tau$ 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),$R(\tau)$ 是该轨迹的总回报。根据策略的定义,我们有:

\pi_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)

因此,目标函数可以重写为:

J(\theta) = \sum_{\tau} \pi_\theta(\tau) R(\tau)

为了最大化$J(\theta)$,我们需要计算其梯度 $\nabla_\theta J(\theta)$

\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} \pi_\theta(\tau) R(\tau) = \sum_{\tau} \nabla_\theta \pi_\theta(\tau) R(\tau)

使用概率分布的梯度性质,我们有:

\nabla_\theta \pi_\theta(\tau) = \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau)

因此,梯度可以表示为:

\nabla_\theta J(\theta) = \sum_{\tau} \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]

这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 $\pi_\theta$ 中采样多个轨迹 $\tau$,然后计算平均值。

策略梯度算法的实现

我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略$\pi_\theta(a|s)$

步骤 1:环境设置

首先,设置环境和参数:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

env = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义

定义一个简单的策略网络:

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, n_actions):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹

编写函数来从策略中采样轨迹:

def sample_trajectory(env, policy, max_steps=1000):
    state = env.reset()
    states, actions, rewards = [], [], []
    for _ in range(max_steps):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = policy(state)
        action = np.random.choice(n_actions, p=probs.detach().numpy()[0])
        next_state, reward, done, _ = env.step(action)
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        if done:
            break
        state = next_state
    return states, actions, rewards
步骤 4:计算回报和梯度

计算每个状态的回报,并使用策略梯度定理更新策略:

def compute_returns(rewards, gamma=0.99):
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    return returns

def update_policy(policy, optimizer, states, actions, returns):
    returns = torch.FloatTensor(returns)
    loss = 0
    for state, action, G in zip(states, actions, returns):
        state = state.squeeze(0)
        probs = policy(state)
        log_prob = torch.log(probs[action])
        loss += -log_prob * G
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
步骤 5:训练策略

将上述步骤组合在一起,训练策略网络:

num_episodes = 1000
for episode in range(num_episodes):
    states, actions, rewards = sample_trajectory(env, policy)
    returns = compute_returns(rewards)
    update_policy(policy, optimizer, states, actions, returns)
    if episode % 100 == 0:
        print(f"Episode {episode}, total reward: {sum(rewards)}")
总结

通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/767254.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI大模型时代来临:企业如何抢占先机?

AI大模型时代来临:企业如何抢占先机? 2023年,被誉为大模型元年,AI大模型的发展如同一股不可阻挡的潮流,正迅速改变着我们的工作和生活方式。从金融到医疗,从教育到制造业,AI大模型正以其强大的生成能力和智能分析,重塑着行业的未来。 智能化:企业核心能力的转变 企…

【CUDA】 归约 Reduction

Reduction Reduction算法从一组数值中产生单个数值。这个单个数值可以是所有元素中的总和、最大值、最小值等。 图1展示了一个求和Reduction的例子。 图1 线程层次结构 在Reduction算法中,线程的常见组织方式是为每个元素使用一个线程。下面将展示利用许多不同方…

AI-算力集群通往AGI

背景: 自GPT-4发布以来,全球AI能力的发展势头有放缓的迹象。 但这并不意味着Scaling Law失效,也不是因为训练数据不够,而是结结实实的遇到了算力瓶颈。 具体来说,GPT-4的训练算力约2e25 FLOP,近期发布的几个…

双曲方程初值问题的差分逼近(迎风格式)

稳定性: 数值例子 例一 例二 代码 % function chap4_hyperbolic_1st0rder_1D % test the upwind scheme for 1D hyperbolic equation % u_t + a*u_x = 0,0<x<L,O<t<T, % u(x,0) = |x-1|,0<X<L, % u(0,t) = 1% foundate = 2015-4-22’; % chgedate = 202…

刷代码随想录有感(124):动态规划——最长公共子序列

题干&#xff1a; 代码&#xff1a; class Solution { public:int findLength(vector<int>& nums1, vector<int>& nums2) {vector<vector<int>>dp(nums1.size() 1, vector<int>(nums2.size() 1, 0));int res 0;for(int i 1; i <…

买华为智驾,晚了肯定要后悔

文 | AUTO芯球 作者 | 雷慢 晚了就来不及了&#xff01; 你买华为系的车&#xff0c;薅羊毛真的要趁早。 华为ADS2.0高阶智驾正在慢慢恢复原价&#xff0c; 你看啊&#xff0c;就在昨天&#xff0c;华为宣布ADS智驾优惠后价格调到3万元&#xff0c; 只有6000元的优惠了。…

153. 寻找旋转排序数组中的最小值(中等)

153. 寻找旋转排序数组中的最小值 1. 题目描述2.详细题解3.代码实现3.1 Python3.2 Java 1. 题目描述 题目中转&#xff1a;153. 寻找旋转排序数组中的最小值 2.详细题解 如果不考虑 O ( l o g n ) O(log n) O(logn)的时间复杂度&#xff0c;直接 O ( n ) O(n) O(n)时间复杂…

从百数教学看产品设计:掌握显隐规则,打造极致用户体验

字段显隐规则允许通过一个控件&#xff08;如复选框、单选按钮或下拉菜单&#xff09;来控制其他控件&#xff08;如文本框、日期选择器等&#xff09;和标签页&#xff08;如表单的不同部分&#xff09;的显示或隐藏。 这种规则通常基于用户的选择或满足特定条件来触发&#…

vue3实现echarts——小demo

版本&#xff1a; 效果&#xff1a; 代码&#xff1a; <template><div class"middle-box"><div class"box-title">检验排名TOP10</div><div class"box-echart" id"chart1" :loading"loading1"&…

【Portswigger 学院】路径遍历

路径遍历&#xff08;Path traversal&#xff09;又称目录遍历&#xff08;Directory traversal&#xff09;&#xff0c;允许攻击者通过应用程序读取或写入服务器上的任意文件&#xff0c;例如读取应用程序源代码和数据、凭证和操作系统文件&#xff0c;或写入应用程序所访问或…

API Object设计模式

API测试面临的问题 API测试由于编写简单&#xff0c;以及较高的稳定性&#xff0c;许多公司都以不同工具和框架维护API自动化测试。我们基于seldom框架也积累了几千条自动化用例。 •简单的用例 import seldomclass TestRequest(seldom.TestCase):def test_post_method(self…

GDB 远程调试简介

文章目录 1. 前言2. GDB 远程调试2.1 准备工作2.1.1 准备 客户端 gdb 程序2.1.2 准备 服务端 gdbserver2.1.3 准备 被调试程序 2.2 调试2.2.1 通过网络远程调试2.2.1.1 通过 gdbserver 直接启动程序调试2.2.1.2 通过 gdbserver 挂接到已运行程序调试 2.2.2 通过串口远程调试2.2…

紫鸟浏览器搭配IPXProxy代理IP的高效使用指南

​紫鸟指纹浏览器一款专门为跨境电商而生的防关联浏览器&#xff0c;能够帮助跨境电商卖家解决多店铺管理问题。紫鸟指纹浏览器为跨境电商卖家提供稳定的登录环境&#xff0c;并且搭配IP代理&#xff0c;能够解决浏览器指纹记录问题&#xff0c;提高操作的安全性。那如何利用紫…

广州AI绘图模型训练外包定制公司

&#x1f680;设计公司如何借助AI人工智能降本增效&#xff0c;广州这家AI公司值得借鉴— 触站AI&#xff0c;智能图像的创新引擎 &#x1f31f; &#x1f3a8; 触站AI&#xff0c;绘制设计界的未来蓝图 &#x1f3a8;在AI技术的浪潮中&#xff0c;触站AI以其前沿的AI图像技术…

RK3568驱动指南|第十六篇 SPI-第188章 mcp2515驱动编写:复位函数

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

Redux 使用及基本原理

什么是Redux Redux 是用于js应用的状态管理库&#xff0c;通常和React一起用。帮助开发者管理应用中各个组件之间的状态&#xff0c;使得状态的变化变得更加可预测和易于调试。 Redu也可以不和React组合使用。&#xff08;通常一起使用&#xff09; Redux 三大原则 单一数据源…

在uni-app使用vue3使用vuex

在uni-app使用vue3使用vuex 1.在项目目录中新建一个store目录&#xff0c;并且新建一个index.js文件 import { createStore } from vuex;export default createStore({//数据&#xff0c;相当于datastate: {count:1,list: [{name: 测试1, value: test1},{name: 测试2, value: …

从hugging face 下模型

支持国内下载hugging face 的东西 下模型权重 model_id 是红色圈复制的 代码 记得设置下载的存储位置 import os from pathlib import Path from huggingface_hub import hf_hub_download from huggingface_hub import snapshot_downloadmodel_id"llava-hf/llava-v1…

Swift 中强大的 Key Paths(键路径)机制趣谈(下)

概览 在上一篇博文 Swift 中强大的 Key Paths(键路径)机制趣谈(上)中,我们介绍了 Swift 语言中键路径机制的基础知识,并举了若干例子讨论了它的一些用武之地。 而在本文中我们将再接再厉,继续有趣的键路径大冒险,为 KeyPaths 画上一个圆满的句号。 在本篇博文中,您将…

C++:二维数组的遍历

方式一&#xff1a; #include <vector> #include <iostream> int main() { // 初始化一个2x3的二维向量&#xff08;矩阵&#xff09; std::vector<std::vector<float>> matrix { {1.0, 2.0, 3.0}, // 第一行 {4.0, 5.0, 6.0} // 第二行 };…