【RL Base】强化学习:信赖域策略优化(TRPO)算法

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】(51)---《信赖域策略优化(TRPO)算法》

信赖域策略优化(TRPO)算法

目录

1.信赖域策略优化

2.什么是 TRPO?

3.TRPO 的公式解析

1. 策略优化目标

2. KL 散度约束

3. Fisher 信息矩阵

4. 二次约束优化

[Python] TRPO算法实现

1.TRPO算法伪代码

2.TRPO算法pytorch实现代码

[Notice]  注意事项

4.TRPO 的优势

5.TRPO 的应用场景

6.结论


1.信赖域策略优化

        在强化学习(RL)领域,如何稳定地优化策略是一个核心挑战。2015 年,由 John Schulman 等人提出的信赖域策略优化(Trust Region Policy Optimization, TRPO)算法为这一问题提供了优雅的解决方案。TRPO 通过限制策略更新的幅度,避免了策略更新过大导致的不稳定问题,是强化学习中经典的策略优化方法之一。


2.什么是 TRPO?

        TRPO 是一种基于策略梯度的优化算法,其目标是通过限制新策略和旧策略之间的差异来确保训练的稳定性。TRPO 在高维、连续动作空间中表现尤为出色,尤其适用于机器人控制、游戏 AI 等领域。

基本思想

        TRPO 的核心目标是寻找新策略 ( \pi_\theta ),使得累积奖励最大化,同时限制新旧策略之间的变化幅度。数学形式化如下:

\max_\theta \mathbb{E}{s \sim \rho{\pi_\text{old}}, a \sim \pi_\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} A_{\pi_\text{old}}(s, a) \right]

其中:

  • ( \pi_\text{old} ): 旧策略。
  • ( \pi_\theta ): 新策略。
  • ( A_{\pi_\text{old}}(s, a) ): 优势函数,衡量动作( a ) 相对于状态( s )的优劣。

TRPO 在优化过程中通过限制新旧策略之间的 KL 散度来实现稳定更新:

[ \mathbb{E}{s \sim \rho{\pi_\text{old}}} \left[ \text{KL}(\pi_\text{old} | \pi_\theta) \right] \leq \delta ]


3.TRPO 的公式解析

1. 策略优化目标

TRPO 的优化目标是:

L(\pi_\theta) = \mathbb{E}{s, a \sim \pi\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} A_{\pi_\text{old}}(s, a) \right]

其中:

  • ( \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} ): 概率比率,衡量新旧策略的变化。
  • ( A_{\pi_\text{old}}(s, a) ): 优势函数,用于评价动作的相对收益。

2. KL 散度约束

        通过添加 KL 散度约束,TRPO 控制新策略( \pi_\theta )与旧策略( \pi_\text{old} )的差异:

\mathbb{E}{s \sim \rho{\pi_\text{old}}} \left[ \text{KL}(\pi_\text{old} | \pi_\theta) \right] \leq \delta

  • ( \text{KL}(\pi_\text{old} | \pi_\theta) ): 衡量新旧策略概率分布的差异。
  • ( \delta ): 预设的信赖域阈值。

3. Fisher 信息矩阵

        KL 散度的二阶导数可以通过 Fisher 信息矩阵近似表示:

H = \mathbb{E}{s \sim \rho{\pi_\text{old}}, a \sim \pi_\text{old}} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]

4. 二次约束优化

        TRPO 的优化过程最终转化为一个带二次约束的优化问题:

\max_\theta g^T \delta\theta \quad \text{subject to } \delta\theta^T H \delta\theta \leq \delta

  • ( g = \nabla_\theta L(\pi_\theta) ): 策略梯度。
  • ( H ): Fisher 信息矩阵,用于表示 KL 散度的变化。

通过拉格朗日乘数法,TRPO 可以高效求解上述问题。


[Python] TRPO算法实现

1.TRPO算法伪代码

"""《TRPO 的实现流程》
    时间:2024.11
    作者:不去幼儿园
"""
初始化策略参数 θ 和目标网络
for 迭代轮次 do
    1. 使用策略 π_θ 采样轨迹 (s, a, r)
    2. 计算优势函数 A(s, a) (可以用 GAE 方法)
    3. 计算策略梯度 g = ∇_θ L(θ)
    4. 估计 Fisher 信息矩阵 H
    5. 通过优化问题 max δθ^T g, subject to δθ^T H δθ ≤ δ 更新策略
    6. 更新目标网络
