【强化学习】Deep Q Learning

Deep Q Learning

在前两篇文章中,我们发现RL模型的目标是基于观察空间 (observations) 和最大化奖励和 (maximumize sum rewards) 的。

如果我们能够拟合出一个函数 (function) 来解决上述问题,那就可以避免存储一个 (在Double Q-Learning中甚至是两个) 巨大的Q_table。

Tabular -> Function

  • Continous Observation: 函数能够让我们处理连续的观察空间,而表只能处理离散的。
  • Saving the space: 不用存储 len(state) * len(action) 大小的Q_table

在早期人们试过使用核函数或者线性函数等各种方法去拟合这个function,但后来深度神经网络出现后人们纷纷开始研究如何用DNN来拟合。

然而以上的拟合方式不免存在一个问题,我们期望得到一个DNN,使得DNN(state)->Q-value

可是强化学习中,最好的Q-value在开始时是不知道的 (这也是强化学习和机器学习不一样的地方:我们不知道能否训练到一个Q值,直到有人把它训练出来),这就导致我们在训练过程中没有目标函数。

Natural Deep Q Learning

所有的第一步必须从高维的感官输入中获得对环境的有效表示

深度Q网络(DQN)是一种将深度学习和Q学习相结合的强化学习方法。DQN由DeepMind于2015年提出,并在玩Atari视频游戏方面取得了显著的成功。DQN的核心原理是使用深度神经网络来近似Q函数,即在给定状态下采取某一动作的预期累积奖励。

DQN的关键创新

  1. 使用神经网络近似Q函数

    • 传统的Q学习使用表格(Q表)来存储每个状态-动作对的Q值。当状态空间很大或连续时,这变得不切实际。
    • DQN通过使用深度神经网络来近似Q函数,克服了这一限制。网络输入是状态,输出是该状态下所有可能动作的Q值。
  2. 经验回放

    • DQN引入了经验回放机制,即将代理的经验(状态、动作、奖励、新状态)存储在回放缓冲区中。

      image-20231114211049019
    • 训练时,从这个缓冲区中随机抽取小批量经验进行学习。这增加了数据的多样性,减少了样本之间的相关性,从而稳定了训练。

  3. 目标网络

    • DQN使用两个结构相同但参数不同的网络:一个是在线网络 (dqn_model),用于当前Q值的估计;另一个是目标网络 (target_model),用于计算目标Q值。
    • 目标网络的参数定期从在线网络复制过来,但不是每个训练步骤都更新。这减少了学习过程中的震荡,提高了稳定性。
    image-20231114211236348

训练过程

  • 在每个时间步,代理根据当前的Q值(通常结合探索策略,如ε-贪婪)选择一个动作,接收环境的反馈(新状态和奖励),并将这个转换存储在经验回放缓冲区中。
  • 训练神经网络时,从缓冲区中随机抽取一批经验,然后使用贝尔曼方程计算目标Q值和预测Q值,通过最小化这两者之间的差异来更新网络参数。

DQN解决月球着陆问题

导入环境

import time
from collections import defaultdict

import gymnasium as gym
import numpy as np
import random

from matplotlib import pyplot as plt, animation
from IPython.display import display, clear_output
env = gym.make("LunarLander-v2", continuous=False, render_mode='rgb_array')

定义经验池

class ExperienceBuffer:
    def __init__(self, size=0):
        self.states = []
        self.actions = []
        self.rewards = []
        self.states_next = []
        self.actions_next = []
        self.size = 0

    def clear(self):
        self.__init__()

    def append(self, s, a, r, s_n, a_n):
        self.states.append(s)
        self.actions.append(a)
        self.rewards.append(r)
        self.states_next.append(s_n)
        self.actions_next.append(a_n)
        self.size += 1

    def batch(self, batch_size=128):
        indices = np.random.choice(self.size, size=batch_size, replace=True)
        return  (
            np.array(self.states)[indices],
            np.array(self.actions)[indices],
            np.array(self.rewards)[indices],
            np.array(self.states_next)[indices],
            np.array(self.actions_next)[indices],
        )
