DQN原理和代码实现

参考:王树森《强化学习》书籍、课程、代码


1、基本概念

折扣回报:
U t = R t + γ ⋅ R t + 1 + γ 2 ⋅ R t + 2 + ⋯ + γ n − t ⋅ R n . U_t=R_t+\gamma\cdot R_{t+1}+\gamma^2\cdot R_{t+2}+\cdots+\gamma^{n-t}\cdot R_n. Ut=Rt+γRt+1+γ2Rt+2++γntRn.
动作价值函数:
Q π ( s t , a t ) = E [ U t ∣ S t = s t , A t = a t ] , Q_\pi(s_t,a_t)=\mathbb{E}\Big[U_t\Big|S_t=s_t,A_t=a_t\Big], Qπ(st,at)=E[Ut St=st,At=at],
最大化动作价值函数<消除 π \pi π的影响>:
Q ⋆ ( s t , a t ) = max ⁡ π Q π ( s t , a t ) , ∀ s t ∈ S , a t ∈ A . Q^\star(s_t,a_t)=\max_{\pi}Q_\pi(s_t,a_t),\quad\forall s_t\in\mathcal{S},\quad a_t\in\mathcal{A}. Q(st,at)=πmaxQπ(st,at),stS,atA.


2、Q-learning算法
Observe a transition  ( s t , a t , r t , s t + 1 ) . TD target:  y t = r t + γ ⋅ max ⁡ a Q ⋆ ( s t + 1 , a ) . TD error:  δ t = Q ⋆ ( s t , a t ) − y t . Update: Q ⋆ ( s t , a t ) ← Q ⋆ ( s t , a t ) − α ⋅ δ t . \begin{aligned} &\text{Observe a transition }(s_t,{a_t},r_t,s_{t+1}). \\ &\text{TD target: }y_t=r_t+\gamma\cdot\max_{{a}}Q^\star(s_{t+1},{a}). \\ &\text{TD error: }\delta_t=Q^\star(s_t,{a_t})-y_t. \\ &\text{Update:}\quad Q^\star(s_t,{a_t})\leftarrow Q^\star(s_t,{a_t})-\alpha\cdot\delta_t. \end{aligned} Observe a transition (st,at,rt,st+1).TD target: yt=rt+γamaxQ(st+1,a).TD error: δt=Q(st,at)yt.Update:Q(st,at)Q(st,at)αδt.

3、DQN算法
Observe a transition  ( s t , a t , r t , s t + 1 ) . TD target:  y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w ) . TD error:  δ t = Q ( s t , a t ; w ) − y t . Update:  w ← w − α ⋅ δ t ⋅ ∂ Q ( s t , a t ; w ) ∂ w . \begin{aligned} &\text{Observe a transition }(s_t,{a_t},r_t,s_{t+1}). \\ &\text{TD target: }y_{t}=r_{t}+\gamma\cdot\max_{a}Q(s_{t+1},a;\mathbf{w}). \\ &\text{TD error: }\delta_t=Q(s_t,{a_t};\mathbf{w})-y_t. \\ &\text{Update: }\mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_t\cdot\frac{\partial Q(s_t,a_t;\mathbf{w})}{\partial\mathbf{w}}. \end{aligned} Observe a transition (st,at,rt,st+1).TD target: yt=rt+γamaxQ(st+1,a;w).TD error: δt=Q(st,at;w)yt.Update: wwαδtwQ(st,at;w).
注:DQN使用神经网络近似最大化动作价值函数;


4、经验回放:Experience Replay(训练DQN的一种策略)

优点:可以重复利用离线经验数据;连续的经验具有相关性,经验回放可以在离线经验BUFFER随机抽样,减少相关性;

