Isaac Lab 使用 Stable Baselines3 实现 Multi Input Policy

目前Isaac Lab支持的强化学习框架

Isaac Lab支持的强化学习框架介绍icon-default.png?t=N7T8http://t.csdnimg.cn/h8u7Z调研下来,能够实现字典状态量,也就是多输入状态量的有

rsl_rl、sb3、(skrl不确定),rl_games是显然不支持的,自己改了一版,花了很长时间,目前训练还不收敛,个人觉得rl_games定制网络和策略不那么友好。rsl_rl关节类的研究对象用这个多一些,但是目前master分支只支持PPO算法,algorithms分支支持算法很多,但是没有合并到master,使用不方便;sb3比较通用一点;rl_games 是 NVIDIA Isaac Gym里官方使用的,skrl比较新官方文档看着很强大。

目前Lab里rl_games 的 observation_space 只支持 gym.spaces.Box

如果你想添加新的强化学习框架,可以在这里修改:

omni.isaac.lab_tasks\omni\isaac\lab_tasks\utils\wrappers

 Adding your own learning library — Isaac Lab documentation (isaac-sim.github.io)

好了,Stable Baselines3 是直接可以用 ``MultiInputPolicy`` 的

 目前Lab里sb3的 observation dict 处理过程

sb3官方对多输入状态量的描述与例子,如图像 + 舵机数据这种,还简单定义了``SimpleMultiObsEnv``演示环境

Dict Observations

You can use environments with dictionary observation spaces. This is useful in the case where one can’t directly concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles). Stable Baselines3 provides SimpleMultiObsEnv as an example of this kind of setting. The environment is a simple grid world, but the observations for each cell come in the form of dictionaries. These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.

Examples — Stable Baselines3 2.4.0a3 documentation (stable-baselines3.readthedocs.io)

from stable_baselines3 import PPO
from stable_baselines3.common.envs import SimpleMultiObsEnv


# Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations
env = SimpleMultiObsEnv(random_start=False)

model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=100_000)

那``MultiInputPolicy``使用的策略网络是什么,Stable Baselines3 支持使用 Dict Gym 空间处理多个输入。这可以使用 MultiInputPolicy 来完成,默认情况下,它使用 CombinedExtractor 特征提取器将多个输入转换为单个向量,由 net_arch 网络处理。具体见1、2、3描述

策略网络 — Stable Baselines3 2.4.0a3 文档 --- Policy Networks — Stable Baselines3 2.4.0a3 documentation (stable-baselines3.readthedocs.io)

Multiple Inputs and Dictionary Observations

By default, CombinedExtractor processes multiple inputs as follows:
默认情况下, CombinedExtractor 按如下方式处理多个输入:

  1. If input is an image (automatically detected, see common.preprocessing.is_image_space), process image with Nature Atari CNN network and output a latent vector of size 256.
    如果输入是图像(自动检测,请参阅 common.preprocessing.is_image_space ),则使用Nature Atari CNN网络处理图像并输出大小 256 为的潜在向量。

  2. If input is not an image, flatten it (no layers).
    如果输入不是图像,请将其拼合(无图层)。

  3. Concatenate all previous vectors into one long vector and pass it to policy.
    将所有先前的向量连接成一个长向量,并将其传递给策略。

那么如何来自定义特征提取网络呢?如CNN与MLP以及拼接大小

Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the observation space dictionary: “image” is a (1,H,W) image (channel first), and “vector” is a (D,) dimensional vector. We process “image” with a simple downsampling and “vector” with a single linear layer.
与上面非常相似,您可以定义自定义功能提取器。以下示例假定环境在观测空间字典中有两个键:“image”是 (1,H,W) 图像(通道优先),“vector”是 (D,) 维向量。我们用简单的下采样处理“图像”,用单个线性层处理“矢量”。

import gymnasium as gym
import torch as th
from torch import nn

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super().__init__(observation_space, features_dim=1)

        extractors = {}

        total_concat_size = 0
        # We need to know size of the output of this extractor,
        # so go over all the spaces and compute output feature sizes
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                # We will just downsample one channel of the image by 4x4 and flatten.
                # Assume the image is single-channel (subspace.shape[0] == 0)
                extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
                total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
            elif key == "vector":
                # Run through a simple MLP
                extractors[key] = nn.Linear(subspace.shape[0], 16)
                total_concat_size += 16

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = total_concat_size

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))
        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)

环境定义,这里使用的是DirectRLEnv;可以参考 cartpole_camera_env.py

添加摄像头:

# camera
tiled_camera: TiledCameraCfg = TiledCameraCfg(
        prim_path="/World/envs/env_.*/Robot/body/Camera",
        offset=TiledCameraCfg.OffsetCfg(pos=(0.0, 0.0, 0.05), rot=(1.0, 0.0, 0.0, 0.0), convention="ros"),
        data_types=["rgb"],
        spawn=sim_utils.PinholeCameraCfg(
            focal_length=24.0, focus_distance=400.0, horizontal_aperture=20.955, clipping_range=(0.1, 10.0)
        ),
        width=640,
        height=480,
    )