import torch

from torch import nn
from torch.nn.functional import relu
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定义DQN

class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = 32
        self.linear_1 = nn.Linear(self.state_size, self.hidden_size)
        self.linear_2 = nn.Linear(self.hidden_size, self.action_size)

        nn.init.uniform_(self.linear_1.weight, a=-0.1, b=0.1)
        nn.init.uniform_(self.linear_2.weight, a=-0.1, b=0.1)

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.tensor([state], dtype=torch.float)
        state = state.to(device)
        return self.linear_2(relu(self.linear_1(state)))

定义policy

def policy(model, state, eval=False):
    eps = 0.1

    if not eval and random.random() < eps:
        return random.randint(0, model.action_size - 1)
    else:
        q_values = model(torch.tensor([state], dtype=torch.float))
        action = torch.multinomial(F.softmax(q_values), num_samples=1)
        return int(action[0])

collect

dqn_model = DQN(state_size=8, action_size=4).to(device)
target_model = DQN(state_size=8, action_size=4).to(device)
from tqdm.notebook import tqdm
# 学习率
alpha = 0.9
# 折扣因子
gamma = 0.95
# 训练次数
episode = 1000
experience_buffer = ExperienceBuffer()

eval_iter = 100
eval_num = 100

# collect
def collect():
    for e in tqdm(range(episode)):
        state, info = env.reset()
        action = policy(dqn_model, state)

        sum_reward = 0

        while True:
            state_next, reward, terminated, truncated, info_next = env.step(action)
            action_next= policy(dqn_model, state_next)

            sum_reward += reward

            experience_buffer.append(
                state, action, reward, state_next, action_next
            )

            if terminated or truncated:
                break

            state = state_next
            info = info_next
            action = action_next

learning

## learning
from torch.optim import Adam

loss_fn = nn.MSELoss()
optimizer = Adam(lr=1e-5, params=dqn_model.parameters())

losses = []
target_fix_period = 5
epoch = 3