end for

2.TRPO算法pytorch实现代码

import argparse
from itertools import count

import gym
import scipy.optimize

import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from trpo import trpo_step
from utils import *

torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.995, metavar='G',
                    help='discount factor (default: 0.995)')
parser.add_argument('--env-name', default="Reacher-v1", metavar='G',
                    help='name of the environment to run')
parser.add_argument('--tau', type=float, default=0.97, metavar='G',
                    help='gae (default: 0.97)')
parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G',
                    help='l2 regularization regression (default: 1e-3)')
parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G',
                    help='max kl value (default: 1e-2)')
parser.add_argument('--damping', type=float, default=1e-1, metavar='G',
                    help='damping (default: 1e-1)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 1)')
parser.add_argument('--batch-size', type=int, default=15000, metavar='N',
                    help='random seed (default: 1)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()

env = gym.make(args.env_name)

num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

env.seed(args.seed)
torch.manual_seed(args.seed)

policy_net = Policy(num_inputs, num_actions)
value_net = Value(num_inputs)

def select_action(state):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def update_params(batch):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net, torch.Tensor(flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
    set_flat_params_to(value_net, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)

for i_episode in count(1):
    memory = Memory()

    num_steps = 0
    reward_batch = 0
    num_episodes = 0
    while num_steps < args.batch_size:
        state = env.reset()
        state = running_state(state)

        reward_sum = 0
        for t in range(10000): # Don't infinite loop while learning
            action = select_action(state)
            action = action.data[0].numpy()
            next_state, reward, done, _ = env.step(action)
            reward_sum += reward

            next_state = running_state(next_state)

            mask = 1
            if done:
                mask = 0

            memory.push(state, np.array([action]), mask, next_state, reward)

            if args.render:
                env.render()
            if done:
                break

            state = next_state
        num_steps += (t-1)
        num_episodes += 1
        reward_batch += reward_sum

    reward_batch /= num_episodes
    batch = memory.sample()
    update_params(batch)

    if i_episode % args.log_interval == 0:
        print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
            i_episode, reward_sum, reward_batch))
import numpy as np

import torch
from torch.autograd import Variable
from utils import *


def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x


def linesearch(model,
               f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1):
    fval = f(True).data
    print("fval before", fval.item())
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(True).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve
        print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            print("fval after", newfval.item())
            return True, xnew
    return False, x


def trpo_step(model, get_loss, get_kl, max_kl, damping):
    loss = get_loss()
    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

    def Fvp(v):
        kl = get_kl()
        kl = kl.mean()

        grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * Variable(v)).sum()
        grads = torch.autograd.grad(kl_v, model.parameters())
        flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data

        return flat_grad_grad_kl + v * damping

    stepdir = conjugate_gradients(Fvp, -loss_grad, 10)

    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]

    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
    print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))

    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep,
                                     neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)

    return loss

完成代码见附件 

     🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。


[Notice]  注意事项

        由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。


4.TRPO 的优势

  1. 更新稳定性
    TRPO 限制策略更新幅度,避免因更新过大导致的性能崩溃。

  2. 适用范围广泛
    TRPO 在高维连续动作空间中表现出色,适用于机器人控制等复杂任务。

  3. 高效性
    TRPO 在信赖域内寻找最优更新方向,平衡了稳定性和优化效率。


5.TRPO 的应用场景

1. 机器人控制

在动态环境中优化机器人动作,例如步态优化和抓取任务。

2. 游戏 AI

训练高维连续动作的游戏智能体,例如模拟物理环境中的策略优化。

3. 资源调度

优化任务分配和资源调度问题,例如云计算中的负载平衡。


6.结论

        TRPO 是强化学习领域的一个重要里程碑,其通过信赖域限制优化策略更新的方式,在稳定性和性能提升方面提供了良好的平衡。尽管 TRPO 在实践中因其计算复杂性被更新的算法(如 PPO)部分取代,但其核心思想仍然是深度强化学习发展的重要基础。

参考文献:Trust Region Policy Optimization

 更多自监督强化学习文章,请前往:【强化学习(RL)】专栏


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928207.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

黑马2024AI+JavaWeb开发入门Day04-SpringBootWeb入门-HTTP协议-分层解耦-IOCDI飞书作业

