基于PPO的强化学习超级马里奥自动通关

目录

一、环境准备

二、训练思路

1.训练初期:

2.思路整理及改进:

思路一:

思路二:

思路三:

思路四:

3.训练效果:

三、结果分析

四、完整代码

训练代码:

测试代码:

底模: 


本文将基于强化学习中的PPO算法训练一个自动玩超级马里奥的智能体,用于强化学习的项目实践

源码及底模放于文末(可自行取用)

一、环境准备

所需环境如下:

pip install nes-py
pip install gym-super-mario-bros
pip install setuptools==65.5.0 "wheel<0.40.0"
pip install gym==0.21.0
pip install stable-baselines3【extra】==1.6.0
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

注意: 在环境配置方面,nes-py库安装的先决条件是 安装Microsoft Visual C++,其下载地址为:Microsoft C++ Build Tools - Visual Studio

在安装Microsoft Visual C++时需选择桌面开发:

二、训练思路

1.训练初期:

使用了最简单的训练框架,并选择PPO算法中较简单的的CnnPolicy网络(可以尝试MlpPolicy和MultiInputPolicy网络我没试是因为太懒了)以及马里奥操控中的SIMPLE_MOVEMENT操作模块:

自然,效果是不尽人意的,马里奥在所选关卡的第三根水管处(即最高的那个水管)不断尝试跳跃,直至时间耗尽也未能通过。

2.思路整理及改进:

思路一:

        既然训练效果不佳,是否跟训练轮数有关?固将总训练轮数增加至3000000,并尝试训练。跑出来的模型有所改进,马里奥在成功越过所有水管后,遇到了新的难题——越过两个断崖。至此,无论如何增加轮数,马里奥似乎到了一个瓶颈,固继续进行修改。

思路二:

        在增加训练轮数的基础上,选择对关卡的环境图像进行预处理——使用GrayScaleObservation转换为灰度观察,并保留通道维度。同时,我们对训练参数进行调整:

        尝试训练后,能够得到一个不稳定越过断崖的新模型,但对断崖之后的环境似乎有些陌生,陷入了前半段关卡的“局部最优解”。

思路三:

        由于之前的训练过程中使用了较小的学习率(1e-9),进而使得马里奥在关卡中陷入了局部最优,所以选择对学习率进行微调,使其在最开始的训练阶段使用较大的学习率,在后期减小学习率,从而达到先快速探索参数空间并加速收敛,再提高模型的稳定性和收敛精度。

至此,训练出来的测试模型,奖励反馈有所增长,但实际测试效果与调整前相差不多。

思路四:

        在上述尝试无明显效果后,猜测效果的好坏是否与马里奥的奖励机制有关,固在查阅奖励部分代码后,对“抵达终点”的奖励予以提高,希望对效果有所改善。

然结果并没有明显改观,更换调整方向。分别尝试马里奥的三套运动方式

经过对比,complex_movement的效果远超另外两套,且在前面思路的改动下模型质量有显著提升,固整理上述调整方案,进行底模训练。

3.训练效果:

        以奖励折扣率gamma = 0.9、gae_lambda = 0.9、clip_range = 0.2、步长n_steps = 7168,并用1e-3作为开始训练的学习率,并在训练过程中使其动态地在1e-5,1e-7中调整,修改抵达终点的奖励反馈,同时设置训练轮数为4000000,训练动作组为complex_movement进行训练。得到基础奖励回报为1520的底模,并将其继续用于迁移学习,得到2300的新模型。在实际测试后发现,模型确有改观,固继续将新模型用于训练,最终得到3200的最终模型,其能顺利到达终点并进入关卡的下一阶段。

三、结果分析

        与之前的训练经验相比,使用复杂的动作组未必比简单的动作组训练出的效果差,学习率的调整也是必要的,先用较大学习率打好基础,再有小学习率继续细化模型。同时,要给足够的训练轮数(足够的训练时间)。若是能够把奖励机制更进一步细化增加奖励细节,对其的训练是会更有帮助的。

四、完整代码

训练代码:

from nes_py.wrappers import JoypadSpace
import time
import os
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
from gym.wrappers import GrayScaleObservation
from gym import Wrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import PPO
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback


# 定义自定义奖励包装器
class CustomRewardWrapper(Wrapper):
    def __init__(self, env):
        super(CustomRewardWrapper, self).__init__(env)
        self.curr_score = 0

    def step(self, action):
        state, reward, done, info = self.env.step(action)

        # 自定义的奖励
        reward += (info["score"] - self.curr_score) / 40.
        self.curr_score = info["score"]

        if done:
            if info["flag_get"]:
                reward += 50
            else:
                reward -= 50

        return state, reward / 10., done, info


