强化学习_06_pytorch-PPO实践(Hopper-v4)

一、PPO优化

PPO的简介和实践可以看笔者之前的文章 强化学习_06_pytorch-PPO实践(Pendulum-v1)
针对之前的PPO做了主要以下优化:

  1. batch_normalize: 在mini_batch 函数中进行adv的normalize, 加速模型对adv的学习
  2. policyNet采用beta分布(0~1): 同时增加MaxMinScale 将beta分布产出值转换到action的分布空间
  3. 收集多个episode的数据,依次计算adv,后合并到一个dataloader中进行遍历:加速模型收敛

1.1 PPO2 代码

详细可见 Github: PPO2.py

class PPO2:
    """
    PPO2算法, 采用截断方式
    """
    def __init__(self,
                state_dim: int,
                actor_hidden_layers_dim: typ.List,
                critic_hidden_layers_dim: typ.List,
                action_dim: int,
                actor_lr: float,
                critic_lr: float,
                gamma: float,
                PPO_kwargs: typ.Dict,
                device: torch.device,
                reward_func: typ.Optional[typ.Callable]=None
                ):
        dist_type = PPO_kwargs.get('dist_type', 'beta')
        self.dist_type = dist_type
        self.actor = policyNet(state_dim, actor_hidden_layers_dim, action_dim, dist_type=dist_type).to(device)
        self.critic = valueNet(state_dim, critic_hidden_layers_dim).to(device)
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = gamma
        self.lmbda = PPO_kwargs['lmbda']
        self.k_epochs = PPO_kwargs['k_epochs'] # 一条序列的数据用来训练的轮次
        self.eps = PPO_kwargs['eps'] # PPO中截断范围的参数
        self.sgd_batch_size = PPO_kwargs.get('sgd_batch_size', 512)
        self.minibatch_size = PPO_kwargs.get('minibatch_size', 128)
        self.action_bound = PPO_kwargs.get('action_bound', 1.0)
        self.action_low = -1 * self.action_bound 
        self.action_high = self.action_bound
        if 'action_space' in PPO_kwargs:
            self.action_low = self.action_space.low
            self.action_high = self.action_space.high
        
        self.count = 0 
        self.device = device
        self.reward_func = reward_func
        self.min_batch_collate_func = partial(mini_batch, mini_batch_size=self.minibatch_size)
    
    def _action_fix(self, act):
        if self.dist_type == 'beta':
            # beta 0-1 -> low ~ high
            return act * (self.action_high - self.action_low) + self.action_low
        return act 
    
    def _action_return(self, act):
        if self.dist_type == 'beta':
            # low ~ high -> 0-1 
            act_out = (act - self.action_low) / (self.action_high - self.action_low)
            return act_out * 1 + 0
        return act 

    def policy(self, state):
        state = torch.FloatTensor(np.array([state])).to(self.device)
        action_dist = self.actor.get_dist(state, self.action_bound)
        action = action_dist.sample()
        action = self._action_fix(action)
        return action.cpu().detach().numpy()[0]
    
    def _one_deque_pp(self, samples: deque):
        state, action, reward, next_state, done = zip(*samples)
        state = torch.FloatTensor(np.stack(state)).to(self.device)
        action = torch.FloatTensor(np.stack(action)).to(self.device)
        reward = torch.tensor(np.stack(reward)).view(-1, 1).to(self.device)
        if self.reward_func is not None:
            reward = self.reward_func(reward)

        next_state = torch.FloatTensor(np.stack(next_state)).to(self.device)
        done = torch.FloatTensor(np.stack(done)).view(-1, 1).to(self.device)
        
        old_v = self.critic(state)
        td_target = reward + self.gamma * self.critic(next_state) * (1 - done)
        td_delta = td_target - old_v
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta, done).to(self.device)
        # recompute
        td_target = advantage + old_v
        action_dists = self.actor.get_dist(state, self.action_bound)
        old_log_probs = action_dists.log_prob(self._action_return(action))
        return state, action, old_log_probs, advantage, td_target
        
    def data_prepare(self, samples_list: List[deque]):
        state_pt_list = []
        action_pt_list = []
        old_log_probs_pt_list = []
        advantage_pt_list = []
        td_target_pt_list = []
        for sample in samples_list:
            state_i, action_i, old_log_probs_i, advantage_i, td_target_i = self._one_deque_pp(sample)
            state_pt_list.append(state_i)
            action_pt_list.append(action_i)
            old_log_probs_pt_list.append(old_log_probs_i)
            advantage_pt_list.append(advantage_i)
            td_target_pt_list.append(td_target_i)
            
        state = torch.concat(state_pt_list) 
        action = torch.concat(action_pt_list) 
        old_log_probs = torch.concat(old_log_probs_pt_list) 
        advantage = torch.concat(advantage_pt_list) 
        td_target = torch.concat(td_target_pt_list)
        return state, action, old_log_probs, advantage, td_target
        
    def update(self, samples_list: List[deque]):
        state, action, old_log_probs, advantage, td_target = self.data_prepare(samples_list)
        if len(old_log_probs.shape) == 2:
            old_log_probs = old_log_probs.sum(dim=1)
        d_set = memDataset(state, action, old_log_probs, advantage, td_target)
        train_loader = DataLoader(
            d_set,
            batch_size=self.sgd_batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=self.min_batch_collate_func
        )
        
        for _ in range(self.k_epochs):
            for state_, action_, old_log_prob, adv, td_v in train_loader:
                action_dists = self.actor.get_dist(state_, self.action_bound)
                log_prob = action_dists.log_prob(self._action_return(action_))
                if len(log_prob.shape) == 2:
                    log_prob = log_prob.sum(dim=1)
                # e(log(a/b))
                ratio = torch.exp(log_prob - old_log_prob.detach())
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * adv

                actor_loss = torch.mean(-torch.min(surr1, surr2)).float()
                critic_loss = torch.mean(
                    F.mse_loss(self.critic(state_).float(), td_v.detach().float())
                ).float()

                self.actor_opt.zero_grad()
                self.critic_opt.zero_grad()
                actor_loss.backward()
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) 
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) 
                self.actor_opt.step()
                self.critic_opt.step()

        return True

    def save_model(self, file_path):
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        act_f = os.path.join(file_path, 'PPO_actor.ckpt')
        critic_f = os.path.join(file_path, 'PPO_critic.ckpt')
        torch.save(self.actor.state_dict(), act_f)
        torch.save(self.critic.state_dict(), critic_f)

    def load_model(self, file_path):
        act_f = os.path.join(file_path, 'PPO_actor.ckpt')
        critic_f = os.path.join(file_path, 'PPO_critic.ckpt')
        self.actor.load_state_dict(torch.load(act_f, map_location='cpu'))
        self.critic.load_state_dict(torch.load(critic_f, map_location='cpu'))
        self.actor.to(self.device)
        self.critic.to(self.device)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

    def train(self):
        self.training = True
        self.actor.train()
        self.critic.train()

    def eval(self):
        self.training = False
        self.actor.eval()
        self.critic.eval()

