robomimic应用教程(二)——策略运行与评估

得到训练好的pth后,下一并将其进行部署及效果评估

可以在jupyter notebook中进行此操作,文件为robomimic文件夹中的examples/notebooks/run_policy.ipynb

本文采用pycharm调试

该脚本用于在环境中评估策略,主要包括从model zoo下载checkpoint,在pytorch中加载checkpoint,并运行评估policy

目录

一、参数说明

二、Terminal运行

1. 执行评估策略

2. 保存方法1

3. 保存方法2

三、逐步运行与解析(run a trained policy and visualize the rollout)

1. 库引用

2. 下载policy checkpoint

3. 加载trained policy

4. 创建 rollout 环境

5. 定义 rollout 循环

6. 运行policy

7. 可视化 rollout

8.运行结果


一、参数说明

"""
The main script for evaluating a policy in an environment.

Args:
    agent (str): path to saved checkpoint pth file

    horizon (int): if provided, override maximum horizon of rollout from the one 
        in the checkpoint

    env (str): if provided, override name of env from the one in the checkpoint,
        and use it for rollouts

    render (bool): if flag is provided, use on-screen rendering during rollouts

    video_path (str): if provided, render trajectories to this video file path

    video_skip (int): render frames to a video every @video_skip steps

    camera_names (str or [str]): camera name(s) to use for rendering on-screen or to video

    dataset_path (str): if provided, an hdf5 file will be written at this path with the
        rollout data

    dataset_obs (bool): if flag is provided, and @dataset_path is provided, include 
        possible high-dimensional observations in output dataset hdf5 file (by default,
        observations are excluded and only simulator states are saved).

    seed (int): if provided, set seed for rollouts

Example usage:

    # Evaluate a policy with 50 rollouts of maximum horizon 400 and save the rollouts to a video.
    # Visualize the agentview and wrist cameras during the rollout.
    
    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --video_path /path/to/output.mp4 \
        --camera_names agentview robot0_eye_in_hand 

    # Write the 50 agent rollouts to a new dataset hdf5.

    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --dataset_path /path/to/output.hdf5 --dataset_obs 

    # Write the 50 agent rollouts to a new dataset hdf5, but exclude the dataset observations
    # since they might be high-dimensional (they can be extracted again using the
    # dataset_states_to_obs.py script).

    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --dataset_path /path/to/output.hdf5
"""

agent (str): 已保存的检查点模型文件路径(.pth文件)。

horizon (int):如果提供,将覆盖检查点中的最大回合长度,即在评估中运行多少步

env (str): 如果提供,将覆盖检查点中保存的环境名称,用于运行策略时创建新的环境

render (bool):如果提供该标志,在每次回合执行时显示屏幕上的实时渲染

video_path (str):如果提供,将回合过程录制为视频并保存到指定的路径

video_skip (int):每隔 @video_skip 步渲染一次帧到视频中

camera_names (str 或 [str]):指定用于渲染的相机名称。

dataset_path (str):如果提供,将回合数据写入到指定路径的hdf5文件中

dataset_obs (bool):如果flag及@dataset_path提供,hdf5文件中将包含可能的高维观测数据(默认情况下,仅保存模拟器状态)

seed (int): 如果提供,设置回合的随机种子

二、Terminal运行

保存回合到视频并渲染相机视图

评估一个策略,50次滚动,最大地平线400,并将滚动保存到视频中

在过程中可视化agentview和手腕摄像头

1. 执行评估策略

进行50次回合,每次最多运行400步,并将回合保存为视频文件,回合过程中显示agentview和wrist相机视角的画面

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --video_path /path/to/output.mp4 \
    --camera_names agentview robot0_eye_in_hand

2. 保存方法1

将50次回合保存到hdf5数据集文件中

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --dataset_path /path/to/output.hdf5 --dataset_obs

3. 保存方法2

将50次回合保存到hdf5数据集文件中,但不包含数据观测数据,因为这些数据可能是高维的

可以使用dataset_states_to_obs.py脚本再次提取这些观测数据

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --dataset_path /path/to/output.hdf5

三、逐步运行与解析(run a trained policy and visualize the rollout)

官方提供了python代码的示例,先看一遍相关示例,再进一步解析

示例采用jupyter演示,库环境默认已安装好robomimic和robosuite

1. 库引用

import argparse
import json
import h5py
import imageio
import numpy as np
import os
from copy import deepcopy

import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.envs.env_base import EnvBase
from robomimic.algo import RolloutPolicy

import urllib.request

2. 下载policy checkpoint

从model zoo中下载pretrained model

此处简单的说明一下 checkpoint,在深度学习模型中,checkpoint 是指在训练过程中保存的模型状态,通常包含模型的参数(权重和偏置)、优化器的状态以及其他相关的训练信息

在训练过程中定期保存模型 checkpoint,就可以在需要时恢复训练或用于模型评估和推理

模型的参数(权重和偏置)文件:在 TensorFlow 中通常是 .cpkt 文件,在 PyTorch 中通常是 .pt.pth 文件

