【rl-agents代码学习】01——总体框架

文章目录

  • rl-agent Get start
    • Installation
    • Usage
    • Monitoring
  • 具体代码

学习一下rl-agents的项目结构以及代码实现思路。

source: https://github.com/eleurent/rl-agents

rl-agent Get start

Installation

pip install --user git+https://github.com/eleurent/rl-agents

Usage

rl-agents中的大部分例子可以通过cd到scripts文件夹 cd scripts,执行 python experiments.py命令实现。

Usage:
  experiments evaluate <environment> <agent> (--train|--test)
                                             [--episodes <count>]
                                             [--seed <str>]
                                             [--analyze]
  experiments benchmark <benchmark> (--train|--test)
                                    [--processes <count>]
                                    [--episodes <count>]
                                    [--seed <str>]
  experiments -h | --help

Options:
  -h --help            Show this screen.
  --analyze            Automatically analyze the experiment results.
  --episodes <count>   Number of episodes [default: 5].
  --processes <count>  Number of running processes [default: 4].
  --seed <str>         Seed the environments and agents.
  --train              Train the agent.
  --test               Test the agent.

evaluate命令允许在给定的环境中评估给定的agent。例如,

# Train a DQN agent on the CartPole-v0 environment
$ python3 experiments.py evaluate configs/CartPoleEnv/env.json configs/CartPoleEnv/DQNAgent.json --train --episodes=200

每个agent都按照标准接口与环境交互:

action = agent.act(state)
next_state, reward, done, info = env.step(action)
agent.record(state, action, reward, next_state, done, info)

环境的配置文件

{
    "id": "intersection-v0",
    "import_module": "highway_env",
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 15,
        "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"],
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-20, 20],
            "vy": [-20, 20]
        },
        "absolute": true,
        "order": "shuffled"
    },
    "destination": "o1"
}

agent的配置文件,核心就是"__class__": "<class 'rl_agents.agents.deep_q_network.pytorch.DQNAgent'>",利用agent_factory进行agent的创建。

{
    "__class__": "<class 'rl_agents.agents.deep_q_network.pytorch.DQNAgent'>",
    "model": {
        "type": "MultiLayerPerceptron",
        "layers": [128, 128]
    },
    "gamma": 0.95,
    "n_steps": 1,
    "batch_size": 64,
    "memory_capacity": 15000,
    "target_update": 512,
    "exploration": {
        "method": "EpsilonGreedy",
        "tau": 15000,
        "temperature": 1.0,
        "final_temperature": 0.05
    }
}

如果部分key缺失的话,会使用默认的值agent.default_config()

最后,可以在基准(baseline)测试中安排一批实验。然后在几个进程上并行执行所有实验。

# Run a benchmark of several agents interacting with environments
$ python3 experiments.py benchmark cartpole_benchmark.json --test --processes=4

基准配置文件包含环境配置列表和agent配置列表。

{
    "environments": ["configs/CartPoleEnv/env.json"],
    "agents": [
        "configs/CartPoleEnv/DQNAgent.json",
        "configs/CartPoleEnv/LinearAgent.json",
        "configs/CartPoleEnv/MCTSAgent.json"
    ]
}

Monitoring

有几种工具可用于监控agent性能:

  • Run metadata:为了可重复性,将运行所用的环境和agent配置合并,并保存到metadata.*.json文件中。
  • Gym Monitor:每次运行的主要统计数据(episode rewards, lengths, seeds)都会记录到episode_batch.*.stats.json文件中。可以通过运行scripts/analyze.py来自动可视化这些数据。
  • Logging:agent可以通过标准的Python日志记录库发送消息。默认情况下,所有日志级别为INFO的消息都会保存到logging.*.lo文件中。要保存日志级别为DEBUG的消息,请添加选项scripts/experiments.py --verbose
  • Tensorboard:默认情况下,一个tensoboard writer会记录有关有用标量、图像和模型图的信息到运行目录。可以通过运行以下命令来进行可视化:tensorboard --logdir <path-to-runs-dir>

具体代码

rl-agents核心代码集中在rl-agents文件夹和scripts文件夹中,其中,rl-agents主要实现相关的算法,scripts为相应的配置文件。
在这里插入图片描述