def train():
    for e in range(epoch):
        batch_size = 128
        for i in range(experience_buffer.size // batch_size):
            s, a, r, s_n, a_n = experience_buffer.batch(batch_size)

            s = torch.tensor(s, dtype=torch.float).to(device)
            s_n = torch.tensor(s_n, dtype=torch.float).to(device)
            r = torch.tensor(r, dtype=torch.float).to(device)
            a = torch.tensor(a, dtype=torch.long).to(device)
            a_n = torch.tensor(a_n, dtype=torch.long).to(device)

            y = r + target_model(s_n).gather(1, a_n.unsqueeze(1)).squeeze(1)
            y_hat = dqn_model(s).gather(1, a.unsqueeze(1)).squeeze(1)

            loss = loss_fn(y, y_hat)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 500 == 0:
                print(f'i == {i}, loss = {loss} ')

            if i % target_fix_period == 0:
                target_model.load_state_dict(dqn_model.state_dict())

a_n:动作
s_n:状态

image-20231205221613164

image-20231205221643890

将状态 s_n 作为输入,target_model的输出是针对每个可能动作的 Q 值;如果 s_n 包含多个状态(比如一个批量),输出将是一个批量的 Q 值

image-20231205221710717

image-20231205221746045

image-20231205221827050

训练

for i in range(10):
    print(f'collect/train: {i}')
    experience_buffer.clear()
    collect()
    train()

结果

task_num = 10
frames = []

for _ in range(10):
    state, _ = env.reset()
    while True:
        action = policy(dqn_model, state, eval=True)
        state_next, reward, terminated, truncated, info_next = env.step(action)
        frames.append(env.render())

        if terminated or truncated:
            break

output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/258708.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis介绍与使用

1、Nosql 1.1 数据存储的发展 1.1.1 只使用Mysql 以前的网站访问量不大&#xff0c;单个数据库是完全够用的。 但是随着互联网的发展&#xff0c;就出现了很多的问题&#xff1a; 数据量太大&#xff0c;服务器放不下 访问量太大&#xff0c;服务器也承受不了 1.1.2 缓存…

STL stack练习

CSTL之stack栈容器 - 数据结构教程 - C语言网CSTL之stack栈容器1.再谈栈回顾一下之前所学的栈&#xff0c;栈是一种先进后出的数据结构&#xff0c;而实现方式需要创建多个结构体&#xff0c;通过链式的方式进行实现&#xff0c;这是标准的栈的思路&#xff0c;而在STL中栈可以…

ctfshow(web190-web200)

目录 web190 web191 web192 web193 web194 web195 web196 web197 web198 web199 web200 web190 为什么要有admin呢 这也是试出来的 这个admin必须是数据库中存在的 这样才能使用布尔注入 因为这个时候登录 有两种返回结果 一种密码错误 一种就是用户名错误 admin an…

HBase 整合 Phoenix

目录 一、Phoenix 简介 1.1 Phoenix定义 1.2 为什么使用 Phoenix 二、Phoenix 快速入门 2.1 安装部署 Phoenix 2.1.1 上传并解压 tar 包 2.1.2 复制 server 包并拷贝到各个节点的 hbase/lib 2.1.3 配置环境变量 2.1.4 重启 HBase 2.1.5 连接 Phoenix 2.2 Phoenix…

FBX模型 转换成带有空间参考的 3DTiles(.b3dm) 数据(FBX glTF 3DTiles)

目录 0 引言1 数据类型介绍1.1 FBX数据1.2 glTF数据1.3 3DTiles1.3.1 简介1.3.2 3DTiles格式的LOD是如何定义1.3.3 文件后缀格式 2 转换工具2.1 CesiumGS /3d-tiles-tools 工具2.1 glf-to-3d-tiles工具 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xf…

Linux下载安装Pychram社区版

一.背景 最近将开发也转到Linuxl发现vscode在Linux上面的速度个方面都比windows快太多了!!! 狠狠爱住,于是准备把pycharm也装上来,由于作者没有edu邮箱,于是拿社区版进行演示 二.具体步骤 2.1下载pycharm社区版 链接:下载PyCharm&#xff1a;JetBrains为专业开发者提供的P…

【经典LeetCode算法题目专栏分类】【第2期】组合与排列问题系列

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 组合总和1 class So…

你应该知道的C语言性能提升法之结构体优化

前两天码哥写了一篇《你应该知道的C语言Cache命中率提升法》的文章&#xff0c;讲述关于地址连续性带来的cache命中率提升&#xff0c;感兴趣的朋友可以先翻看一番。 今天的文章是关于如何优化结构体成员来提升cache命中率的。我们先来看一个例子&#xff1a; 代码一 /* a.c…

一文教你提高写代码效率,程序员别错过!

首先&#xff0c;每个程序员都是会利用工具的人&#xff0c;也有自己囊里私藏的好物。独乐乐不如众乐乐&#xff0c;今天笔者整理了 3 个辅助我们写代码的黑科技&#xff0c;仅供参考。如果你有更好的工具&#xff0c;欢迎评论区分享。 1、Google/Stackoverflow——搜索解决方…

ChimeraX使用教程-安装及基本操作

ChimeraX使用教程-安装及基本操作 1、访问https://www.cgl.ucsf.edu/chimerax/download.html进行下载&#xff0c;然后安装 安装完成后&#xff0c;显示界面 2、基本操作 1、点击file&#xff0c;导入 .PDB 文件。 &#xff08;注&#xff1a;在 alphafold在线预测蛋白》点…

编码器的数学描述

在数字信号处理和通信系统中&#xff0c;编码器扮演着非常重要的角色&#xff0c;它负责将原始信号转换成特定的编码形式&#xff0c;以便于传输、存储和处理。编码器的数学描述是理解其原理和设计实现的关键。本文将围绕编码器的数学描述展开&#xff0c;介绍编码器的基本原理…

智能物联网汽车3d虚拟漫游展示增强消费者对品牌的认同感和归属感

汽车3D虚拟展示系统是一种基于web3D开发建模和VR虚拟现实技术制作的360度立体化三维汽车全景展示。它通过计算机1:1模拟真实的汽车外观、内饰和驾驶体验&#xff0c;让消费者在购车前就能够更加深入地了解车辆的性能、特点和设计风格。 华锐视点云展平台是一个专业的三维虚拟展…

2023年中国法拍房用户画像和数据分析

法拍房主要平台 法拍房主要平台有3家&#xff0c;分别是阿里、京东和北交互联平台。目前官方认定纳入网络司法拍卖的平台共有7家&#xff0c;其中阿里资产司法拍卖平台的挂拍量最大。 阿里法拍房 阿里法拍房数据显示2017年&#xff0c;全国法拍房9000套&#xff1b;2018年&a…

C语言归并排序(合并排序)算法以及代码

合并排序是采用分治法&#xff0c;先将无序序列划分为有序子序列&#xff0c;再将有序子序列合并成一个有序序列的有效的排序算法。 原理&#xff1a;先将无序序列利用二分法划分为子序列&#xff0c;直至每个子序列只有一个元素(单元素序列必有序)&#xff0c;然后再对有序子序…

【VScode和Leecode的爱恨情仇】command ‘leetcode.signin‘ not found

文章目录 一、关于command ‘leetcode.signin‘ not found的问题二、解决方案第一&#xff0c;没有下载Nodejs&#xff1b;第二&#xff0c;有没有在VScode中配置Nodejs第三&#xff0c;力扣的默认在VScode请求地址中请求头错误首先搞定配置其次搞定登入登入方法一&#xff1a;…

netty线程调度定制

1、netty的线程调度问题 在netty的TCP调度中&#xff0c;线程的调度封装在NioEventLoopGroup中&#xff0c;线程执行则封装在NioEventLoop中。 线程调度规则封装在MultithreadEventExecutorGroup的next方法中&#xff0c;这个方法又封装了EventExecutorChooserFactory&#xf…

ArkTS @Observed、@ObjectLink状态装饰器的使用

作用 Observed、ObjectLink装饰器用于在涉及嵌套对象或者数组元素为对象的场景中进行双向数据同步。 状态的使用 1.嵌套对象 我们将父类设置为Observed状态&#xff0c;这个时候&#xff0c;子应该设置ObjectLink才能完成数据的双向绑定&#xff0c;所以我们构建一个组件&…

控制理论simulink+matlab

这里写目录标题 根轨迹二级目录三级目录 根轨迹 z [-1]; %开环传递函数的零点 p [0 -2 -3 -4]; %开环传递函数的系统极点 k 1; %开环传递函数的系数&#xff0c;反映在比例上 g zpk(z,p,k); %生成开环传递函数%生成的传递函数如下 % (s1) % -------------…

Vue3-23-组件-依赖注入的使用详解

什么是依赖注入 个人的理解 &#xff1a; 依赖注入&#xff0c;是在 一颗 组件树中&#xff0c;由 【前代组件】 给 【后代组件】 提供 属性值的 一种方式 &#xff1b;这种方式 突破了 【父子组件】之间通过 props 的方式传值的限制&#xff0c;只要是 【前代组件】提供的 依…

「Qt Widget中文示例指南」如何创建一个计算器?(三)

Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写&#xff0c;所有平台无差别运行&#xff0c;更提供了几乎所有开发过程中需要用到的工具。如今&#xff0c;Qt已被运用于超过70个行业、数千家企业&#xff0c;支持数百万设备及应用。 本文将展示如何使用…