# Get pretrained checkpooint from the model zoo

ckpt_path = "lift_ph_low_dim_epoch_1000_succ_100.pth"
# Lift (Proficient Human)
urllib.request.urlretrieve(
    "http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/lift_ph_low_dim_epoch_1000_succ_100.pth",
    filename=ckpt_path
)

assert os.path.exists(ckpt_path)

3. 加载trained policy

调用 policy_from_checkpoint 函数,从 checkpoint 中构建正确的模型并加载训练好的权重,也可以手动加载 checkpoint

device = TorchUtils.get_torch_device(try_to_use_cuda=True)

# restore policy
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)

4. 创建 rollout 环境

此处简单的说明一下 rollout,直接翻译为 “推演或者模拟”,通常是智能体和环境以及模型交互的过程中产生的一系列的交互历史轨迹,通过收集数据,来评估或改进当前的策略

一个 rollout 可以包含一个或多个完整的episodes,或者只是一个episode的一部分(在实际应用中,通常一个rollout只包含一个episode的数据)

policy checkpoint 包含足够的信息来重新创建训练它的环境(也可以手动创建环境)

# create environment from saved checkpoint
env, _ = FileUtils.env_from_checkpoint(
    ckpt_dict=ckpt_dict, 
    render=False, # we won't do on-screen rendering in the notebook
    render_offscreen=True, # render to RGB images for video
    verbose=True,
)

5. 定义 rollout 循环

定义主 rollout 循环,该循环将运行 policy 到目标 horizon,并将 rollout 写入视频(可选)

def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
    assert isinstance(env, EnvBase)
    assert isinstance(policy, RolloutPolicy)
    assert not (render and (video_writer is not None))

    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):

            # get action from policy
            act = policy(ob=obs)

            # play action
            next_obs, r, done, _ = env.step(act)

            # compute reward
            total_reward += r
            success = env.is_success()["task"]

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                video_count += 1

            # break if done or if success
            if done or success:
                break

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))

    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    return stats

6. 运行policy

此处请仔细看,和 terminal 运行指令相同,赋予参数

rollout_horizon = 400
np.random.seed(0)
torch.manual_seed(0)
video_path = "rollout.mp4"
video_writer = imageio.get_writer(video_path, fps=20)

运行主函

stats = rollout(
    policy=policy, 
    env=env, 
    horizon=rollout_horizon, 
    render=False, 
    video_writer=video_writer, 
    video_skip=5, 
    camera_names=["agentview"]
)
print(stats)
video_writer.close()

7. 可视化 rollout

from IPython.display import Video
Video(video_path)

8.运行结果

下载了数据集,生成了视频,看一下是夹取动作

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884237.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【web开发】Spring Boot 快速搭建Web项目(三)

Date: 2024.08.31 18:01:20 author: lijianzhan 简述:根据上篇原文Spring Boot 快速搭建Web项目(二),由于已经搭建好项目初始的框架,以及自动创建了一个启动类文件(TestWebApplication.java) …

【Python】Daphne:Django 异步服务的桥梁

Daphne 是 Django Channels 项目的一部分,专门用于为 Django 提供支持 HTTP、WebSocket、HTTP2 和 ASGI 协议的异步服务器。Daphne 是一个开源的 Python 异步服务器,它可以帮助开发者运行异步应用程序,并且非常适合与 Django Channels 一起使…

电子电路的基础知识

电子电路是现代电子技术的基础,由电子元件(如电阻、电容、电感、二极管、晶体管等)和无线电元件通过一定方式连接而成的电路系统。 以下是对电子电路的详细概述: 一、定义与分类 定义:电子电路是指由电子器件和有关无…

解压视频素材下载网站推荐

在制作抖音小说推文或其他短视频时,找到合适的解压视频素材非常重要。以下是几个推荐的网站,可以帮助你轻松下载高质量的解压视频素材: 蛙学网 蛙学网是国内顶尖的短视频素材网站,提供大量4K高清无水印的解压视频素材,…

【记录】Excel|不允许的操作:合并或隐藏单元格出现的问题列表及解决方案

人话说在前:这篇的内容是2022年5月写的,当时碰到了要批量处理数据的情况,但是又不知道数据为啥一直报错报错报错,说不允许我操作,最终发现是因为存在隐藏的列或行,于是就很无语地写了博客,但内容…

如何使用Flux+lora进行AI模型文字生成图片

目录 概要 前期准备 部署安装与运行 1. 部署ComfyUI 本篇的模型部署是在ComfyUI的基础上进行,如果没有部署过ComfyUI,请按照下面流程先进行部署,如已安装请跳过该步: (1)使用命令克隆 ComfyUI &…

【友元补充】【动态链接补充】

友元 友元的目的是让一个函数或者类,访问另一个类中的私有成员。 友元的关键字friend是一个修饰符。 友元分为友元类和友元函数 1.全局函数作友元 2.类作友元 3.类的一个成员函数作友元 好处:可以通过友元在类外访问类内的私有和受保护类型的成员 坏处…