由于自定义space,需要重写父类 ``DirectRLEnv`` 方法 _configure_gym_env_spaces,如我这里定义的:

def _configure_gym_env_spaces(self):
    """Configure the action and observation spaces for the Gym environment."""
    # observation space (unbounded since we don't impose any limits)
    self.num_actions = self.cfg.num_actions
    self.num_observations_img = self.cfg.num_observations_img
    self.num_observations_vec = self.cfg.num_observations_vec
    self.num_states = self.cfg.num_states

    # set up spaces
    self.single_observation_space = gym.spaces.Dict()

    self.single_observation_space["policy"] = gym.spaces.Dict()
    self.single_observation_space["policy"]["img"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.cfg.tiled_camera.height, self.cfg.tiled_camera.width, self.cfg.num_channels),)
    self.single_observation_space["policy"]["vec"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_observations_vec,))

    if self.num_states > 0:
        self.single_observation_space["critic"] = gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(self.num_observations_vec, ),
        )

    self.single_action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.num_actions,))

    # batch the spaces for vectorized environments
    self.observation_space = gym.vector.utils.batch_space(self.single_observation_space, self.num_envs)
    self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs)

具体 _get_observations 实现 (省略里面具体状态量获取)

def _get_observations(self) -> dict:
    # 具体 observations 实现
    observations = {"policy": {"img": self._tiled_camera.data.output[data_type].clone(), "vec": obs}}
    return observations

sb3策略配置,policy 部分为 'MultiInputPolicy'

# Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L32
seed: 42

n_timesteps: 20000
policy: 'MultiInputPolicy'
n_steps: 16
batch_size: 64
gae_lambda: 0.95
gamma: 0.99
n_epochs: 10
ent_coef: 0.01
learning_rate: !!float 3e-4
clip_range: !!float 0.2
policy_kwargs: "dict(
                  activation_fn=nn.ELU,
                  net_arch=[256, 256, 256, 128],
                  squash_output=False,
                )"
vf_coef: 1.0
max_grad_norm: 1.0
device: "cuda:0"

 开始训练:

python source/standalone/workflows/sb3/train.py --task=你的智能体名 --headless --enable_cameras

PS:

测试发现,Stable Baselines3的训练速度比 rsl_rl 和 rl_games要慢,GPU利用率也低,不知道是不是超参数设置不一样,后续会继续对比;


欢迎加QQ群一起交流学习:723139415

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/739261.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

servlet的生命周期

1、Servlet的生命周期就是servlet类对象什么时候创建?什么时候调用对应的方法,什么时候销毁。 对象的生命周期: Student student new Student(); //创建对象 student.setName("eric"); // 使用对象 student.show();// 使用对象 student null; // 销毁…

踩坑——VS添加相对路径

需求:我需要将模型放到程序里面。 过程:附加包含目录添加目录,发现找不到onnx模型文件。我就想是不是相对路径不对,该来搞去都不对。 解决办法: 相对路径值得是运行程序的当下环境,什么是运行程序呢&…

Tomcat简介与安装

目录 一、Tomcat 简介 1、Tomcat好帮手---JDK 2、安装Tomcat & JDK 1、系统环境说明 2 、安装JDK 3、安装Tomcat 二、Tomcat目录介绍 1、tomcat主目录介绍 2、webapps目录介绍 3、Tomcat配置介绍(conf) 4、Tomcat的管理 5、tomcat 配置管…

微信支付还能这么玩?设置好自动扣费,停车费、电影票一键搞定

在这个快节奏的时代,微信支付以其便捷性成为我们日常生活中不可或缺的一部分。但你知道吗? 微信支付的功能远不止于此,它还能通过自动扣费功能,让我们的生活变得更加智能和轻松。从停车费到电影票,一键搞定&#xff0…

【Python/Pytorch - 网络模型】-- SVD算法

文章目录 文章目录 00 写在前面01 基于Pytorch版本的SVD算代码02 理论知识 00 写在前面 (1)矩阵的奇异值分解在最优化问题、特征值问题、最小二乘方问题、广义逆矩阵问题及统计学等方面都有重要应用; (2)应用&#…

pgAdmin后台命令执行漏洞(CVE-2023-5002)

​ 我们可以看到针对于漏洞 CVE-2022-4223,官方做了一定的修复措施。 web\pgadmin\misc_init_.py#validate_binary_path ​ 首先是添加了 login_required​ 进行权限校验。在 Flask 框架中,login_required​ 装饰器通常与 Flask-Login 扩展一起使用。…

探索Linux的奇妙世界 :第三关---Linux的基本指令(中篇)