class SaveOnBestTrainingRewardCallback(BaseCallback):
    """
    Callback for saving a model (the check is done every ``check_freq`` steps)
    based on the training reward (in practice, we recommend using ``EvalCallback``).

    :param check_freq: (int)
    :param log_dir: (str) Path to the folder where the model will be saved.
      It must contains the file created by the ``Monitor`` wrapper.
    :param verbose: (int)
    """

    def __init__(self, check_freq, save_model_dir, verbose=1):
        super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = os.path.join(save_model_dir, './')
        self.best_model_subdir = os.path.join(self.save_path, 'best_model')
        self.best_mean_reward = -np.inf
        self.best_model_path = None
        self.best_score_model_path = os.path.join(self.save_path, 'pass_customs_model.zip')  # 增加通关模型路径

    # def _init_callback(self) -> None:
    def _init_callback(self):
        # Create folder if needed
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    # def _on_step(self) -> bool:
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            print('self.n_calls: ', self.n_calls)
            model_path1 = os.path.join(self.save_path, 'model_{}'.format(self.n_calls))
            self.model.save(model_path1)
            # Save the best model
            x, y = ts2xy(load_results(monitor_dir), 'timesteps')
            if len(x) > 0:
                mean_reward = np.mean(y[-self.check_freq:])
                if self.verbose > 0:
                    print("Num timesteps: {}, Best mean reward: {:.2f}, Last mean reward: {:.2f}".format(
                        self.n_calls, self.best_mean_reward, mean_reward))

                if mean_reward > self.best_mean_reward:
                    if self.best_model_path is not None:
                        try:
                            os.remove(self.best_model_path)  # Delete the old best model
                        except OSError:
                            pass
                    self.best_mean_reward = mean_reward
                    # Update path for the new best model
                    self.best_model_path = os.path.join(self.save_path, 'best_model.zip')
                    # Save the new best model
                    self.model.save(self.best_model_path)

                    if self.verbose > 0:
                        print("New best mean reward: {:.2f} - saving best model".format(mean_reward))

                        # Save the best mean reward to a file
                        reward_record_file = './Mario_model_save/model/mario_model/best_mean_reward.txt'
                        with open(reward_record_file, 'a') as file:
                            # 将最佳平均奖励值和时间戳一同写入文件
                            file.write(
                                "New best mean reward: {:.2f} - Recorded at {}\n".format(mean_reward, datetime.now()))

        
        return True


# 总的训练timesteps
my_total_timesteps = 4000000
# 需要改变学习率的timestep
change_lr_timestep = 2000000


# 学习率调度函数
def learning_rate_schedule(progress_remaining):
    """
    参数 progress_remaining 表示剩下的训练进度(从1开始降低到0)。
    通过训练进度来动态调整学习率。
    """
    current_timestep = my_total_timesteps * (1 - progress_remaining)
    if current_timestep < change_lr_timestep:
        return 1e-3  # 1e-3
    elif change_lr_timestep <= current_timestep <= int(change_lr_timestep * 1.5):
        return 1e-5
    else:
        return 1e-7


env = gym_super_mario_bros.make('SuperMarioBros-1-2-v0')
env = JoypadSpace(env, COMPLEX_MOVEMENT)  # 使用复杂的按键映射

env = CustomRewardWrapper(env)  # 应用自定义奖励包装器

monitor_dir = r'./Mario_model_save/monitor_log/'
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, monitor_dir)  # 将环境包装为监视器

env = GrayScaleObservation(env, keep_dim=True)  # 转换为灰度观察,并保留通道维度
env = DummyVecEnv([lambda: env])  # 创建虚拟环境
env = VecFrameStack(env, 4, channels_order='last')  # 将最近4帧堆叠在一起

best_params = {
    'n_steps': 7168,  # 7168
    'gamma': 0.9,
    # 'learning_rate': 1e-3,   # 1e-3, 1e-4, 1e-5
    'clip_range': 0.2,
    'gae_lambda': 0.9,
}

# 更新best_params中的learning_rate参数
best_params.update({'learning_rate': learning_rate_schedule})

tensorboard_log = r'./Mario_model_save/tensorboard_log/'
# 正常训练
model = PPO("CnnPolicy", env, verbose=1,
            tensorboard_log=tensorboard_log,
            **best_params
            )
'''
# 加载预训练模型
pretrained_model_path = r'D:\python_project\Mario\model\mario_model\pretraining_model_4.zip'
model = PPO.load(pretrained_model_path, env=env, tensorboard_log=tensorboard_log, **best_params)'''

# 保存模型位置
save_model_dir = r'./Mario_model_save/model/mario_model/'
callback1 = SaveOnBestTrainingRewardCallback(10000, save_model_dir)

model.learn(total_timesteps=my_total_timesteps, callback=callback1)
# model.save("mario_model")