二、 Pytorch实践

2.1 智能体构建与训练

PPO2主要是收集多轮的结果序列进行训练,增加训练轮数,适当降低学习率,稍微增Actor和Critic的网络深度
详细可见 Github: test_ppo.Hopper_v4_ppo2_test

import os
from os.path import dirname
import sys
import gymnasium as gym
import torch
# 笔者的github-RL库
from RLAlgo.PPO import PPO
from RLAlgo.PPO2 import PPO2
from RLUtils import train_on_policy, random_play, play, Config, gym_env_desc

env_name = 'Hopper-v4'
gym_env_desc(env_name)
print("gym.__version__ = ", gym.__version__ )
path_ = os.path.dirname(__file__) 
env = gym.make(
    env_name, 
    exclude_current_positions_from_observation=True,
    # healthy_reward=0
)
cfg = Config(
    env, 
    # 环境参数
    save_path=os.path.join(path_, "test_models" ,'PPO_Hopper-v4_test2'), 
    seed=42,
    # 网络参数
    actor_hidden_layers_dim=[256, 256, 256],
    critic_hidden_layers_dim=[256, 256, 256],
    # agent参数
    actor_lr=1.5e-4,
    critic_lr=5.5e-4,
    gamma=0.99,
    # 训练参数
    num_episode=12500,
    off_buffer_size=512,
    off_minimal_size=510,
    max_episode_steps=500,
    PPO_kwargs={
        'lmbda': 0.9,
        'eps': 0.25,
        'k_epochs': 4, 
        'sgd_batch_size': 128,
        'minibatch_size': 12, 
        'actor_bound': 1,
        'dist_type': 'beta'
    }
)
agent = PPO2(
    state_dim=cfg.state_dim,
    actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,
    critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,
    action_dim=cfg.action_dim,
    actor_lr=cfg.actor_lr,
    critic_lr=cfg.critic_lr,
    gamma=cfg.gamma,
    PPO_kwargs=cfg.PPO_kwargs,
    device=cfg.device,
    reward_func=None
)
agent.train()
train_on_policy(env, agent, cfg, wandb_flag=False, train_without_seed=True, test_ep_freq=1000, 
                 online_collect_nums=cfg.off_buffer_size,
                 test_episode_count=5)