experiments.py为入口程序,先从它看起,其相应的用法如下:
在这里插入图片描述

Usage:
  experiments evaluate <environment> <agent> (--train|--test) [options]
  experiments benchmark <benchmark> (--train|--test) [options]
  experiments -h | --help

Options:
  -h --help              Show this screen.
  --episodes <count>     Number of episodes [default: 5].
  --no-display           Disable environment, agent, and rewards rendering.
  --name-from-config     Name the output folder from the corresponding config files
  --processes <count>    Number of running processes [default: 4].
  --recover              Load model from the latest checkpoint.
  --recover-from <file>  Load model from a given checkpoint.
  --seed <str>           Seed the environments and agents.
  --train                Train the agent.
  --test                 Test the agent.
  --verbose              Set log level to debug instead of info.
  --repeat <times>       Repeat several times [default: 1].

首先从main函数开始,根据evaluate或者benchmark执行相应的任务。暂且先从evaluate入手。

def main():
    opts = docopt(__doc__)
    if opts['evaluate']:
        for _ in range(int(opts['--repeat'])):
            evaluate(opts['<environment>'], opts['<agent>'], opts)
    elif opts['benchmark']:
        benchmark(opts)

evaluate主要完成envagent的创建以及evaluation 对象的创建,再根据选择train或test执行不同的程序。

def evaluate(environment_config, agent_config, options):
    """
        Evaluate an agent interacting with an environment.

    :param environment_config: the path of the environment configuration file
    :param agent_config: the path of the agent configuration file
    :param options: the evaluation options
    """
    logger.configure(LOGGING_CONFIG)
    if options['--verbose']:
        logger.configure(VERBOSE_CONFIG)
    env = load_environment(environment_config)
    agent = load_agent(agent_config, env)
    run_directory = None
    if options['--name-from-config']:
        run_directory = "{}_{}_{}".format(Path(agent_config).with_suffix('').name,
                                  datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                                  os.getpid())
    options['--seed'] = int(options['--seed']) if options['--seed'] is not None else None
    evaluation = Evaluation(env,
                            agent,
                            run_directory=run_directory,
                            num_episodes=int(options['--episodes']),
                            sim_seed=options['--seed'],
                            recover=options['--recover'] or options['--recover-from'],
                            display_env=not options['--no-display'],
                            display_agent=not options['--no-display'],
                            display_rewards=not options['--no-display'])
    if options['--train']:
        evaluation.train()
    elif options['--test']:
        evaluation.test()
    else:
        evaluation.close()
    return os.path.relpath(evaluation.run_directory)

Evaluation类中主要包含以下函数:
在这里插入图片描述

__init__的一些参数说明

参数描述
env要解决的环境,可能是包装了AbstractEnv的环境
agent解决环境的AbstractAgent agent
directory工作空间目录路径
run_directory运行目录路径
num_episodes运行的episode数
trainingagent是处于训练模式还是测试模式
sim_seed环境/agent随机性源的种子
recover从文件中恢复agent参数。如果为True,则使用默认的最新保存。如果为字符串,则将其用作路径。
display_env渲染环境,并有一个监视器录制其视频
display_agent如果支持,将agent图形添加到环境查看器中
display_rewards通过episodes显示agent的性能
close_env当评估结束时,是否应该关闭环境
step_callback_fn在每个环境步骤之后调用的回调函数。它接受以下参数:(episode, env, agent, transition, writer)。

首先看一下train,根据agent是否有batched属性,分为run_batched_episodesrun_episodes

    def train(self):
        self.training = True
        if getattr(self.agent, "batched", False):
            self.run_batched_episodes()
        else:
            self.run_episodes()
        self.close()

run_episodes就是一般强化学习的基本过程,注意其中的reset step 等函数都是经过封装的。实现自己的算法时需要注意。run_batched_episodes则主要实现一些并行计算的任务,这一部分等之后再详细介绍。

