PyTorch深度学习实战——使用深度Q学习进行Pong游戏

PyTorch深度学习实战——使用深度Q学习进行Pong游戏

    • 0. 前言
    • 1. 结合固定目标网络的深度 Q 学习模型
      • 1.1 模型输入
      • 1.2 模型策略
    • 2. 实现深度 Q 学习进行 Pong 游戏
    • 相关链接

0. 前言

我们已经学习了如何利用深度 Q 学习来进行 Gym 中的 CartPole 游戏。在本节中,我们将研究更复杂的 Pong 游戏,并了解如何结合深度 Q 学习与固定目标网络进行此游戏,同时利用基于卷积神经网络 (Convolutional Neural Networks, CNN) 的模型替代普通神经网络。

1. 结合固定目标网络的深度 Q 学习模型

1.1 模型输入

在本节中,我们的目标是构建一个可以与计算机进行乒乓球对战并击败它的智能体,该智能体预计能够获得 21 分。我们采用以下策略训练智能体用于进行 Pong 游戏:
裁剪图像的无关部分,获取游戏当前帧(状态):

模型输入

在上示图像中,我们获取了原始图像,并裁剪原始图像的顶部和底部像素。

1.2 模型策略

为了构建具有固定目标网络的深度 Q 学习模型,使用以下策略:

  • 堆叠四个连续的帧——智能体需要状态序列了解球是否正在向它靠近
  • 智能体在初始阶段采取随机动作进行游戏,并将当前状态、未来状态、采取的动作和奖励存储在内存中,只保留最近的 10,000 个动作的信息,并清除超过 10,000 的历史记录
  • 构建预测网络,从内存中获取状态样本并预测可能动作的值
  • 定义作为预测网络的副本——目标网络
  • 预测网络每更新 1000 次就更新一次目标网络
  • 1000epoch 结束时目标网络的权重与预测网络的权重相同
  • 利用目标网络计算下一个状态下最佳动作的 Q 值
  • 对于预测网络提议的动作,我们期望它预测即时奖励和下一个状态中最佳动作的 Q 值的总和
  • 最小化预测网络的 MSE 损失
  • 令智能体持续游戏,直到奖励最大化

2. 实现深度 Q 学习进行 Pong 游戏

根据上一小节的策略,使用 PyTorch 实现智能体,用于最大化 Pong 游戏奖励。

(1) 导入相关库,搭建游戏环境:

import gym
import numpy as np
import cv2
from collections import deque
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple, deque
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make('PongDeterministic-v0')

(2) 获取状态空间和动作空间大小:

state_size = env.observation_space.shape[0]
action_size = env.action_space.n

(3) 定义预处理函数,以便删除无关的底部和顶部像素:

def preprocess_frame(frame): 
    bkg_color = np.array([144, 72, 17])
    img = np.mean(frame[34:-16:2,::2]-bkg_color, axis=-1)/255.
    resized_image = img
    return resized_image

(4) 定义用于堆叠四个连续游戏帧的函数。

函数将 stack_frames、当前状态 state 和标志 is_new_episode 作为输入:

def stack_frames(stacked_frames, state, is_new_episode):
    # Preprocess frame
    frame = preprocess_frame(state)
    stack_size = 4

如果为新一回合,则从初始帧开始:

    if is_new_episode:
        # Clear stacked_frames
        stacked_frames = deque([np.zeros((80,80), dtype=np.uint8) for i in range(stack_size)], maxlen=4)
        # Because we're in a new episode, copy the same frame 4x
        for i in range(stack_size):
            stacked_frames.append(frame) 
        # Stack the frames
        stacked_state = np.stack(stacked_frames, axis=2).transpose(2, 0, 1)

如果并非新一回合,则从 stacked_frames 中删除最旧的帧并添加最新的帧:

    else:
        # Append frame to deque, automatically removes the oldest frame
        stacked_frames.append(frame)
        # Build the stacked state (first dimension specifies different frames)
        stacked_state = np.stack(stacked_frames, axis=2).transpose(2, 0, 1) 
    return stacked_state, stacked_frames

