一文实践强化学习训练游戏ai--doom枪战游戏实践

一文实践强化学习训练游戏ai–doom枪战游戏实践
上次文章写道下载doom的环境并尝试了简单的操作,这次让我们来进行对象化和训练、验证,如果你有基础,可以直接阅读本文,不然请你先阅读Doom基础知识,其中包含了下载、动作等等的基础知识。
本次与之前的马里奥训练不同,马里奥是已经有做好的step等函数的,而这个doom没有,但也因此我们可以更好的一窥训练的过程。
完整代码在最后,可以复制执行。

文章目录

  • 一、训练模型
    • 1、vizdoom_train类
      • 1)__init__
      • 2)step
      • 3)togray
      • 4)其他函数
    • 2、保存模型函数
    • 3、训练模型
  • 二、成果验收
  • 完整代码
    • 训练代码
    • 测试代码

一、训练模型

1、vizdoom_train类

这是我们的训练基类,由于我们想用openai gym环境,因此我们需要手写此环境必须的__init__、step等函数

1)init

在这个函数中,我们要定义训练的观察空间(即游戏图像)、动作空间(即ai可以执行的操作)和一些基础设置。

        VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"
        self.game = vizdoom.DoomGame()
        self.game.load_config(VizDoom_basic_cfg)

        if render == False: 
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        self.game.init()

此处,我们先指定游戏文件的位置,然后加一个判断,决定是否允许游戏窗口显示出来,最后初始化游戏。
初始化后,我们就要规定观察空间和动作空间了

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) 
        self.action_space = Discrete(3)

观察空间是我们游戏的界面,这是一个灰化后的图像,所以维度为1,大小为100*160
动作空间为离散空间3,即可以选取动作0、1、2,我们只需要在后续的代码中指定数字指代的动作就可以了。

2)step

step函数非常关键,这是ai执行动作的时候会调用的函数。
首先,我们定义我们的动作

        actions = np.identity(3,dtype=np.uint8)
        reward = self.game.make_action(actions[action], 4) 

这部分的详细解释在Doom基础知识中,事实上就是定义一个矩阵,调用游戏文件中的左移、右移和射击。
接下来,我们要获取当前的状态,比如得分,游戏图像等,这样我们才可以训练。

        try:
            state = self.game.get_state()
            img = state.screen_buffer
            img = self.togray(img)
            info = state.game_variables[0]
        except:
            img = np.zeros(self.observation_space.shape)
            info = 0 
        finally:
            info = {"info":info}
            done = self.game.is_episode_finished()
        #img_show(img)
        return img,reward,done,info

使用try,是因为gameover时有些内容获取不到,为了防止程序因此暂停,用try。最后,我们要把info变成字典形式,这是因为openai gym环境时这么要求的。为了方便理解,这里可以调用imgshow,查看目前的图像,其实现如下。

def img_show(img):
    plt.imshow(img)
    plt.show()
    time.sleep(5)

3)togray

灰度化图像,我们知道,彩色图像时由rgb三个颜色矩阵组成,但这么大的数据量给我们的训练增添的很多负担,于是我们采用灰度图。同时,我们缩小图像,这样可以训练的更快。

    def togray(self,observation):
        gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)
        state = np.reshape(resize, (100,160,1))
        return state

observation 是一个 NumPy 数组,通常表示图像数据。
np.moveaxis 是 NumPy 库中的一个函数,用于重新排列数组的轴。
参数 0 表示将第0轴(通常是颜色通道)移动到新位置的最后一个轴位置(即 -1)。
例如,如果 observation 的形状是 (C, H, W)(即颜色通道在第一个维度),经过 np.moveaxis 后,形状将变为 (H, W, C),这对于 OpenCV 处理图像更为常见,因为 OpenCV 期望颜色通道是图像的最后一个维度。

cv2.cvtColor 是 OpenCV 中用于颜色空间转换的函数。
第一个参数是输入图像(在这里是经过 np.moveaxis 处理后的图像)。
第二个参数 cv2.COLOR_BGR2GRAY 指定了将图像从 BGR 颜色空间转换为灰度图像。

4)其他函数

我们要定义一个关闭游戏的函数close

   def close(self):
        self.game.close()

以及一个reset函数,用于在结束一个游戏后,重置状态,继续下一轮训练

    def reset(self):
        state = self.game.new_episode()
        state = self.game.get_state()
        return self.togray(state.screen_buffer)

至此,我们已经成功地把这个独立游戏包装成了可以使用openai gym的环境的游戏。

2、保存模型函数

class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

我们使用这段代码来存档数据,这部分复制即可

3、训练模型