在这里插入图片描述

    def run_episodes(self):
        for self.episode in range(self.num_episodes):
            # Run episode
            terminal = False
            self.reset(seed=self.episode)
            rewards = []
            start_time = time.time()
            while not terminal:
                # Step until a terminal step is reached
                reward, terminal = self.step()
                rewards.append(reward)

                # Catch interruptions
                try:
                    if self.env.unwrapped.done:
                        break
                except AttributeError:
                    pass

            # End of episode
            duration = time.time() - start_time
            self.after_all_episodes(self.episode, rewards, duration)
            self.after_some_episodes(self.episode, rewards)

test为模型测试部分

    def test(self):
        """
        Test the agent.

        If applicable, the agent model should be loaded before using the recover option.
        """
        self.training = False
        if self.display_env:
            self.wrapped_env.episode_trigger = lambda e: True
        try:
            self.agent.eval()
        except AttributeError:
            pass
        self.run_episodes()
        self.close()

其中eval也需要进行重写。

    def eval(self):
        """
            Set to testing mode. Disable any unnecessary exploration.
        """
        pass

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/138582.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PaaS基础建设

PaaS&#xff08;Platform-as-a-Service&#xff1a;平台即服务&#xff09;是应用程序和服务的部署平台。Paas为开发、测试和管理软件应用程序提供所需的开发环境&#xff0c;是云计算服务类型之一。 PaaS是什么&#xff1f;IaaS、SaaS、PaaS三种云服务区别 PaaS&#xff08;P…

​《水经注全国三维离线GIS系统》硬件安装教程

有些工作&#xff0c;是需要一些外在动力才能完成的。 为什么这么讲呢&#xff1f; 因为正是在客户的要求下&#xff0c;我们才撰写了《水经注全国三维离线GIS系统》的硬件安装教程&#xff0c;而且还录制了视频教程。 当用户收到货物以后&#xff0c;就可以通过本教程清点货…

深度学习AI识别人脸年龄