超参数:Replay Buffer的长度;
∙  Find w by minimizing  L ( w ) = 1 T ∑ t = 1 T δ t 2 2 . ∙  Stochastic gradient descent (SGD): ∙  Randomly sample a transition,  ( s i , a i , r i , s i + 1 ) , from the buffer ∙  Compute TD error,  δ i . ∙  Stochastic gradient: g i = ∂ δ i 2 / 2 ∂ w = δ i ⋅ ∂ Q ( s i , a i ; w ) ∂ w ∙  SGD: w ← w − α ⋅ g i . \begin{aligned} &\bullet\text{ Find w by minimizing }L(\mathbf{w})=\frac{1}{T}\sum_{t=1}^{T}\frac{\delta_{t}^{2}}{2}. \\ &\bullet\text{ Stochastic gradient descent (SGD):} \\ &\bullet\text{ Randomly sample a transition, }(s_i,a_i,r_i,s_{i+1}),\text{from the buffer} \\ &\bullet\text{ Compute TD error, }\delta_i. \\ &\bullet\text{ Stochastic gradient: g}_{i}=\frac{\partial\delta_{i}^{2}/2}{\partial \mathbf{w}}=\delta_{i}\cdot\frac{\partial Q(s_{i},a_{i};\mathbf{w})}{\partial\mathbf{w}} \\ &\bullet\text{ SGD: w}\leftarrow\mathbf{w}-\alpha\cdot\mathbf{g}_i. \end{aligned}  Find w by minimizing L(w)=T1t=1T2δt2. Stochastic gradient descent (SGD): Randomly sample a transition, (si,ai,ri,si+1),from the buffer Compute TD error, δi. Stochastic gradient: gi=wδi2/2=δiwQ(si,ai;w) SGD: wwαgi.
注:实践中通常使用minibatch SGD,每次抽取多个经验,计算小批量随机梯度;


**5、使用经验回放训练的DQN的算法代码实现**
"""4.3节DQN算法实现。
"""
import argparse
from collections import defaultdict
import os
import random
from dataclasses import dataclass, field
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class QNet(nn.Module):
    """QNet.
    Input: feature
    Output: num_act of values
    """

    def __init__(self, dim_state, num_action):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_action)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class DQN:
    def __init__(self, dim_state=None, num_action=None, discount=0.9):
        self.discount = discount
        self.Q = QNet(dim_state, num_action)
        self.target_Q = QNet(dim_state, num_action)
        self.target_Q.load_state_dict(self.Q.state_dict())

    def get_action(self, state):
        qvals = self.Q(state)
        return qvals.argmax()

    def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
        # 计算s_batch,a_batch对应的值。
        qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
        # 使用target Q网络计算next_s_batch对应的值。
        next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
        # 使用MSE计算loss。
        loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
        return loss


def soft_update(target, source, tau=0.01):
    """
    update target by target = tau * source + (1 - tau) * target.
    """
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


@dataclass
class ReplayBuffer:
    maxsize: int
    size: int = 0
    state: list = field(default_factory=list)
    action: list = field(default_factory=list)
    next_state: list = field(default_factory=list)
    reward: list = field(default_factory=list)
    done: list = field(default_factory=list)

    def push(self, state, action, reward, done, next_state):
        """
        :param state: 状态
        :param action: 动作
        :param reward: 奖励
        :param done:
        :param next_state:下一个状态
        :return:
        """
        if self.size < self.maxsize:
            self.state.append(state)
            self.action.append(action)
            self.reward.append(reward)
            self.done.append(done)
            self.next_state.append(next_state)
        else:
            position = self.size % self.maxsize
            self.state[position] = state
            self.action[position] = action
            self.reward[position] = reward
            self.done[position] = done
            self.next_state[position] = next_state
        self.size += 1

    def sample(self, n):
        total_number = self.size if self.size < self.maxsize else self.maxsize
        indices = np.random.randint(total_number, size=n)
        state = [self.state[i] for i in indices]
        action = [self.action[i] for i in indices]
        reward = [self.reward[i] for i in indices]
        done = [self.done[i] for i in indices]
        next_state = [self.next_state[i] for i in indices]
        return state, action, reward, done, next_state


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.no_cuda:
        torch.cuda.manual_seed(args.seed)