首先,我们指定训练结果保存的路径

    CHECKPOINT_DIR = './train/train_basic'
    LOG_DIR = './logs/log_basic'
    callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    ```

然后我们调用训练函数进行训练

    env = vizdoom_train(render=False)
    model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)
    model.learn(total_timesteps=100000, callback=callback)

这里我们使用PPO这个强化学习算法。

经过几十分钟的等待,就得到训练好的模型了

二、成果验收

训练好模型后,我们要使用模型,看看效果。
首先,我们载入训练好的模型

model = PPO.load('./train/train_basic/best_model_100000')

然后测试以下模型的平均得分

mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)

这样还不够直观,我们让ai玩给我们看


for episode in range(100): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        time.sleep(0.20)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(total_reward, episode))
    time.sleep(2)

我们首先让ai模型预测,然后将预测的动作输入step函数,然后展示页面。

如图所示,平均得分非常高,基本能快速索敌,然后一枪秒杀。
至此,我们完成了ai的训练。

完整代码

训练代码

from gym import Env
from gym.spaces import Discrete,Box
import cv2
from vizdoom import *
import vizdoom
import random
import time
import numpy as np
#DIS离散空间 作用类似于random
#box用来装游戏图像
from matplotlib import pyplot as plt
class vizdoom_train(Env):
    def __init__(self, render=True):
        super().__init__()
        VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"
        self.game = vizdoom.DoomGame()
        self.game.load_config(VizDoom_basic_cfg)

        if render == False: 
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        self.game.init()

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) 
        self.action_space = Discrete(3)
    def step(self,action):
        actions = np.identity(3,dtype=np.uint8)
        reward = self.game.make_action(actions[action], 4) 
        try:
            state = self.game.get_state()
            img = state.screen_buffer
            img = self.togray(img)
            info = state.game_variables[0]
        except:
            img = np.zeros(self.observation_space.shape)
            info = 0 
        finally:
            info = {"info":info}
            done = self.game.is_episode_finished()
        #img_show(img)
        return img,reward,done,info
    
    def close(self):
        self.game.close()

    def reset(self):
        state = self.game.new_episode()
        state = self.game.get_state()
        return self.togray(state.screen_buffer)
    
    def togray(self,observation):
        gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)
        state = np.reshape(resize, (100,160,1))
        return state
    
def img_show(img):
    plt.imshow(img)
    plt.show()
    time.sleep(5)


# Import os for file nav
import os 
# Import callback class from sb3
from stable_baselines3.common.callbacks import BaseCallback
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True
    
if __name__ == "__main__":

    CHECKPOINT_DIR = './train/train_basic'
    LOG_DIR = './logs/log_basic'
    callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    

    train = vizdoom_train()

    from stable_baselines3 import PPO

    env = vizdoom_train(render=False)
    model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)
    model.learn(total_timesteps=100000, callback=callback)

测试代码

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
import time
model = PPO.load('./train/train_basic/best_model_100000')

from second_gym import vizdoom_train
env = vizdoom_train(render=True)

mean_reward, _ ,_= evaluate_policy(model, env, n_eval_episodes=5)

print(mean_reward)

for episode in range(100): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        time.sleep(0.20)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(total_reward, episode))
    time.sleep(2)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/788371.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

android CameraX构建相机拍照

Android CameraX 是一个 Jetpack 支持库,旨在简化相机应用的开发工作。它提供了一致且易用的API接口,适用于大多数Android设备,并可向后兼容至Android 5.0(API级别21)。 CameraX解决了在多种设备上实现相机功能时所遇…

14-56 剑和诗人30 - IaC、PaC 和 OaC 在云成功中的作用

介绍 随着各大企业在 2024 年加速采用云计算,基础设施即代码 (IaC)、策略即代码 (PaC) 和优化即代码 (OaC) 已成为成功实现云迁移、IT 现代化和业务转型的关键功能。 让我在云计划的背景下全面了解这些代码功能的当前状态。我们将研究现代云基础设施趋势、IaC、Pa…

java:获取当前的日期和时间

// 获取当前的日期和时间LocalDateTime now LocalDateTime.now();// 定义日期时间格式化器DateTimeFormatter formatter DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");// 格式化日期时间String formattedDateTime now.format(formatter);// 打印结果Syste…

【数据结构和算法的概念等】

目录 一、数据结构1、数据结构的基本概念2、数据结构的三要素2.1 数据的逻辑结构2.2 数据的存储(物理)结构2.3 数据的运算 二、算法1、算法概念2、算法的特性及特点3、算法分析 一、数据结构 1、数据结构的基本概念 数据: 是所有能输入到计…

前端八股文 对事件循环的理解

对事件循环的理解 思维导图 图示 实际案例的执行过程 总结

能源电子领域2区SCI,版面稀缺,即将截稿,无版面费!

【SciencePub学术】今天小编给大家推荐1本能源电子领域的SCI!影响因子1.0-2.0之间,最重要的是审稿周期较短,对急投的学者较为友好! 能源电子类SCI 01 / 期刊概况 【期刊简介】IF:1.0-2.0,JCR2区&#xf…

【C++】C++入门基础--引用,inline,nullptr

文章目录 前言一、引用?1.1 引用的概念和定义1.2 引用的特性1.3 引用的使用1.4 const引用(常引用)1.5 指针和引用的关系 二、inline2.1inline概念和定义2.2 inline使用2.3 inline注意事项 三.nullptr总结 前言 上一篇文章我们介绍了C中的命名…

枚举对象序列化规则(将Java枚举转换为JSON字符串的步骤)

文章目录 引言I 案例分析1.1 接口签名计算1.2 请求对象1.3 枚举对象序列化II 在JSON中以枚举的code值来表示枚举的实现方式2.1 自定义toString方法返回code引言 在Java中,每个对象都有一个toString方法,用于返回该对象的字符串表示。默认情况下,Enum类的toString方法返回的…

C语言笔记30 •单链表经典算法OJ题-2.移除链表元素•

移除链表元素 1.问题 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 2.代码实现&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <stdlib.h&g…

【RHCE】转发服务器实验

1.在本地主机上操作 2.在客户端操作设置主机的IP地址为dns 3.测试,客户机是否能ping通

特征及特征选择

1、特征&#xff08;Feature&#xff09;是什么&#xff1f; 特征是数据集中的一个可量化的属性或变量&#xff0c;用于描述数据点的特性。 特征可以是连续的数值&#xff0c;如身高、体重等&#xff0c;也可以是离散的类别&#xff0c;如性别、种族等。 常见的特征有边缘、角、…

Mosh|初学者 SQL 教程

sql文件链接&#xff1a;链接: https://pan.baidu.com/s/1okjsgssdxMkfKf8FEos7DA?pwdf9a9 提取码: f9a9 在mysql workbench 导入 create_databases.sql 文件&#xff0c;下面是运行成功的界面 快捷方式&#xff1a;全部运行可以同时按下controlcommandenter &#xff0c;或者…

Linux学习之网络配置问题

Linux学习——那些我们网络配置遇到过的问题&#xff1f;ping不通百度&#xff1f;XShell连接不上&#xff1f;&#xff08;超详细&#xff09; &#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感…

详细谈谈负载均衡的startupProbe探针、livenessProbe探针、readnessProbe探针如何使用以及使用差异化

文章目录 startupProbe探针startupProbe说明示例配置参数解释 使用场景说明实例——要求&#xff1a; 容器在8秒内完成启动&#xff0c;否则杀死对应容器工作流程说明timeoutSeconds: 和 periodSeconds: 参数顺序说明 livenessProbe探针livenessProbe说明示例配置参数解释 使用…

生产者消费者模型和线程同步问题

文章目录 线程同步概念生产者消费者模型条件变量使用条件变量唤醒条件变量 阻塞队列 线程同步概念 互斥能保证安全,但是仅有安全不够,同步可以更高效的使用资源 生产者消费者模型 下面就基于生产者消费者来深入线程同步等概念: 如何理解生产消费者模型: 以函数调用为例: 两…

LNMP搭建Discuz和Wordpress

1、LNMP L:linux操作系统 N&#xff1a;nginx展示前端页面web服务 M&#xff1a;mysql数据库&#xff0c;保存用户和密码&#xff0c;以及论坛相关的内容 P&#xff1a;php动态请求转发的中间件 数据库的作用&#xff1a; 登录时验证用户名和密码 创建用户和密码 发布和…

存储产品选型策略 OSS生命周期管理与运维

最近在看阿里云的 云存储通关实践认证训练营这个课程还是不错的。 存储产品选型策略、对象存储OSS入门、基于对象存储OSS快速搭建网盘、 如何做好权限控制、如何做好数据安全、如何做好数据管理、涉及对象存储OSS的权限控制、使用OSS完成静态网站托管、对OSS中存储的数据进行分…

ubuntu使用kubeadm搭建k8s集群

一、卸载k8s kubeadm reset -f modprobe -r ipip lsmod rm -rf ~/.kube/ rm -rf /etc/kubernetes/ rm -rf /etc/systemd/system/kubelet.service.d rm -rf /etc/systemd/system/kubelet.service rm -rf /usr/bin/kube* rm -rf /etc/cni rm -rf /opt/cni rm -rf /var/lib/etcd …

压缩感知2——算法模型

采集原理 其中Y就是压缩后的信号表示(M维)&#xff0c;Φ表示采集的测量矩阵&#xff0c;可以是一个随机矩阵&#xff0c;X代表原始的数字信号&#xff08;N维&#xff09;。 常见的测量矩阵——随机高斯矩阵 随机伯努利矩阵 稀疏随机矩阵等&#xff0c;矩阵需要满足与信号的稀…

57、基于概率神经网络(PNN)的分类(matlab)

1、基于概率神经网络(PNN)的分类简介 PNN&#xff08;Probabilistic Neural Network&#xff0c;概率神经网络&#xff09;是一种基于概率论的神经网络模型&#xff0c;主要用于解决分类问题。PNN最早由马科夫斯基和马西金在1993年提出&#xff0c;是一种非常有效的分类算法。…