2.2 训练出的智能体观测

最后将训练的最好的网络拿出来进行观察

agent.load_model(cfg.save_path)
agent.eval()
env_ = gym.make(env_name, 
                exclude_current_positions_from_observation=True,
                render_mode='human'
                ) # , render_mode='human'
play(env_, agent, cfg, episode_count=3, play_without_seed=True, render=True)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/424347.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

排序(3)——直接选择排序

目录 直接选择排序 基本思想 整体思路(升序) 单趟 多趟 代码实现 特性总结 直接选择排序 基本思想 每一次从待排序的数据元素中选出最小(或最大)的一个元素,存放在序列的起始位置,直到全部待排序的…

JOSEF约瑟 漏电继电器LLJ-400F 配套零序互感器φ100mm 50-500mA 0.1S 导轨安装

系列型号: LLJ-150F(S)漏电继电器 LLJ-160F(S)漏电继电器 LLJ-200F(S)漏电继电器 LLJ-250F(S)漏电继电器 LLJ-300F(S)漏电继电器 LLJ-320F(S)漏电继电器 LLJ-400F(S)漏电继电器 LLJ-500F(S)漏电继电器 LLJ-600F(S)漏电继电器 一、产品用途及特点 LLJ-FS系列漏电继电…

MWC 2024丨美格智能推出5G RedCap系列FWA解决方案,开启5G轻量化新天地

2月27日,在MWC 2024世界移动通信大会上,美格智能正式推出5G RedCap系列FWA解决方案。此系列解决方案具有低功耗、低成本等优势,可以显著降低5G应用复杂度,快速实现5G网络接入,提升FWA部署的经济效益。 RedCap技术带来了…

MATLAB知识点:if条件判断语句的嵌套

​讲解视频:可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇(数学建模清风主讲,适合零基础同学观看)_哔哩哔哩_bilibili 节选自​第4章:MATLAB程序流程控制 我们通过一个…

NCT 全国青少年编程图形化编程(Scratch)等级考试(一级)模拟测试H

202312 青少年软件编程等级考试Scratch一级真题 第 1 题 【 单选题 】 以下说法合理的是( ) A :随意点开不明来源的邮件 B :把密码设置成 abc123 C :在虚拟社区上可以辱骂他人 D :在改编他人的作品前, 先征得他人同意 正确答案: D 试题解析&…

【Scratch画图100例】图49-scratch绘制直角风车 少儿编程 scratch编程画图案例教程 考级比赛画图集训案例

目录 scratch绘制直角风车 一、题目要求 1、准备工作 2、功能实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、实现流程 1、案例分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 …

regexpire-攻防世界-MISC

用nc链接靶机: rootkali:~/Desktop# nc 220.249.52.133 37944 Can you match these regexes? Bv*(clementine|sloth)*Q*eO(clinton|alien)*(cat|elephant)(cat|trump)[a-zA-Z]*(dolphin|clementine)\W*(table|apple)* 大致上是服务端给出一个正则表达式&#xff0c…

基于springboot的宠物咖啡馆平台的设计与实现论文