1. man指令(重要) Linux的命令有很多参数,我们不可能全记住,我们可以通过查看联机手册获取帮助。访问 Linux 手册页的命令是man 语法 : man [ 选项 ] 命令。 常用选项: -k 根据关键字搜索联机帮助 num 只在第num章节找 -a 将所有章节的都显…

游戏行业新质生产力洞察报告 | 七成游戏企业技术投入显著增加 AI应用率99%

近日,伽马数据发布了《中国游戏产业新质生产力发展报告》。报告围绕中国游戏产业推动“新质生产力”发展的关键路径和重点领域进行深入讨论,并通过对相关数据和典型案例的深入分析,清晰呈现当前中国游戏企业在发展新质生产力过程中的探索与实…

【服务器02】之【阿里云平台】

百度一下阿里云官网 点击注册直接使用支付宝注册可以跳过认证 成功登录后,点击产品 点击免费试用 点击勾选 选一个距离最近的 点满GB 注意:一般试用的时用的是【阿里云】,真正做项目时用的是【腾讯云】 现在开始学习使用: 首先…

STM32学习之一:什么是STM32

目录 1.什么是STM32 2.STM32命名规则 3.STM32外设资源 4. STM32的系统架构 5. 从0到1搭建一个STM32工程 学习stm32已经很久了,因为种种原因,也有很久一段时间没接触过stm32了。等我捡起来的时候,发现很多都已经忘记了,重新捡…

2024年【低压电工】考试题库及低压电工考试报名

题库来源:安全生产模拟考试一点通公众号小程序 低压电工考试题库是安全生产模拟考试一点通总题库中生成的一套低压电工考试报名,安全生产模拟考试一点通上低压电工作业手机同步练习。2024年【低压电工】考试题库及低压电工考试报名 1、【单选题】()仪表…

计算机网路面试HTTP篇三

HTTPS RSA 握手解析 我前面讲,简单给大家介绍了的 HTTPS 握手过程,但是还不够细! 只讲了比较基础的部分,所以这次我们再来深入一下 HTTPS,用实战抓包的方式,带大家再来窥探一次 HTTPS。 对于还不知道对称…

【数列极限证明大题】解题方法,证明数列极限存在并求此极限,单调有界准则

文章目录 数列极限证明大题1.单调有界准则1.1 证有界性和单调性 1.2真题实战1.2 证明有界性中常用到的不等式 写在最前,持续更新中 数列极限证明大题 数列极限的证明大题的目标是,证明数列极限存在且求此极限。 核心方法是:单调有界准则&…

免费分享:2000-2020年中国长时间序列夜间灯光数据集(附下载方法)

夜间灯光数据集直观反映了地表夜间灯光亮度,进而揭示了人类活动强度,为分析城市扩张、人口迁移、经济发展等提供了连续、全面的视角,有助于深入理解中国城市化的历史进程和未来趋势。 数据简介 基于DMSP/OLS第四版非辐射定标夜间年平均灯光强…

green bamboo snake

green bamboo snake 【竹叶青蛇】 为什么写这个呢,因为回县城听说邻居有人被蛇咬伤,虽然不足以危及生命,严重的送去市里了。 1)这种经常都是一动不动,会躲在草地、菜地的菜叶里面、果树上、有时候会到民房大厅休息&a…

Python 接口自动化测试

一、基础准备 1. 环境搭建 工欲善其事必先利其器,废话不多说。我们先开始搭建环境。 # 创建项目目录mkdir InterfaceTesting# 切换到项目目录下cd InterfaceTesting# 安装虚拟环境创建工具pip install virtualenv# 创建虚拟环境,env代表虚拟环境的名称&…

1Panel应用推荐:Bitwarden开源密码管理器

1Panel(github.com/1Panel-dev/1Panel)是一款现代化、开源的Linux服务器运维管理面板,它致力于通过开源的方式,帮助用户简化建站与运维管理流程。为了方便广大用户快捷安装部署相关软件应用,1Panel特别开通应用商店&am…

中国港口年鉴(2000-2023年)

数据年限:2000-2023(齐全) 数据格式:pdf、excel 数据内容: 一、记述和反映了中国大陆江、海、河港口在深化改革、调整结构、整合资源、开拓经营、加快建设等方面所取得的成就和发展进程,香港特别行政区、澳…

YOLOv8改进 | SPPF | 具有多尺度带孔卷积层的ASPP【CVPR2018】

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡 专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有40篇内容,内含各种Head检测头、损失函数Loss、…

Linux源码阅读笔记04-实时调度类及SMP和NUMA

Linux进程分类 实时进程普通进程 如果系统中有一个实时进程并且可执行,调度器总是会选择他,除非有另外一个优先级高的实时进程。SCHED_FIFO:没有时间片,被调度器选择之后,可以运行任意长的时间。SCHED_RR:有…