(5) 定义网络架构 DQNetwork

class DQNetwork(nn.Module):
    def __init__(self, states, action_size):
        super(DQNetwork, self).__init__()
        
        self.conv1 = nn.Conv2d(4, 32, (8, 8), stride=4)
        self.conv2 = nn.Conv2d(32, 64, (4, 4), stride=2)
        self.conv3 = nn.Conv2d(64, 64, (3, 3), stride=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(2304, 512)
        self.fc2 = nn.Linear(512, action_size)
        
    def forward(self, state): 
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

(6) 定义 Agent 类。

定义 __init__ 方法:

class Agent():
    def __init__(self, state_size, action_size):
        
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(0)

        ## hyperparameters
        self.buffer_size = 10000
        self.batch_size = 32
        self.gamma = 0.99
        self.lr = 0.0001
        self.update_every = 4
        self.update_every_target = 1000 
        self.learn_every_target_counter = 0
        # Q-Network
        self.local = DQNetwork(state_size, action_size).to(device)
        self.target = DQNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.local.parameters(), lr=self.lr)

        # Replay memory
        self.memory = deque(maxlen=self.buffer_size) 
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        # Initialize time step (for updating every few steps)
        self.t_step = 0

__init__ 方法添加目标网络和目标网络的更新频率。

定义权重更新方法 step

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.append(self.experience(state[None], action, reward, next_state[None], done))
        
        # Learn every update_every time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
   # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.batch_size:
                experiences = self.sample_experiences()
                self.learn(experiences, self.gamma)

定义 act 方法,根据给定状态获取要执行的动作:

    def act(self, state, eps=0.):
        # Epsilon-greedy action selection
        if random.random() > eps:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
            self.local.eval()
            with torch.no_grad():
                action_values = self.local(state)
            self.local.train()
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

定义 learn 函数,训练预测模型:

    def learn(self, experiences, gamma):
        self.learn_every_target_counter+=1
        states, actions, rewards, next_states, dones = experiences
       # Get expected Q values from local model
        Q_expected = self.local(states).gather(1, actions)

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current state
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        
        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)

        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        if self.learn_every_target_counter%1000 ==0:
            self.target_update() 

在以上代码中,Q_targets_next 是使用目标模型预测的,而非预测模型,每经过 1,000 步后更新目标网络,其中 learn_every_target_counter 是用于确定是否应该更新目标模型的计数器。

定义函数 target_update 用于更新目标模型:

    def target_update(self):
        print('target updating')
        self.target.load_state_dict(self.local.state_dict())

定义函数 sample_experiences 用于从内存中采样经验:

    def sample_experiences(self):
        experiences = random.sample(self.memory, k=self.batch_size)        
        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)        
        return (states, actions, rewards, next_states, dones)

(7) 定义智能体对象:

agent = Agent(state_size, action_size)

(8) 定义用于训练智能体的参数:

n_episodes=5000
max_t=5000
eps_start=1.0
eps_end=0.02
eps_decay=0.995
scores = [] # list containing scores from each episode
scores_window = deque(maxlen=100) # last 100 scores
eps = eps_start
stack_size = 4
stacked_frames = deque([np.zeros((80,80), dtype=np.int) for i in range(stack_size)], maxlen=stack_size) 

(9) 训练智能体:

for i_episode in range(1, n_episodes+1):
    state = env.reset()
    state, frames = stack_frames(stacked_frames, state, True)
    score = 0
    for i in range(max_t):
        action = agent.act(state, eps)
        next_state, reward, done, _ = env.step(action)
        next_state, frames = stack_frames(frames, next_state, False)
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
        if done:
            break 
    scores_window.append(score) # save most recent score
    scores.append(score) # save most recent score
    eps = max(eps_end, eps_decay*eps) # decrease epsilon
    print('\rEpisode {}\tReward {} \tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode,score,np.mean(scores_window), eps), end="")
    if i_episode % 100 == 0:
        print('\rEpisode {}\tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode, np.mean(scores_window), eps))

随着训练回合的增加分数变化情况如下:

import matplotlib.pyplot as plt
plt.plot(scores)
plt.title('Scores over increasing episodes')
plt.show()

模型性能监测

从上图中可以看出,智能体逐渐学会了进行 Pong 游戏,能够获得较高奖励。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(45)——强化学习
PyTorch深度学习实战(46)——深度Q学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/842849.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis之List列表

目录 一.列表讲解 二.列表命令 三.内部编码 四.应用场景 Redis的学习专栏:http://t.csdnimg.cn/a8cvV 一.列表讲解 列表类型是用来存储多个有序的字符串,如下所示,a、b、c、d、e五个元素从左到右组成了一个有序的列表,列表中的…

git使用、git与idea结合、gitee、gitlab

本文章基于黑马程序javase模块中的"git"部分 先言:git在集成idea中,不同版本的idea中页面显示不同,操作时更注重基于选项的文字;git基于命令操作参考文档实现即可,idea工具继承使用重点掌握 1.git概述 git是目前世界上最先进的分布式文件版本控制系统 分布式:将…

快手矩阵系统全解析:功能、优势与特点一网打尽

在数字化时代,短视频已成为连接创作者与观众的重要媒介。快手矩阵系统以其独特的功能和优势,为短视频的创作、管理和发布提供了一站式解决方案,极大地提升了内容运营的效率和效果。 功能概览 智能创作:AI技术的应用使得快手矩阵…

ELK日志管理与应用

目录 一.ELK收集nginx日志 二.收集tomcat日志 三.Filebeat 一.ELK收集nginx日志 1.搭建好ELKlogstashkibana架构 2.关闭防火墙和selinux systemctl stop firewalld setenforce 0 3.安装nginx [rootlocalhost ~]# yum install epel-release.noarch -y [rootlocalhost …

因果推断 | 双重机器学习(DML)算法原理和实例应用

文章目录 1 引言2 DML算法原理2.1 问题阐述2.2 DML算法 3 DML代码实现3.1 策略变量为0/1变量3.2 策略变量为连续变量 4 总结5 相关阅读 1 引言 小伙伴们,好久不见呀。 距离上次更新已经过去了一个半月,上次发文章时还信誓旦旦地表达自己后续目标是3周更…

elementUI在手机端使用遇到的问题总结

之前的博客有写过用vue2elementUI封装手机端选择器picker组件,支持单选、多选、远程搜索多选,最终真机调试的时候发现有很多细节样式需要调整。此篇博客记录下我调试过程中遇到的问题和解决方法。 一、手机真机怎么连电脑本地代码调试? 1.确…

Spring通过工厂方法进行配置

在Spring的世界中, 我们通常会利用 xml配置文件 或者 annotation注解方式来配置bean实例! 在第一种利用 xml配置文件 方式中, 还包括如下三小类 反射模式(我们前面的所有配置都是这种模式)工厂方法模式Factory Bean模…

linux shell脚本编程(分支语句、循环语句)

一、分支语句 1、语法结构 : if 表达式 then 命令表 fi 如果表达式为真 , 则执行命令表中的命令 ; 否则退出 if 语句 , 即执行 fi 后面的语句。 if 和 fi 是条件语句的语句括号 , 必须成对使用 ;命令表中的命令可以是一条 , 也可以是若干条。 2、语法结构为 : if 表达式 t…

CSS3实现提示工具的渐入渐出效果及CSS3动画简介

上一篇文章用CSS3实现了一个提示工具,本文介绍如何利用CSS3实现提示工具以渐入的方式呈现,以渐出的方式消失。 CSS3主要可以通过两个样式来实现动画效果:animation和transition。 其中,animation需要自己定义一组关键帧从而实现…

vue使用x6画流程图,简单使用

官网 https://x6.antv.antgroup.com/tutorial/getting-started 安装 npm install antv/x6 --save 使用 <template><div>3333<div id"container" style"width: 800px;height: 800px;"></div></div> </template> <…

详细分析Springboot自定义启动界面(附Demo)

目录 前言1. banner.text1.1 配置文件关闭1.2 启动类关闭1.3 命令行关闭 2. 自定义Banner类3. 自动配置类4. 总结 前言 实现自定义启动动画是一项有趣的任务&#xff0c;虽然Spring Boot本身不提供内置的动画功能&#xff0c;但可以通过一些技巧来实现 以下主要以Demo的形式展…

Linux系统学习日记——vim操作手册

Vim编辑器是linux下的一个命令行编辑器&#xff0c;类似于我们windows下的记事本。 目录 打开文件 编辑 保存退出 打开文件 打开 hello.c不存在也可以打开&#xff0c;保存时vim会自动创建。 效果 Vim打开时&#xff0c;处于命令模式&#xff0c;即执行命令的模式&#x…

STM32自己从零开始实操:PCB全过程

一、PCB总体分布 以下只能让大家看到各个模块大致分布在板子的哪一块&#xff0c;只能说每个人画都有自己的理由&#xff1a; 电源&#xff1a;从外部接入电源&#xff0c;5V接到中间&#xff0c;向上变成4V供给无线&#xff0c;向下变成3V供给下面的接口&#xff08;也刻意放…

分布式搜索引擎ES-Elasticsearch进阶

1.head与postman基于索引的操作 引入概念&#xff1a; 集群健康&#xff1a; green 所有的主分片和副本分片都正常运行。你的集群是100%可用 yellow 所有的主分片都正常运行&#xff0c;但不是所有的副本分片都正常运行。 red 有主分片没能正常运行。 查询es集群健康状态&…

dd命令参数

GPT-4 (OpenAI) dd 是一个用于在 Unix 和 Unix 类操作系统中转换和复制文件的命令行工具。虽然它的名字听起来可能很简单&#xff0c;但是 dd 拥有很多功能强大的选项。下面列出了一些常见的 dd 参数&#xff08;选项&#xff09;&#xff1a; 1. if<文件名>&#xff…

mysql存储引擎和备份

索引 事务 存储引擎 概念&#xff1a;存储引擎&#xff0c;就是一种数据库存储数据的机制&#xff0c;索引的技巧&#xff0c;锁定水平。 存储引擎。存储的方式和存储的格式。 存储引擎也属于mysql当中的组件&#xff0c;实际上操作的&#xff0c;执行的就是数据的读写I/O。…

【ARM】MDK-解决CMSIS_DAP.DLL missing报错

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 记录解决CMSIS_DAP.DLL missing的报错情况&#xff0c;对应相关报错信息&#xff0c;供后续客户参考&#xff0c;快速解决客户问题。 2、 问题场景 客户进行硬件调试时&#xff0c;发现Target设置内有CMSIS_DAP.DL…

趣谈linux操作系统 9 网络系统-读书笔记

文章目录 网络协议栈基础知识回顾网络分层网络分层的目的各层作用简介延伸-ip地址,有类,无类,cidr socket实现分析tcp/udp回顾socket编程回顾TCP编程回顾UDP编程回顾差异 socket相关接口实现浅析sokcet实现解析创建socket的三个参数socket函数定义及其参数创建socket结构体关联…

StarRocks on AWS Graviton3,实现 50% 以上性价比提升

在数据时代&#xff0c;企业拥有前所未有的大量数据资产&#xff0c;但如何从海量数据中发掘价值成为挑战。数据分析凭借强大的分析能力&#xff0c;可从不同维度挖掘数据中蕴含的见解和规律&#xff0c;为企业战略决策提供依据。数据分析在营销、风险管控、产品优化等领域发挥…

8、添加第三方包

目录 1、安装Django Debug Toolbar Django的一个优势就是有丰富的第三方包生态系统。这些由社区开发的包&#xff0c;可以用来快速扩展应用程序的功能集 1、安装Django Debug Toolbar Django Debug Toolbar位于名列前三的第三方包之一 这是一个用于调试Debug Web应用程序的有…