测试代码:

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, RIGHT_ONLY, COMPLEX_MOVEMENT
import time
from matplotlib import pyplot as plt
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
import os
from stable_baselines3 import PPO

from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, COMPLEX_MOVEMENT)

monitor_dir = r'./Mario/monitor_log/'
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, monitor_dir)

env = GrayScaleObservation(env, keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')

save_model_dir = r'model/mario_model/pretraining_model_5.zip'
# save_model_dir = r'./Mario/model/mario_model/pretraining_model.zip'


model = PPO.load(save_model_dir)

obs = env.reset()
obs = obs.copy()
done = True
while True:
    if done:
        state = env.reset()
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    obs = obs.copy()
    # time.sleep(0.01)
    env.render()

env.close()

底模: 

最有底模为pretraining_model_5

链接:https://pan.baidu.com/s/1ed9IfgqvPC-uJmbGZMZtMQ?pwd=ru3t 
提取码:ru3t

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/713304.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 日志(二)

本篇将继续介绍MySQL日志的相关内容 目录 一、二进制日志 简介 注意事项 删除二进制日志 查看二进制日志 二进制日志的格式 二、服务器日志维护 一、二进制日志 简介 二进制日志中主要记录了MySQL的更改事件&#xff08;不包含SELECT和SHOW),例如&#xff1a;表的…

Base64编码的工作原理与实际应用

目录 前言 一、什么是Base64编码&#xff1f; 二、Base64编码的原理 三、Base64编码的应用场景 四、为什么要使用Base 64 五、Base64加密解密的实现 前言 当你需要将二进制数据转换为可传输和存储的文本格式时&#xff0c;Base64编码是一个常用的选择。在这篇博客中&#…

C++ 51 之 继承中的构造和析构

对象构造和析构的调用原则 继承中的构造和析构 子类对象在创建时会首先调用父类的构造函数父类构造函数执行完毕后&#xff0c;才会调用子类的构造函数当父类构造函数有参数时&#xff0c;需要在子类初始化列表(参数列表)中显示调用父类构造函数析构函数调用顺序和构造函数相…

可以用来制作硬模空心耳机壳的胶粘剂有哪些种类?

可以用来制作硬模空心耳机壳的胶粘剂有哪些种类&#xff1f; 制作耳机壳的胶粘剂有很多种类&#xff0c;常见的有环氧树脂胶水、UV树脂胶、快干胶、热熔胶等。 这些胶粘剂都有不同的特点和适用场景&#xff0c;可以根据自己的需求选择合适的类型。 例如&#xff1a; 环氧树脂…

Adobe设计替代软件精选列表

Adobe软件的替代列表&#xff0c;最初由 XdanielArt 收集&#xff0c;并由社区改进。您可以随意打开问题或拉出请求&#xff0c;或从数据中创建图像(以便于共享)。列表总是按照免费和开源选项的顺序排列&#xff0c;但根据您的用例&#xff0c;它可能不是最佳选择 替代因素 &am…

Python 潮流周刊#56:NumPy 2.0 里更快速的字符串函数

△△请给“Python猫”加星标 &#xff0c;以免错过文章推送 本周刊由 Python猫 出品&#xff0c;精心筛选国内外的 250 信息源&#xff0c;为你挑选最值得分享的文章、教程、开源项目、软件工具、播客和视频、热门话题等内容。愿景&#xff1a;帮助所有读者精进 Python 技术&am…

【linux】认识“文件”的本质,理解“文件系统”的设计逻辑,体会linux优雅的设计理念

⭐⭐⭐个人主页⭐⭐⭐ ~~~~~~~~~~~~~~~~~~ C站最❤❤❤萌❤❤❤博主 ~~~~~~~~~~~~~~~~~~~ ​♥东洛的克莱斯韦克-CSDN博客♥ ~~~~~~~~~~~~~~~~~~~~ 嗷呜~ ✌✌✌✌ 萌妹统治世界~ &#x1f389;&#x1f389;&#x1f389;&#x1f389; ✈✈✈✈相关文章✈✈✈✈ &#x1f4a…

2023年的Top20 AI应用在近一年表现怎么样?

AI应用现在进入寒武纪大爆发时代&#xff0c;百花争艳。如果倒回到2023年初&#xff0c;那时候排名靠前的AI应用在一年多时间&#xff0c;发生了哪些变化&#xff1f;能带给我们什么启示&#xff1f; 在2023年1月&#xff0c;排名靠前20的AI应用是&#xff1a; DeepL&#xff…

MATLAB中与直方图有关函数的关系

histogram Histogram plot画直方图 histcounts 直方图 bin 计数 histcounts是histogram的主要计算函数。 discretize 将数据划分为 bin 或类别 histogram2 画二元直方图 histcounts2 二元直方图 bin 计数 hist和histc过时了。替换不建议使用的 hist 和 histc 实例 hist → \r…

Day54 JDBC

Day54 JDBC JDBC&#xff1a;SUN公司提供的一套操作数据库的标准规范&#xff0c;就是使用Java语言操作关系型数据库的一套API JDBC与数据库驱动的关系&#xff1a;接口与实现的关系 给大家画一个jdbc的工作模式图 1.JDBC的四大金刚 1.DriverManager&#xff1a;用于注册驱动 2…

【Quartus 13.0】NIOS II 部署UART 和 PWM

打算在 EP1C3T144I7 芯片上部署 nios ii 做 uart & pwm控制 这个芯片或许不够做 QT 部署 这个芯片好老啊&#xff0c;但是做控制足够了&#xff0c;我只是想装13写 leader给的接口代码是用VHDL写的&#xff0c;我不会 当然verilog我也不太会 就这样&#xff0c;随便写吧 co…

SUSTAINABILITY,SCIESSCI双检期刊还能投吗?

本期&#xff0c;小编给大家介绍的是一本MDPI出版社旗下SCIE&SSCI双检“毕业神刊”——SUSTAINABILITY。据悉&#xff0c;早在2024年1月&#xff0c;ElSEVIER旗下的Scopus数据库已暂停收录检索期刊SUSTAINABILITY所发表文章&#xff0c;同时重新评估是否继续收录该期刊。随…

Qwen2——阿里巴巴最新的多语言模型挑战 Llama 3 等 SOTA

引言 经过几个月的期待&#xff0c; 阿里巴巴 Qwen 团队终于发布了 Qwen2 – 他们强大的语言模型系列的下一代发展。 Qwen2 代表了一次重大飞跃&#xff0c;拥有尖端的进步&#xff0c;有可能将其定位为 Meta 著名的最佳替代品 骆驼3 模型。在本次技术深入探讨中&#xff0c;我…

零基础入门学用Arduino 第三部分(三)

重要的内容写在前面&#xff1a; 该系列是以up主太极创客的零基础入门学用Arduino教程为基础制作的学习笔记。个人把这个教程学完之后&#xff0c;整体感觉是很好的&#xff0c;如果有条件的可以先学习一些相关课程&#xff0c;学起来会更加轻松&#xff0c;相关课程有数字电路…

python-基础篇-类与对象/面向对象程序设计-是什么

文章目录 定义一&#xff1a;面对对象是一种编程思想定义一&#xff1a;面向对象是一种抽象1、面向对象的两个基本概念2、面向对象的三大特性 定义一&#xff1a;你是土豪&#xff0c;全家都是土豪面向对象编程基础类和对象定义类创建和使用对象访问可见性问题面向对象的支柱 定…

C++初学者指南第一步---4.基本类型

C初学者指南第一步—4.基本类型 文章目录 C初学者指南第一步---4.基本类型1.变量声明2.快速概览Booleans 布尔型Characters 字符型Signed Integers 有符号整数Unsigned Integers 无符号整数Floating Point Types 浮点数类型 3.Common Number Representations 常用的数字表示常用…

用Copilot画漫画,Luma AI生成视频:解锁创意新玩法

近年来&#xff0c;随着人工智能技术的不断发展&#xff0c;各种创意工具也层出不穷。今天&#xff0c;我们就来介绍一种全新的创作方式&#xff1a;使用Copilot画漫画&#xff0c;再将漫画放入Luma AI生成视频。 Copilot&#xff1a;你的AI绘画助手 Copilot是一款基于人工智…

【Kubernetes项目部署】k8s集群+高可用、负载均衡+防火墙

项目架构图 &#xff08;1&#xff09;部署 kubernetes 集群 详见&#xff1a;http://t.csdnimg.cn/RLveS &#xff08;2&#xff09; 在 Kubernetes 环境中&#xff0c;通过yaml文件的方式&#xff0c;创建2个Nginx Pod分别放置在两个不同的节点上&#xff1b; Pod使用hostP…

TCP及UDP协议

tcp是点到点的&#xff0c;只有一条路径&#xff0c;到达顺序和发送顺序是相同的 回复的确认号是序发送端的序列号加上data的长度 1910 发送端的序列号也是那么算的 ack和下一个seq一样 那就没问题 三次握手四次挥手&#xff1a; 为啥是三次呢&#xff1f; 假如一次&#xf…

SpringBoot使用jasypt实现数据库信息的脱敏,以此来保护数据库的用户名username和密码password(容易上手,详细)

1.为什么要有这个需求&#xff1f; 一般当我们自己练习的时候&#xff0c;username和password直接是爆露出来的 假如别人路过你旁边时看到了你的数据库账号密码&#xff0c;他跑到他的电脑打开navicat直接就是一顿连接&#xff0c;直接疯狂删除你的数据库&#xff0c;那可就废…