PPO 跑CartPole-v1

gym-0.26.2
cartPole-v1

参考动手学强化学习书中的代码,并做了一些修改

代码

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class PPO:
    """PPO算法,采用截断的方式"""
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs    # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO 中阶段范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.FloatTensor([state]).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def gae(self, td_delta):
        td_delta = td_delta.detach().numpy()
        advantages_list = []
        advantage = 0.0
        for delta in td_delta[::-1]:
            advantage = self.gamma * self.lmbda * advantage + delta
            advantages_list.append(advantage)
        advantages_list.reverse()
        return torch.FloatTensor(advantages_list)

    def update(self, transition_dist):
        states = torch.FloatTensor(transition_dist['states']).to(self.device)
        actions = torch.tensor(transition_dist['actions']).reshape((-1, 1)).to(self.device)
        rewards = torch.FloatTensor(transition_dist['rewards']).reshape((-1, 1)).to(self.device)
        next_states = torch.FloatTensor(transition_dist['next_states']).to(self.device)
        dones = torch.FloatTensor(transition_dist['dones']).reshape((-1, 1)).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        # GAE 计算广义优势
        advantage = self.gae(td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ration = torch.exp(log_probs - old_log_probs)
            surr1 = ration * advantage
            surr2 = torch.clamp(ration, 1-self.eps, 1+self.eps) * advantage # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))   # PPO损失函数
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()


def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0))
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))