def train(args, env, agent):
    replay_buffer = ReplayBuffer(10_000)
    optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr)
    optimizer.zero_grad()

    epsilon = 1
    epsilon_max = 1
    epsilon_min = 0.1
    episode_reward = 0
    episode_length = 0
    max_episode_reward = -float("inf")
    log = defaultdict(list)
    log["loss"].append(0)

    agent.Q.train()
    state, _ = env.reset(seed=args.seed)
    for i in range(args.max_steps):
        if np.random.rand() < epsilon or i < args.warmup_steps:
            action = env.action_space.sample()
        else:
            action = agent.get_action(torch.from_numpy(state))
            action = action.item()
        env.render()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        episode_length += 1

        replay_buffer.push(state, action, reward, done, next_state)
        state = next_state

        if done is True:
            log["episode_reward"].append(episode_reward)
            log["episode_length"].append(episode_length)

            print(
                f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log['loss'][-1]:.1e}, epsilon={epsilon:.3f}")

            # 如果得分更高,保存模型。
            if episode_reward > max_episode_reward:
                save_path = os.path.join(args.output_dir, "model.bin")
                torch.save(agent.Q.state_dict(), save_path)
                max_episode_reward = episode_reward

            episode_reward = 0
            episode_length = 0
            epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, 1e-1)
            state, _ = env.reset()

        if i > args.warmup_steps:
            bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
            bs = torch.tensor(bs, dtype=torch.float32)
            ba = torch.tensor(ba, dtype=torch.long)
            br = torch.tensor(br, dtype=torch.float32)
            bd = torch.tensor(bd, dtype=torch.float32)
            bns = torch.tensor(bns, dtype=torch.float32)

            loss = agent.compute_loss(bs, ba, br, bd, bns)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            log["loss"].append(loss.item())

            soft_update(agent.target_Q, agent.Q)

    # 3. 画图。
    plt.plot(log["loss"])
    plt.yscale("log")
    plt.savefig(f"{args.output_dir}/loss.png", bbox_inches="tight")
    plt.close()

    plt.plot(np.cumsum(log["episode_length"]), log["episode_reward"])
    plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")
    plt.close()


def eval(args, env, agent):
    agent = DQN(args.dim_state, args.num_action)
    model_path = os.path.join(args.output_dir, "model.bin")
    agent.Q.load_state_dict(torch.load(model_path))

    episode_length = 0
    episode_reward = 0
    state, _ = env.reset()
    for i in range(5000):
        episode_length += 1
        action = agent.get_action(torch.from_numpy(state)).item()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        env.render()
        episode_reward += reward

        state = next_state
        if done is True:
            print(f"episode reward={episode_reward}, episode length{episode_length}")
            state, _ = env.reset()
            episode_length = 0
            episode_reward = 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
    parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
    parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
    parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
    parser.add_argument("--max_steps", default=50000, type=int, help="Maximum steps for interaction.")
    parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")
    parser.add_argument("--warmup_steps", default=5000, type=int, help="Warmup steps without training.")
    parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
    parser.add_argument("--epsilon_decay", default=1 / 1000, type=float, help="Epsilon-greedy algorithm decay coefficient.")
    parser.add_argument("--do_train", action="store_true", help="Train policy.")
    parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
    args = parser.parse_args()

    args.do_train = True
    args.do_eval = True

    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    env = gym.make(args.env, render_mode="human")
    set_seed(args)
    agent = DQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount)
    agent.Q.to(args.device)
    agent.target_Q.to(args.device)

    if args.do_train:
        train(args, env, agent)

    if args.do_eval:
        eval(args, env, agent)


if __name__ == "__main__":
    main()

显示:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/61576.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于 APN 的 CXL 链路训练

&#x1f525;点击查看精选 CXL 系列文章&#x1f525; &#x1f525;点击进入【芯片设计验证】社区&#xff0c;查看更多精彩内容&#x1f525; &#x1f4e2; 声明&#xff1a; &#x1f96d; 作者主页&#xff1a;【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