视频地址&#xff1a;哔哩哔哩 讲义作业飞书地址&#xff1a;day04作业&#xff08;IOC&DI&#xff09; 作业很简单&#xff0c;主要是练习拆分为三层架构controller、service、dao&#xff0c;并基于IOC & DI进行解耦。 1、结构&#xff1a; 2、代码 网盘链接&…

(长期更新)《零基础入门 ArcGIS(ArcMap) 》实验三----学校选址与路径规划(超超超详细!!!)

目录 实验三 学校选址与道路规划 3.1 实验内容及目的 3.1.1 实验内容 3.1.2 实验目的 3.2 实验方案 3.3 操作流程 3.3.1 环境设置 3.3.2 地势分析 &#xff08;1&#xff09;提取坡度: (2)重分类: 3.3.3 学校点分析 (1)欧氏距离: (2)重分类: 3.3.4 娱乐场所点分析 (1)欧氏距离…

【Python网络爬虫笔记】8- (BeautifulSoup)抓取电影天堂2024年最新电影,并保存所有电影名称和链接

目录 一. BeautifulSoup的作用二. 核心方法介绍2.1 构造函数2.2 find()方法2.3 find_all()方法2.4 select()方法 三. 网络爬虫中使用BeautifulSoup四、案例爬取结果 一. BeautifulSoup的作用 解析HTML/XML文档&#xff1a;它可以将复杂的HTML或XML文本转换为易于操作的树形结构…

MATLAB期末复习笔记(中)

目录 三、MATLAB函数和程序结构 1.MATLAB文件 2.变量和数据类型 &#xff08;1&#xff09;变量 &#xff08;2&#xff09;变量类型 &#xff08;3&#xff09;字符串 3.函数文件 &#xff08;1&#xff09;函数文件规范 &#xff08;2&#xff09;子函数和私有函数 &…

算法刷题Day8:BM30 二叉搜索树与双向链表

题目 牛客网题目传送门 思路 对二叉搜索树进行中序遍历&#xff0c;结果就是按序数组。因此想办法把前面遍历过的节点给记下来&#xff0c;记作pre。当遍历到某个节点node的时候&#xff0c;令前驱指向pre&#xff0c;然后让pre的后驱指向node。 代码 class TreeNode:def…

深入解析 Dubbo 中的常见问题及优化方案: 数据量限制与配置错误20241203

&#x1f31f; 深入解析 Dubbo 中的常见问题及优化方案&#xff1a;数据量限制与配置错误 在分布式系统中&#xff0c;Dubbo 作为高性能的 RPC 框架广泛应用于企业服务化架构。然而&#xff0c;在实际使用过程中&#xff0c;开发者往往会遇到一些复杂问题&#xff0c;比如 数据…

debian ubuntu armbian部署asp.net core 项目 开机自启动

我本地的环境是 rk3399机器&#xff0c;安装armbian系统。 1.安装.net core 组件 sudo apt-get update && \sudo apt-get install -y dotnet-sdk-8.0或者安装运行库&#xff0c;但无法生成编译项目 sudo apt-get update && \sudo apt-get install -y aspnet…

【AI系统】Ascend C 编程范式

Ascend C 编程范式 AI 的发展日新月异&#xff0c;AI 系统相关软件的更新迭代也是应接不暇&#xff0c;作为一本讲授理论的作品&#xff0c;我们将尽可能地讨论编程范式背后的原理和思考&#xff0c;而少体现代码实现&#xff0c;以期让读者理解 Ascend C 为何这样设计&#x…

hadoop环境配置-创建hadoop用户+更新apt+安装SSH+配置Java环境

一、创建hadoop用户(在vm安装的ubantu上打开控制台) 1、sudo useradd -m hadoop -s /bin/bash &#xff08;创建hadoop用户&#xff09; 2、sudo passwd hadoop (设置密码) 3、sudo adduser hadoop sudo&#xff08;将新建的hadoop用户设置为管理员&#xff09; 执行如下图 将…

嵌入式系统应用-LVGL的应用-平衡球游戏 part1

平衡球游戏 part1 1 平衡球游戏的界面设计2 界面设计2.1 背景设计2.2 球的设计2.3 移动球的坐标2.4 用鼠标移动这个球2.5 增加边框规则2.6 效果图 3 为小球增加增加动画效果3.1 增加移动效果代码3.2 具体效果图片 平衡球游戏 part2 第二部分文章在这里 1 平衡球游戏的界面设计…

