ETH开源PPO算法学习

前言

项目地址:https://github.com/leggedrobotics/rsl_rl

项目简介:快速简单的强化学习算法实现,设计为完全在 GPU 上运行。这段代码是 NVIDIA Isaac GYM 提供的 rl-pytorch 的进化版。

下载源码,查看目录,整个项目模块化得非常好,每个部分各司其职。下面我们自底向上地进行讲解加粗的部分。

rsl_rl/
│ __init__.py

├─algorithms/
│ │ __init__.py
│ │ ppo.py # PPO算法的实现
│ │
├─env/
│ │ __init__.py
│ │ vec_env.py # 实现并行处理多个环境的向量化环境
│ │
├─modules/
│ │ __init__.py
│ │ actor_critic.py # 定义 Actor-Critic 网络结构
│ │ actor_critic_recurrent.py # 定义包含循环层的 Actor-Critic 网络
│ │ normalizer.py # 数据正规化工具,有助于训练过程的稳定性
│ │
├─runners/
│ │ __init__.py
│ │ on_policy_runner.py # 实现用于执行 on-policy 算法训练循环的运行器
│ │
├─storage/
│ │ __init__.py
│ │ rollout_storage.py # 存储和管理策略 rollout 数据的工具
│ │
└─utils/
│ __init__.py
│ neptune_utils.py # 用于与 Neptune.ai 集成的工具
│ utils.py # 通用实用工具函数
│ wandb_utils.py # 用于与 Weights & Biases 集成的工具

rollout 数据储存和管理(rollout_storage.py)

定义了一个名为 RolloutStorage 的类,用于存储和管理在强化学习训练过程中从环境中收集到的数据(称为rollouts)。

  • 定义Transition

用于存储单个时间步的所有相关数据,包括观察值、动作、奖励、完成标志(dones)、值函数估计、动作的对数概率、动作的均值和标准差,以及可能的隐藏状态(对于使用循环网络的情况)。

  • 特权观察值(Privileged Observations)

除了self.observations外还有self.privileged_observations的使用,在强化学习中是指那些在训练期间可用但在实际部署或测试时不可用的额外信息。这些信息通常提供了环境的内部状态或其他有助于学习的提示,但在现实世界应用中可能难以获得或完全不可用。在训练期间使用特权观察值的一种常见方法是通过教师-学生架构(我们常常也称作特权学习),其中一个拥有全部信息的教师模型(可以访问特权观察值)来指导一个学生模型(只能访问普通观察值)。学生模型的目标是模仿教师模型的决策,尽管它没有直接访问特权信息。

  • 奖励和优势的计算
    def compute_returns(self, last_values, gamma, lam):
        advantage = 0
        for step in reversed(range(self.num_transitions_per_env)):
            if step == self.num_transitions_per_env - 1:
                next_values = last_values
            else:
                next_values = self.values[step + 1]
            next_is_not_terminal = 1.0 - self.dones[step].float()
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            self.returns[step] = advantage + self.values[step]

        # Compute and normalize the advantages
        self.advantages = self.returns - self.values
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