基于Spring Boot的宠物咖啡馆平台的设计与实现 摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了基于Spring Boot的宠物咖啡馆平台的设计与实现的开发全过程。通过分析基于Spring Boot的宠物咖啡馆平台的设计与…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:形状裁剪)

用于对组件进行裁剪、遮罩处理。 说明: 从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 clip clip(value: boolean | CircleAttribute | EllipseAttribute | PathAttribute | RectAttribute) 按指定的形状对当…

数据结构:排序算法+查找算法

一、概念 程序数据结构算法 1.算法的特性和要求 特性: 确定性(每次运行相同的输入都是同样的结果)、有穷性、输入、输出、可行性 设计要求: 正确性、高效率、低存储、健壮性、可读性 2.时间复杂度 3.常见排序算法的时间复杂…

6.5 共享数据

本节介绍Android的四大组件之一ContentProvider的基本概念和常见用法:首先说明如何使用内容提供器封装内部数据的外部访问接口,然后阐述如何使用内容解析器通过外部接口操作内部数据,最后叙述如何利用内容解析器读写联系人信息,以…

SQL无列名注入

SQL无列名注入 ​ 前段时间,队里某位大佬发了一个关于sql注入无列名的文章,感觉好像很有用,特地研究下。 关于 information_schema 数据库: ​ 对于这一个库,我所知晓的内容并不多,并且之前总结SQL注入的…

2W字-35页PDF谈谈自己对QT某些知识点的理解

2W字-35页PDF谈谈自己对QT某些知识点的理解 前言与总结总体知识点的概况一些笔记的概况笔记阅读清单 前言与总结 最近,也在对自己以前做的项目做一个知识点的梳理,发现可能自己以前更多的是用某个控件,以及看官方手册,但是没有更…

EchoServer回显服务器简单测试

目录 工具介绍 工具使用 测试结果 工具介绍 github的一个开源项目,是一个测压工具 EZLippi/WebBench: Webbench是Radim Kolar在1997年写的一个在linux下使用的非常简单的网站压测工具。它使用fork()模拟多个客户端同时访问我们设定的URL,测试网站在压力下工作的…

【Vue3】解锁Vue3黑科技:探索接口、泛型和自定义类型的前端奇迹

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

WebServer -- 注册登录

目录 🍉整体内容 🌼流程图 🎂载入数据库表 提取用户名和密码 🚩同步线程登录注册 补充解释 代码 😘页面跳转 补充解释 代码 🍉整体内容 概述 TinyWebServer 中,使用数据库连接池实现…

vscode c/c++ 检测到 #include 错误。请更新 includePath。

问题背景 使用vscode打开项目后,头文件显示红色波浪线,没有引入。 检测到 #include 错误。请更新 includePath。已为此翻译单元(xxx)禁用波形曲线。 解决方法 gcc -v -E -x c - 显示所有头文件路径。 打开c_cpp_properties.json文件,粘贴…

Verilog Constructs、Verilog系统任务和功能

下表列出了Verilog构造在Vivado合成中的支持状态。 Verilog系统任务和功能 Vivado合成支持系统任务或功能,如下表所示。Vivado合成会忽略不支持的系统任务。 使用转换函数 使用以下语法对任何表达式调用$signed和$unsigned系统任务。 $signed(expr&am…

「爬虫职海录」三镇爬虫

HI,朋友们好 「爬虫职海录」第三期更新啦! 本栏目的内容方向会以爬虫相关的“岗位分析”和“职场访谈”为主,方便大家了解一下当下的市场行情。 本栏目持续更新,暂定收集国内主要城市的爬虫岗位相关招聘信息,有求职…

计算机设计大赛 深度学习机器视觉车道线识别与检测 -自动驾驶

文章目录 1 前言2 先上成果3 车道线4 问题抽象(建立模型)5 帧掩码(Frame Mask)6 车道检测的图像预处理7 图像阈值化8 霍夫线变换9 实现车道检测9.1 帧掩码创建9.2 图像预处理9.2.1 图像阈值化9.2.2 霍夫线变换 最后 1 前言 🔥 优质竞赛项目系列,今天要分…