从被动响应到主动帮助,ProActive Agent开启人机交互新篇章

在人工智能领域&#xff0c;我们正见证着一场革命性的变革。传统的AI助手&#xff0c;如ChatGPT&#xff0c;需要明确的指令才能执行任务。但现在&#xff0c;清华大学联合面壁智能等团队提出了一种全新的主动式Agent交互范式——ProActive Agent&#xff0c;它能够主动观察环境…

2.mysql 中一条更新语句的执行流程是怎样的呢?

前面我们系统了解了一个查询语句的执行流程&#xff0c;并介绍了执行过程中涉及的处理模块。 相信你还记得&#xff0c;一条查询语句的执行过程一般是经过连接器、分析器、优化器、执行器等功能模块&#xff0c;最后到达存储引擎。 那么&#xff0c;一条更新语句的执行流程又…

NaviveUI框架的使用 ——安装与引入(图标安装与引入)

文章目录 概述安装直接引入引入图标样式库 概述 &#x1f349;Naive UI 是一个轻量、现代化且易于使用的 Vue 3 UI 组件库&#xff0c;它提供了一组简洁、易用且功能强大的组件&#xff0c;旨在为开发者提供更高效的开发体验&#xff0c;特别是对于构建现代化的 web 应用程序。…

web vue 滑动选择 n宫格选中 九宫格选中

页面动态布局经常性要交给客户来操作&#xff0c;他们按时他们的习惯在同一个屏幕内显示若干个子视图&#xff0c;尤其是在医学影像领域对于影像的同屏显示目视对比显的更为重要。 来看看如下的用户体验&#xff1a; 设计为最多支持5行6列页面展示后&#xff0c;右侧的布局则动…

ELK的Filebeat

目录 传送门前言一、概念1. 主要功能2. 架构3. 使用场景4. 模块5. 监控与管理 二、下载地址三、Linux下7.6.2版本安装filebeat.yml配置文件参考&#xff08;不要直接拷贝用&#xff09;多行匹配配置过滤配置最终配置&#xff08;一、多行匹配、直接读取日志文件、EFK方案&#…

C#调用c++创建的动态链接库dll文件

在C#中调用外部DLL文件是一种常见的编程实践&#xff0c;它具有以下几个重要意义&#xff1a;1.代码重用&#xff1b;2.模块化&#xff1b;3.性能优化&#xff1b;4.安全性&#xff1b;5.跨平台兼容性&#xff1b;6.方便更新和维护&#xff1b;7.利用特定技术或框架&#xff1b…

重建大师重建的模型坐标有偏差怎么解决?

第一遍自由网空三&#xff0c;跑完之后刺点&#xff0c;然后控制点平差增强参数解算&#xff0c;方法如下&#xff1a; &#xff08;1&#xff09;跑完自由网空三后&#xff0c;选择编辑控制点&#xff0c;出现刺点窗口后&#xff0c;导入控制点参数 &#xff08;2&#xff09…

Apache Airflow 快速入门教程

Apache Airflow已经成为Python生态系统中管道编排的事实上的库。与类似的解决方案相反&#xff0c;由于它的简单性和可扩展性&#xff0c;它已经获得了普及。在本文中&#xff0c;我将尝试概述它的主要概念&#xff0c;并让您清楚地了解何时以及如何使用它。 Airflow应用场景 …

GEE Download Data——气温数据的下载

GEE数据下载第二弹!今天我们来分享气温数据的下载。 一、数据介绍 气温数据我们要用到的是MODIS数据产品,MOD11A2 V6.1 产品提供 1200 x 1200 公里网格内 8 天平均陆地表面温度 (LST)。 MOD11A2 中的每个像素值都是该 8 天内收集的所有相应 MOD11A1 LST 像素的简单平均值。…

分布式推理框架 xDit

1. xDiT 简介 xDiT 是一个为大规模多 GPU 集群上的 Diffusion Transformers&#xff08;DiTs&#xff09;设计的可扩展推理引擎。它提供了一套高效的并行方法和 GPU 内核加速技术&#xff0c;以满足实时推理需求。 1.1 DiT 和 LLM DiT&#xff08;Diffusion Transformers&am…