这段代码实现的是在强化学习中计算回报(returns)和优势(advantages)的逻辑,具体是使用了一种称为广义优势估算(Generalized Advantage Estimation, GAE)的方法。GAE是一种权衡偏差和方差以及平滑回报信号的技术,由以下几个数学公式定义:

  1. TD残差(Temporal Difference Residual):
    δ t = R t + γ V ( S t + 1 ) ( 1 − d o n e t ) − V ( S t ) \delta_t = R_t + \gamma V(S_{t+1}) (1 - done_t) - V(S_t) δt=Rt+γV(St+1)(1donet)V(St)
    其中, δ t \delta_t δt是时刻 t t t的TD残差, R t R_t Rt是奖励, γ \gamma γ是折扣因子, V ( S t ) V(S_t) V(St)是状态 S t S_t St的价值函数估计, d o n e t done_t donet是表示当前状态是否为终止状态的指示函数(如果当前状态为终止状态,则 d o n e t = 1 done_t = 1 donet=1;否则, d o n e t = 0 done_t = 0 donet=0)。如果 d o n e t = 1 done_t = 1 donet=1,那么 γ V ( S t + 1 ) \gamma V(S_{t+1}) γV(St+1)项将为 0,因为终止状态之后没有未来回报。

  2. GAE优势估计:
    A t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE(γ,λ)=l=0(γλ)lδt+l
    在代码中,这个无限求和是通过迭代地计算来近似的,具体的迭代公式为:
    A t = δ t + ( γ λ ) A t + 1 ( 1 − d o n e t ) A_t = \delta_t + (\gamma \lambda) A_{t+1} (1 - done_t) At=δt+(γλ)At+1(1donet)
    其中, A t A_t At是时刻 t t t的优势估计, λ \lambda λ是用来平衡TD估计和蒙特卡罗估计之间权重的参数。

  3. 回报的计算:
    G t = A t + V ( S t ) G_t = A_t + V(S_t) Gt=At+V(St)
    其中, G t G_t Gt是时刻 t t t的回报估计。

代码中使用的变量名与数学符号的对应关系:

变量名数学符号含义
rewards[step] R t R_t Rt时刻 t t t的奖励
gamma γ \gamma γ折扣因子,用于计算未来奖励的现值
values[step] V ( S t ) V(S_t) V(St)状态 S t S_t St在当前策略下的价值函数估计
dones[step] d o n e t done_t donet指示当前状态 S t S_t St是否为终止状态的标志(1 表示终止,0 表示非终止)
delta δ t \delta_t δt时刻 t t t的 TD 残差
advantage A t A_t At时刻 t t t的优势估计,根据 GAE 方法计算
lam λ \lambda λ用于 GAE 计算中平衡 TD 估计和蒙特卡罗估计之间权重的参数
returns[step] G t G_t Gt时刻 t t t的回报估计
advantages A t n o r m A_t^{norm} Atnorm标准化后的优势估计
mu_A, sigma_A μ A \mu_A μA, σ A \sigma_A σA优势估计的平均值和标准差
epsilon ϵ \epsilon ϵ避免除零错误而加的小常数,通常取值为 1e-8

代码中的循环从最后一个转换开始向前迭代,使用以上的数学公式来计算每一步的优势和回报。最后,它还对优势进行了标准化处理,即从每个优势中减去所有优势的平均值,并除以标准差,以减少训练期间的方差并加速收敛。标准化公式如下:
A t n o r m = A t − μ A σ A + ϵ A_t^{norm} = \frac{A_t - \mu_A}{\sigma_A + \epsilon} Atnorm=σA+ϵAtμA
其中, μ A \mu_A μA是优势的平均值, σ A \sigma_A σA是优势的标准差, ϵ \epsilon ϵ​ 是为了防止除以零而加的一个小常数(在代码中为 1e-8)。

  • 轨迹的平均长度

类中并没有显式存储轨迹的长度,轨迹长度隐含在self.dones之中。代码中使用的方法是:将每个环境中最后一步置为‘1’,然后flatten(展开)、拼接所有环境中的dones得到flat_dones,差分数组中为‘1’位置的索引得到智能体在每个环境中的步数,即轨迹长度。这个统计量有助于了解训练过程中智能体的表现。

  • mini-batch迭代器

mini_batch_generator 函数通过在多个训练周期(num_epochs)内,从经验回放缓冲区中随机选择小批量数据(包括观察值 observations、动作 actions、奖励 rewards 等)来生成小批量数据集。该函数利用 torch.randperm 生成随机索引 indices 来随机化数据抽样,进而支持基于批处理的学习方法,如梯度下降。通过每次只处理必要的数据量,该生成器在优化模型参数的同时,也优化了内存使用,确保了训练过程的高效性和灵活性。

(未完待续)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/414632.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年3月5-7日第12届生物发酵技术装备展-奥博仪表