Dockerfile构建mysql

使用dockerfile构建mysql详细教学加案例 Dockerfile 文件 # 使用官方5.6版本&#xff0c;latest为默认版本 FROM mysql:5.6 #复制my.cof至容器内 ADD my.cnf /etc/mysql/my.cof #设置环境变量 密码 ENV MYSQL_ROOT_PASSWORD123456my.cof 文件 [mysqld] character-set-server…

LNMP搭建

LNMP&#xff1a;目前成熟的企业网站的应用模式之一&#xff0c;指的是一套协同工作的系统和相关软件 能够提供静态页面服务&#xff0c;也可以提供动态web服务。 这是一个缩写 L linux系统&#xff0c;操作系统。 N nginx网站服务&#xff0c;也可也理解为前端&#xff0c…

企业计算机服务器中了locked勒索病毒怎么办,如何预防勒索病毒攻击

计算机服务器是企业的关键信息基础设备&#xff0c;随着计算机技术的不断发展&#xff0c;企业的计算机服务器也成为了众多勒索者的攻击目标&#xff0c;勒索病毒成为当下计算机服务器的主要攻击目标。近期&#xff0c;我们收到很多企业的求助&#xff0c;企业的服务器被locked…

uni-app、H5实现瀑布流效果封装,列可以自定义

文章目录 前言一、效果二、使用代码三、核心代码总结 前言 最近做项目需要实现uni-app、H5实现瀑布流效果封装&#xff0c;网上搜索有很多的例子&#xff0c;但是代码都是不够完整的&#xff0c;下面来封装一个uni-app、H5都能用的代码。在小程序中&#xff0c;一个个item渲染…

Godot 4 源码分析 - Path2D与PathFollow2D

学习演示项目dodge_the_creeps&#xff0c;发现里面多了一个Path2D与PathFollow2D 研究GDScript代码发现&#xff0c;它主要用于随机生成Mob var mob_spawn_location get_node(^"MobPath/MobSpawnLocation")mob_spawn_location.progress randi()# Set the mobs dir…

【机器学习】编码、创造和筛选特征

在机器学习和数据科学领域中&#xff0c;特征工程是提取、转换和选择原始数据以创建更具信息价值的特征的过程。假设拿到一份数据集之后&#xff0c;如何逐步完成特征工程呢&#xff1f; 文章目录 一、特性类型分析1.1 数值型特征1.2 类别型特征1.3 时间型特征1.4 文本型特征1.…

Android Studio安装AI编程助手Github Copilot

csdn原创谢绝转载 简介 文档链接 https://docs.github.com/en/copilot/getting-started-with-github-copilot 它是个很牛B的编程辅助工具&#xff0c;装它&#xff0c;快装它&#xff0e; 支持以下IDE: IntelliJ IDEA (Ultimate, Community, Educational)Android StudioAppC…

数据库操作系列-Mysql, Postgres常用sql语句总结

文章目录 1.如果我想要写一句sql语句&#xff0c;实现 如果存在则更新&#xff0c;否则就插入新数据&#xff0c;如何解决&#xff1f;MySQL数据库实现方案: ON DUPLICATE KEY UPDATE写法 Postgres数据库实现方案:方案1&#xff1a;方案2&#xff1a;关于更新&#xff1a;如何实…

【云原生】K8S二进制搭建一

目录 一、环境部署1.1操作系统初始化 二、部署etcd集群2.1 准备签发证书环境在 master01 节点上操作在 node01与02 节点上操作 三、部署docker引擎四、部署 Master 组件4.1在 master01 节点上操 五、部署Worker Node组件 一、环境部署 集群IP组件k8s集群master01192.168.243.1…

【雕爷学编程】MicroPython动手做(31)——物联网之Easy IoT