以下链接来自 落痕的寒假 GitHub - luohenyueji/OpenCV-Practical-Exercise: OpenCV practical exercise GitHub - luohenyueji/OpenCV-Practical-Exercise: OpenCV practical exercise import cv2 as cv import time import argparsedef getFaceBox(net, frame, conf_thresh…

深度剖析c语言程序 -- 函数栈帧的创建和销毁(纯肝货)

本章的内容: 什么是函数栈帧&#xff1f; 理解函数栈帧能解决什么问题&#xff1f; 函数栈帧的创建和销毁解析 本文放到 --> 该专栏内&#xff1a;http://t.csdnimg.cn/poMzA 目录 什么是函数栈帧❓ 理解函数栈帧能解决什么问题呢&#xff1f;&#x1f4a2; 函数栈帧的…

计数排序及优化

&#x1f389;个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名乐于分享在学习道路上收获的大二在校生 &#x1f43b;‍❄个人主页&#x1f389;&#xff1a;GOTXX&#x1f43c;个人WeChat&#xff1a;ILXOXVJE&#x1f43c;本文由GOTXX原创&#xff0c;首发CSDN&a…

汽车一键启动智能系统功能作用

在现代科技的推动下&#xff0c;我们的生活每天都在发生着变化。其中&#xff0c;汽车智能一键启动系统就是科技改变生活的最好例子之一。 首先&#xff0c;我们来简单了解一下汽车智能一键启动系统。它是一种利用先进的电子技术和无线通信技术&#xff0c;实现无需钥匙即可启…

基于单片机智能输液器监控系统的设计

**单片机设计介绍&#xff0c; 基于单片机智能输液器监控系统的设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的智能输液器监控系统可以实现对输液过程的实时监测和控制&#xff0c;以下是一个基本的设计介绍&am…

【数据结构——队列的实现(单链表)】

数据结构——队列的实现&#xff08;单链表&#xff09; 一.队列1.1队列的概念及结构 二.队列的实现2.1 头文件的实现——&#xff08;Queue.h&#xff09;2.2 源文件的实现—— &#xff08;Queue.c&#xff09;2.3 源文件的实现—— &#xff08;test.c&#xff09; 三.队列的…

拼多多API接口,打造智能化电商平台

近年来&#xff0c;电商行业的崛起给人们的购物带来了极大的方便。随着电商行业的发展&#xff0c;拼多多作为新兴电商平台已经成为市场焦点。 同时&#xff0c;随着技术的发展&#xff0c;API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&…

pta 6翻了 Python3

“666”是一种网络用语&#xff0c;大概是表示某人很厉害、我们很佩服的意思。最近又衍生出另一个数字“9”&#xff0c;意思是“6翻了”&#xff0c;实在太厉害的意思。如果你以为这就是厉害的最高境界&#xff0c;那就错啦 —— 目前的最高境界是数字“27”&#xff0c;因为这…

云课五分钟的一些想法

起源 自中学起&#xff0c;就积极学习和掌握互联网相关知识&#xff0c;到如今已经快30年了。 个人也全程经历了从信息时代的互联网&#xff08;硬&#xff09;到智能时代的大模型&#xff08;软&#xff09;。 整体信息到智能的基础设施&#xff0c;由硬到软&#xff0c;机…

CRM系统:除了销售管理,还能做些什么?

企业的健康发展&#xff0c;离不开业绩的提升。在企业数字化转型的背景下&#xff0c;采用数字化应用进行管理已成为共识。许多企业认识到了应该使用CRM客户管理系统来进行销售管理&#xff0c;但CRM能做的还有很多。下面说说除了销售管理&#xff0c;CRM系统还能做些什么&…

继承和多态_Java零基础手把手保姆级教程(超详细)

文章目录 Java零基础手把手保姆级教程_继承和多态&#xff08;超详细&#xff09;1. 继承1.1 继承的实现&#xff08;掌握&#xff09;1.2 继承的好处和弊端&#xff08;理解&#xff09; 2. 继承中的成员访问特点2.1 继承中变量的访问特点&#xff08;掌握&#xff09;2.2 sup…

石英增强光声光谱气体传感技术中的高精密压力控制解决方案

摘要&#xff1a;光声池内气体压力的可调节控制以及稳定性是保证光声法高精度测量的关键&#xff0c;但在目前的光声和光谱研究中&#xff0c;对气体样品池内压力控制技术的报道极为简单&#xff0c;甚至很多都是错误的&#xff0c;根本无法实现高精度调节和控制&#xff0c;为…

Autosar模块介绍:Memory_3(MemIf-内存接口抽象)

上一篇 | 返回主目录 | 下一篇 Autosar模块介绍&#xff1a;Memory_3(MemIf-内存接口抽象 1 基本术语解释2 MemIf组成结构图 1 基本术语解释 编号缩写原文解释1(Logical) Block——可单独寻址的连续内存区域&#xff08;即&#xff0c;用于读、写、擦除、比较等操作&#xff…

眼科动态图像处理系统使用说明(2023-8-11 ccc)

眼科动态图像处理系统使用说明 2023-8-11 ccc 动态眼科图像捕捉存贮分析与传输系统&#xff0c;是由计算机软件工程师和医学专家组结合&#xff0c;为满足医院临床工作的需要&#xff0c;在2000年开发的专门用于各类眼科图像自动化分析、处理和传输的软件系统。该系统可以和各…

【NLP】大型语言模型,ALBERT — 用于自监督学习的 Lite BERT

&#x1f50e;大家好&#xff0c;我是Sonhhxg_柒&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流&#x1f50e; &#x1f4dd;个人主页&#xff0d;Sonhhxg_柒的博客_CSDN博客 &#x1f4c3; &#x1f381;欢迎各位→点赞…

基于正余弦算法优化概率神经网络PNN的分类预测 - 附代码

基于正余弦算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于正余弦算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于正余弦优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神经网络…

【Linux】第十五站:环境变量

文章目录 一、进程相关的一些概念1.一些常见的概念2.对于并发3.**进程切换** 二、环境变量1.PATH环境变量2.HOME环境变量3.SHELL环境变量4.env5.系统调用接口与环境变量6.什么是环境变量&#xff1f;7.命令行参数8.main函数的第三个命令行参数9.如何验证环境变量是可以被继承的…

2、工厂模式的实现

工厂模式概念 工厂模式是一种常用的设计模式&#xff0c;它主要用于实例化对象。这种模式的主要思想是在不暴露具体的实现细节的情况下&#xff0c;让客户端能够创建具有特定接口的对象。它可以让我们在运行时决定实例化哪个类。 在C语言中&#xff0c;实例化对象意味着创建一…