参观企业介绍 潍坊奥博仪表科技发展有限公司成立于2002年3月,注册资金1000万元,已有20多年的发展历程,是一家专业从事流量仪表开发、生产与测控系统集成的高新技术企业和双软认证企业。 目前公司以仪表、通讯产品、自控系统、软件的研发、生…

DVWA 靶场之 Command Injection(命令执行)middlehigh

对于 middle 难度的 我们直接先看源码 <?phpif( isset( $_POST[ Submit ] ) ) {// Get input$target $_REQUEST[ ip ];// Set blacklist$substitutions array(&& > ,; > ,);// Remove any of the characters in the array (blacklist).$target str_rep…

Pytest教程:一种利用 Python Pytest Hook 机制的软件自动化测试网络数据抓包方法

随着计算机技术的发展&#xff0c;使得网络应用的数量不断增加&#xff0c;因此网络数据抓包成为了网络应用开发和测试中非常重要的一部分。目前&#xff0c;已有许多网络数据抓包工具可供使用&#xff0c;例如 Wireshark、Tcpdump、Fiddler 等&#xff0c;但这些工具需要手动配…

5G网络频谱划分与应用

频率越大&#xff0c;波长越短。补充&#xff1a;微波频段&#xff1a;300MHZ~3000GHZ。 5G网络工作频带与带宽配置 &#xff08;1&#xff09; FR1:410MHZ~7125MHZ。 FR2&#xff1a;24250MHZ~52600MHZ 注&#xff1a;早期将6GHZ已下频段作为FR1&#xff0c;但是最新的频段…

Python打发无聊时光:10.用flask创造按键控制的网页小游戏

游戏介绍: 《秦蓝大冒险》是一款简单而紧张的追逐游戏。在这个游戏中&#xff0c;玩家将控制一名名叫“吕千”的角色&#xff0c;试图在一个封闭的空间内逃避一个名为“秦蓝”的追逐者。随着时间的推移&#xff0c;“秦蓝”会不断追踪玩家的位置&#xff0c;努力捕捉到他。 游…

C语言中如何进行内存管理

主页&#xff1a;17_Kevin-CSDN博客 收录专栏&#xff1a;《C语言》 C语言是一种强大而灵活的编程语言&#xff0c;但与其他高级语言不同&#xff0c;它要求程序员自己负责内存的管理。正确的内存管理对于程序的性能和稳定性至关重要。 一、引言 C 语言是一门广泛使用的编程语…

2.1_6 线程的实现方式和多线程模型

文章目录 2.1_6 线程的实现方式和多线程模型&#xff08;一&#xff09;线程的实现方式&#xff08;1&#xff09;用户级线程&#xff08;2&#xff09;内核级线程 &#xff08;二&#xff09;多线程模型&#xff08;1&#xff09;一对一模型&#xff08;2&#xff09;多对一模…

stable diffusion webUI之赛博菩萨【秋葉】——工具包新手安裝与使用教程

stable diffusion webUI之赛博菩萨【秋葉】——工具包新手安裝与使用教程 AI浪潮袭来&#xff0c;还是学习学习为妙赛博菩萨【秋葉】简介——&#xff08;葉ye&#xff0c;四声&#xff0c;同叶&#xff09;A绘世启动器.exe&#xff08;sd-webui-aki-v4.6.x&#xff09;工具包安…

VirtualBox虚拟机配置双网卡

安装完后Linux后。下一步需要设置Linux的网络。为了便于学习测试&#xff0c;通常我们需要虚拟机能通宿主机和外网。类似达到下面的效果。虚拟机跟宿主本机和外网&#xff0c;及另外一台同网段的虚拟机也是相通 解决思路是-->配置双网卡&#xff0c;网卡1使用NAT网络模式&a…

vue组件中data为什么必须是一个函数

查看本专栏目录 关于作者 还是大剑师兰特&#xff1a;曾是美国某知名大学计算机专业研究生&#xff0c;现为航空航海领域高级前端工程师&#xff1b;CSDN知名博主&#xff0c;GIS领域优质创作者&#xff0c;深耕openlayers、leaflet、mapbox、cesium&#xff0c;canvas&#x…

fiddler抓包工具使用(一)

一、fiddler简介 1. 简介 fiddler是一款强大的抓包工具&#xff0c;它的原理以web代理服务器的形式进行工作fiddler是好用的web调试工具之一 能记录所有客户端和服务器的http和https请求修改输入、输出数据包数据允许监视设置断点弱网测试 2. 工作原理 代理就是在客户端和服…

QT C++实战:实现用户登录页面及多个界面跳转

主要思路 一个登录界面&#xff0c;以管理员Or普通用户登录管理员&#xff1a;一个管理员的操作界面&#xff0c;可以把数据录入到数据库中。有返回登陆按钮&#xff0c;可以选择重新登陆&#xff08;管理员Or普通用户普通用户&#xff1a;一个主界面&#xff0c;负责展示视频…

java动态代理面试题,java反射原理面试题

01 并发宝典&#xff1a;面试专题 面试专题分为四个部分&#xff0c;分别如下 Synchronized 相关问题 可重入锁 ReentrantLock 及其他显式锁相关问题 Java 线程池相关问题 Java 内存模型相关问题 1.1 Synchronized 相关问题&#xff08;这里整理了八问&#xff09; 问题一…

揭示预处理中的秘密!(二)

目录 ​编辑 1. #运算符 2. ##运算符 3. 命名约定 4. #undef 5. 命令行定义 6. 条件编译 7. 头文件的被包含的方式 8.嵌套文件包含 9. 其他预处理指令 10. 完结散花 悟已往之不谏&#xff0c;知来者犹可追 …

【Go-Zero】测试API查询信息无法返回数据库信息与api、rpc文件编写规范

【Go-Zero】测试API查询信息无法返回数据库信息与api、rpc文件编写规范 大家好 我是寸铁&#x1f44a; 总结了一篇测试API查询信息无法返回数据库信息与api、rpc文件编写规范的文章✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 问题背景 大家好&#xff0c;我是寸铁&#xff01…

C++——基础语法(2):函数重载、引用

4. 函数重载 函数重载就是同一个函数名可以重复被定义&#xff0c;即允许定义相同函数名的函数。但是相同名字的函数怎么在使用的时候进行区分呢&#xff1f;所以同一个函数名的函数之间肯定是要存在不同点的&#xff0c;除了函数名外&#xff0c;还有返回类型和参数两部分可以…

前后端项目-part03

文章目录 5.4.4 机构名称5.4.4.1 创建实体类Company5.4.4.2 创建实体类CompanyMapper5.4.4.3 创建实体类CompanyService5.4.4.4 创建实体类CompanyController5.4.4.5 后端测试5.4.4.6 修改basic.js5.4.4.7 修改course.vue5.4.4.8 测试5.4.5 课程标签5.4.5.1 效果5.4.5.2 修改co…

golang学习5,glang的web的restful接口

1. //返回json r.GET("/getJson", controller.GetUserInfo) package mainimport (/*"net/http"*/"gin/src/main/controller""github.com/gin-gonic/gin" )func main() {r : gin.Default()r.GET("/get", func(ctx *…

【Linux系统化学习】信号概念和信号的产生

目录 信号的概念 从生活中的例子中感知信号 前台进程和后台进程 前台进程 后台进程 操作系统如何知道用户向键盘写入数据了&#xff1f; 进程如何得知自己收到了信号&#xff1f; 信号捕捉 signal函数 Core Dump&#xff08;核心转储&#xff09; 信号产生的方式 通…

如何选择合适的汽车芯片ERP系统?

随着汽车产业的飞速发展&#xff0c;汽车芯片作为关键组件&#xff0c;其管理变得愈发重要。为了高效管理汽车芯片的生产、销售、库存等各个环节&#xff0c;许多企业开始引入汽车芯片ERP(企业资源规划)系统。那么&#xff0c;如何选择合适的汽车芯片ERP系统呢? 明确需求是关键…