Python画笔案例-068 绘制漂亮米

1、绘制漂亮米 通过 python 的turtle 库绘制 漂亮米,如下图: 2、实现代码 绘制 漂亮米,以下为实现代码: """漂亮米.py注意亮度为0.5的时候最鲜艳本程序需要coloradd模块支持,安装方法:pip install coloradd程序运行需要很长时间,请耐心等待。可以把窗口最小…

智能Ai语音机器人的应用价值有哪些?

随着时间的推移,人工智能的发展越来越成熟,智能时代也离人们越来越近,近几年人工智能越来越火爆,人工智能的应用已经开始渗透到各行各业,与生活交融,成为人们无法拒绝,无法失去的一个重要存在。…

医疗大数据安全与隐私保护:数据分类分级的基石作用

医疗行业在数字化转型中迅猛发展,医疗大数据作为核心驱动力,深刻改变医疗服务的模式与效率。它不仅促进医疗信息的流通与共享,推动个性化、精准化的医疗服务新生态。同时,也在提升医疗服务质量、优化医疗资源配置等方面展现巨大潜…

Spring Ioc底层原理代码详细解释

文章目录 概要根据需求编写XML文件,配置需要创建的bean编写程序读取XML文件,获取bean相关信息,类,属性,id前提知识点Dom4j根据第二步获取到的信息,结合反射机制动态创建对象,同时完成属性赋值将…

蓝桥杯【物联网】零基础到国奖之路:十二. TIM

蓝桥杯【物联网】零基础到国奖之路:十二. TIM 第一节 理论知识第二节 cubemx配置 第一节 理论知识 STM32L071xx器件包括4个通用定时器、1个低功耗定时器(LPTIM)、2个基本定时器、2个看门狗定时器和SysTick定时器。 通用定时器(TIM2、TIM3、…

Spring Cloud Alibaba-(6)Spring Cloud Gateway【网关】

Spring Cloud Alibaba-(1)搭建项目环境 Spring Cloud Alibaba-(2)Nacos【服务注册与发现、配置管理】 Spring Cloud Alibaba-(3)OpenFeign【服务调用】 Spring Cloud Alibaba-(4)Sen…

数据分析工具julius ai如何使用

什么是julius ai Julius AI 是一款强大的ai数据分析工具。用户可以使用excel、数据库、文本文件等多种格式的数据,Julius AI 会自动分析这些数据并提供详细的解释和可视化图表。官网显示它目前已经有三十万用户。它也支持手机版。 虽然openai也支持生成图表&#xf…

asp.net core grpc快速入门

环境 .net 8 vs2022 创建 gRPC 服务器 一定要勾选Https 安装Nuget包 <PackageReference Include"Google.Protobuf" Version"3.28.2" /> <PackageReference Include"Grpc.AspNetCore" Version"2.66.0" /> <PackageR…

项目实战:k8s部署考试系统

一、新建nfs服务器&#xff08;192.168.1.44&#xff09; 1.基础配置&#xff08;IP地址防火墙等&#xff09; 2.配置时间同步 [rootlocalhost ~]# yum -y install ntpdate.x86_64 [rootlocalhost ~]# ntpdate time2.aliyun.com 27 Sep 10:28:08 ntpdate[1634]: adjust tim…

MySql在更新操作时引入“两阶段提交”的必要性

日志模块有两个redo log和binlog&#xff0c;redo log 是引擎层的日志&#xff08;负责存储相关的事&#xff09;&#xff0c;binlog是在Server层&#xff0c;主要做MySQL共嗯那个层面的事情。redo log就像一个缓冲区&#xff0c;可以让当更新操作的时候先放redo log中&#xf…

node.js npm 安装和安装create-next-app -windowsserver12

1、官网下载windows版本NODE.JS https://nodejs.org/dist/v20.17.0/node-v20.17.0-x64.msi 2、安装后增加两个文件夹目录node_global、node_cache npm config set prefix "C:\Program Files\nodejs\node_global" npm config set prefix "C:\Program Files\nod…

基于SpringBoot的新冠检测信息管理系统的设计与实现

文未可获取一份本项目的java源码和数据库参考。 国内外在该方向的研究现状及分析 新型冠状病毒肺炎疫情发生以来&#xff0c;中国政府采取积极的防控策略和措施&#xff0c;经过两个多月的不懈努力&#xff0c;有效控制了新发病例的増长&#xff0c;本地传播已经趋于完全控制…

Mysql高级篇(中)——锁机制

锁机制 一、概述二、分类1、读锁2、写锁★、FOR SHARE / FOR UPDATE&#xff08;1&#xff09;NOWAIT&#xff08;2&#xff09;SKIP LOCKED&#xff08;3&#xff09;NOWAIT 和 SKIP LOCKED 的比较 ★、 脏写3、表级锁之 S锁 / X锁&#xff08;1&#xff09;总结&#xff08;2…