def train():
    actor_lr = 1e-3
    critic_lr = 1e-2
    num_episodes = 500
    hidden_dim = 128
    gamma = 0.98
    lmbda = 0.95
    epochs = 10
    eps = 0.2
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    env_name = "CartPole-v1"
    env = gym.make(env_name)
    torch.manual_seed(0)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)

    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state, _ = env.reset()
                done, truncated = False, False
                while not done and not truncated:
                    action = agent.take_action(state)
                    next_state, reward, done, truncated, _ = env.step(action)
                    done = done or truncated    # 这个地方要注意
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)


    episodes_list = list(range(len(return_list)))
    plt.plot(episodes_list, return_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title(f'PPO on {env_name}')
    plt.show()

    mv_return = moving_average(return_list, 9)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title(f'PPO on {env_name}')
    plt.show()


if __name__ == '__main__':
    train()


pycharm中运行结果:

效果看起很好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/328855.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024最新PyQt5及其工具(Qt Designer、PyUIC、PyRcc)手把手操作实践指南

2024最新PyQt5及其工具(Qt Designer、PyUIC、PyRcc)手把手操作实践指南 前言 最近做了一些个人项目,内部逻辑还是挺多的,而且也有想要开源的想法,但是总不能直接把源码端给大家直接运行,有一些需求还有萌…

ubuntu设置每天定时关机

ubuntu设置每天定时关机 终端输入命令: sudo crontab -e输入密码,回车。 我这里使用nano作为编辑器,你可以选择vim。 在末尾输入以下命令: 59 23 * * * sudo -u root shutdown now设置:每天23:59分,电脑…

中科院自动化所:基于关系图深度强化学习的机器人多目标包围问题新算法

摘要:中科院自动化所蒲志强教授团队,提出一种基于关系图的深度强化学习方法,应用于多目标避碰包围(MECA)问题,使用NOKOV度量动作捕捉系统获取多机器人位置信息,验证了方法的有效性和适应性。研究成果在2022年ICRA大会发…

HarmonyOS应用开发者初级认证试题库(鸿蒙)

目录 考试链接: 流程: 选择: 判断: 单选: 多选: 考试链接: 开发者能力认证-职业认证-鸿蒙能力认证-华为开发者学堂 (huawei.com)https://developer.huawei.com/consumer/cn/training/dev-…

计算机导论06-人机交互

文章目录 人机交互基础人机交互概述人机交互及其发展人机交互方式人机界面 新型人机交互技术显示屏技术跟踪与识别(技术)脑-机接口 多媒体技术多媒体技术基础多媒体的概念多媒体技术及其特性多媒体技术的应用多媒体技术发展趋势 多媒体应用技术文字&…

高斯Hack算法

背景 刷leetcode时,碰到一题需要求解n个bit中选择m个bit的所有组合集,我只想到了递归求解,没啥问题,但是在官方题解中看到了牛逼的方法(Gospers Hack),故记录一下。 4bit中2个1的情况 0011b0101b0110b1001b1010b1100b…

ASP.NET Core SingleR:初次体验和简单项目搭建

文章目录 前言应用场景SignalR 网站长什么样?第一个ASP.NET core SignalR程序确定SignalR版本新建MVC项目添加unpkg管理器添加客户端添加ChatHub文件添加SignalR服务添加网页运行测试浏览器Websocket调试type1type6Type为其它时 总结 前言 平常的网页通讯都是基于H…

苹果要在iPhone上运行AI大模型?

近两年,人工智能(AI)技术已经成为各大科技公司的重点研究领域,苹果公司自然也不甘落后。最新消息称,苹果甚至打算在iPhone上直接运行AI大模型... 据苹果AI研究人员表示,他们发明了一种创新的闪存利用技术&a…

Intel开发环境Quartus、Eclipse与WSL的安装

PC :win10 64bit 安装顺序:先安装Quartus 21.4,接着Eclipse或者WSL(Windows Subsystem for Linux),Eclipse与WSL的安装不分先后。 为什么要安装Eclipse? 因为Eclipse可以开发基于Nios II的C/…

【上分日记】第380场周赛(数位dp+ KMP + 位运算 + 二分 + 双指针 )

文章目录 前言正文1.3005. 最大频率元素计数2.3007.价值和小于等于 K 的最大数字3.3008. 找出数组中的美丽下标 II 总结尾序 前言 本场周赛,博主也只写出两道题(前两道, hhh菜鸡勿喷),第三道涉及位运算 ,数位dp,第四道涉及KMP。 下…

Mac系统下,保姆级Jenkins自动化部署Android

一、Jenkins自动化部署 1、安装jenkins 官网:macOS Installers for Jenkins LTS 选择macOS brew install jenkins-lts 安装最新: brew install jenkins-lts 启动jenkins服务: brew services start jenkins-lts 重启jenkins服务: brew services restart jenkin…

Angular系列教程之变更检测与性能优化

文章目录 前言变更检测的原理脏检查OnPush策略 示例代码总结 前言 Angular 除了默认的变化检测机制,也提供了ChangeDetectionStrategy.OnPush,用 OnPush 可以跳过某个组件或者某个父组件以及它下面所有子组件的变化检测。 在本文中,我们将探…

grep 在运维中的常用可选项

一、对比两个文件 vim -d <filename1> <filename2> 演示&#xff1a; 需求&#xff1a;&#xff5e;目录下有两个文件一个test.txt 以及 text2.txt,需求对比两个文件的内容。 执行后会显示如图&#xff0c;不同会高亮。 二、两次过滤 场景&#xff1a;当需要多…

如何用AI提高论文阅读效率?

已经2024年了&#xff0c;该出现一个写论文解读AI Agent了。 大家肯定也在经常刷论文吧。 但真正尝试过用GPT去刷论文、写论文解读的小伙伴&#xff0c;一定深有体验——费劲。其他agents也没有能搞定的&#xff0c;今天我发现了一个超级厉害的写论文解读的agent &#xff0c…

rust跟我学六:虚拟机检测

图为RUST吉祥物 大家好,我是get_local_info作者带剑书生,这里用一篇文章讲解get_local_info是怎么检测是否在虚拟机里运行的。 首先,先要了解get_local_info是什么? get_local_info是一个获取linux系统信息的rust三方库,并提供一些常用功能,目前版本0.2.4。详细介绍地址:…

标准ERP产品的许可期限为,2023-12-22,当前已超过使用期限,请联系系统管理员更新许可或者删除超期的许可

项目场景&#xff1a; 金蝶云星空客户端 问题描述 标准ERP产品的许可期限为&#xff0c;2023-12-22&#xff0c;当前已超过使用期限&#xff0c;请联系系统管理员更新许可或者删除超期的许可! 原因分析&#xff1a; 这个报错提示是金蝶对第三方行业许可强制绑定,行业许可到…

Android PendingIntent 闪退

先来给大家推荐一个我日常会使用到的图片高清处理在线工具&#xff0c;主要是免费&#xff0c;直接白嫖 。 有时候我看到一张图片感觉很不错&#xff0c;但是图片清晰度不合我意&#xff0c;就想有没有什么工具可以处理让其更清晰&#xff0c; 网上随便搜下就能找到&#xff…

MySQL创建表及插入数据

一、创建英雄表&#xff08;hero&#xff09; mysql> use db_han Database changed mysql> create table hero(-> name varchar(20) primary key,-> nickname varchar(20),-> address varchar(20),-> groups varchar(20),-> emile int,-> telephone i…

6.3.1认识Camtasia4(2)

6.3.2录制声音 Camtasia4也可以用来录制声音&#xff0c;不过它在音频的录制与处理上要比GoldWave5差了很多。 1&#xff0e;在Camtasia Studio主程序中&#xff0c;单击【工具】|【Camtasia音频编辑器】&#xff0c;启动Camtasia音频编辑器。在弹出的欢迎界面中&#xff0c;…

如何在网络爬虫中解决CAPTCHA?使用Python进行网络爬虫

网络爬虫是从网站提取数据的重要方法。然而&#xff0c;在进行网络爬虫时&#xff0c;常常会遇到一个障碍&#xff0c;那就是CAPTCHA&#xff08;全自动公共图灵测试以区分计算机和人类&#xff09;。本文将介绍在网络爬虫中解决CAPTCHA的最佳方法&#xff0c;并重点介绍CapSol…