1、物联网的诞生 美国计算机巨头微软(Microsoft)创办人、世界首富比尔盖茨&#xff0c;在1995年出版的《未来之路》一书中&#xff0c;提及“物物互联”。1998年麻省理工学院提出&#xff0c;当时被称作EPC系统的物联网构想。2005年11月&#xff0c;国际电信联盟发布《ITU互联网…

在 Ubuntu 上安装 Docker 桌面

Ubuntu 22.04 (LTS) 安装 Docker 桌面 要成功安装 Docker Desktop&#xff0c;您必须&#xff1a; 满足系统要求拥有 64 位版本的 Ubuntu Jammy Jellyfish 22.04 (LTS) 或 Ubuntu Impish Indri 21.10。对于非 Gnome 桌面环境&#xff0c;必须安装 gnome-terminal&#xff1a;…

机器学习笔记 - YOLO-NAS 最高效的目标检测算法之一

一、YOLO-NAS概述 YOLO(You Only Look Once)是一种对象检测算法,它使用深度神经网络模型,特别是卷积神经网络,来实时检测和分类对象。该算法首次在 2016 年由 Joseph Redmon、Santosh Divvala、Ross Girshick 和 Ali Farhadi 发表的论文《You Only Look Once: Unified, Re…

Excel·VBA表格横向、纵向相互转换

如图&#xff1a;对图中区域 A1:M6 横向表格&#xff0c;转换成区域 A1:C20 纵向表格&#xff0c;即 B:M 列转换成每2列一组按行写入&#xff0c;并删除空行。同理&#xff0c;反向操作就是纵向表格转换成横向表格 目录 横向转纵向实现方法1转换结果 实现方法2转换结果 纵向转横…

ThreadLocal有内存泄漏问题吗

对于ThreadLocal的原理不了解或者连Java中的引用类型都不了解的可以看一下我的之前的一篇文章Java中的引用和ThreadLocal_鱼跃鹰飞的博客-CSDN博客 我这里也简单总结一下: 1. 每个Thread里都存储着一个成员变量&#xff0c;ThreadLocalMap 2. ThreadLocal本身不存储数据&…

Jenkins 自动化部署实例讲解,另附安装教程!

【2023】Jenkins入门与安装_jenkins最新版本_丶重明的博客-CSDN博客 也可以结合这个互补看 前言 你平常在做自己的项目时&#xff0c;是否有过部署项目太麻烦的想法&#xff1f;如果你是单体项目&#xff0c;可能没什么感触&#xff0c;但如果你是微服务项目&#xff0c;相…

Android的Handler消息通信详解

目录 背景 1. Handler基本使用 2. Handler的Looper源码分析 3. Handler的Message以及消息池、MessageQueue 4. Handler的Native实现 4.1 MessageQueue 4.2 Native结构体和类 4.2.1 Message结构体 4.2.2 消息处理类 4.2.3 回调类 4.2.5 ALooper类 5. 总结&…

【千题百解】华为机试题:求最小公倍数

“所有命运馈赠的礼物,都已在暗中标好了价格” 👨🏻‍💻作者:鳄鱼儿 🍀个人简介 👨🏻‍🎓计算机专业硕士研究生 🦨阿里云社区专家博主 🌙CSDN博客专家 & Java领域优质创作者 题目 解题 Java实现 注意a和b相乘时可能超过int最大值。 import java.uti

python调用pytorch的clip模型时报错

使用python调用pytorch中的clip模型时报错&#xff1a;AttributeError: partially initialized module ‘clip’ has no attribute ‘load’ (most likely due to a circular import) 目录 现象解决方案一、查看项目中是否有为clip名的文件二、查看clip是否安装成功 现象 clip…

命令模式(Command)

命令模式是一种行为设计模式&#xff0c;可将一个请求封装为一个对象&#xff0c;用不同的请求将方法参数化&#xff0c;从而实现延迟请求执行或将其放入队列中或记录请求日志&#xff0c;以及支持可撤销操作。其别名为动作(Action)模式或事务(Transaction)模式。 Command is …