【深度学习】最强算法之:深度Q网络(DQN)

深度Q网络

  • 1、引言
  • 2、深度Q网络
    • 2.1 定义
    • 2.2 原理
    • 2.3 实现方式
    • 2.4 算法公式
    • 2.5 代码示例
  • 3、总结

1、引言

小屌丝:鱼哥, 马上清明小长假了, 你这准备去哪里玩啊?
小鱼:哪也不去,在家待着
小屌丝:在家? 待着? 干啥啊?
小鱼:啥也不干,床上躺着
小屌丝:床上… 躺着… 做啥啊?
小鱼:啥也不做,睡觉
小屌丝:睡觉?? 这大白天的,确定睡觉?
小鱼:我擦… 你这wc~
小屌丝:我很正经的好不好。
小鱼:… 我有点事,待会说
小屌丝: 待会,没时间了哦
小鱼:那就在多几个待会的
小屌丝:这火急火燎的, 肯定"有事"。
在这里插入图片描述

2、深度Q网络

2.1 定义

深度Q网络(DQN)是一种结合了深度学习和Q-learning的强化学习算法。它通过深度神经网络逼近值函数,并利用经验回放和目标网络等技术,使得Q-learning能够在高维连续状态空间中稳定学习。

2.2 原理

DQN的核心原理是利用深度神经网络来估计Q值函数。
在每个时刻,DQN根据当前状态s和所有可能的动作a计算出一组Q值,然后选择Q值最大的动作执行。
执行动作后,环境会给出新的状态s’和奖励r,DQN将这些信息存储到经验回放缓存中。

在训练过程中,DQN从经验回放缓存中随机采样一批历史数据,利用这些数据进行梯度下降更新神经网络参数。

此外,DQN还引入了目标网络来稳定学习过程,即每隔一定步数将当前网络参数复制给目标网络,用于计算目标Q值。

2.3 实现方式

实现DQN主要包括以下步骤:

  • 初始化深度神经网络(Q网络)和目标网络(目标Q网络)。
  • 初始化经验回放缓存。
  • 对于每个训练回合:
    • 初始化状态s。
    • 对于每个时间步t:
      • 使用ε-贪婪策略选择动作a。
      • 执行动作a,观察奖励r和新状态s’。
      • 将经验(s, a, r, s’)存储到经验回放缓存中。
      • 从经验回放缓存中采样一批数据,计算损失函数并更新Q网络参数。
      • 每隔一定步数更新目标网络参数。
    • 重复上述步骤直至满足终止条件。

2.4 算法公式

DQN的损失函数通常采用均方误差(MSE)形式,即:

L ( θ ) = 1 / N ∗ Σ [ ( r + γ ∗ m a x a ′ Q ( s ′ , a ′ ; θ − ) − Q ( s , a ; θ ) ) 2 ] L(θ) = 1/N * Σ[(r + γ * max_a' Q(s', a'; θ⁻) - Q(s, a; θ))^2] L(θ)=1/NΣ[(r+γmaxaQ(s,a;θ)Q(s,a;θ))2]

其中,

  • θ θ θ Q Q Q网络参数,
  • θ − θ⁻ θ是目标网络参数,
  • N N N是采样数据批量大小,
  • γ γ γ是折扣因子,
  • r r r是奖励,
  • s s s a a a分别是当前状态和动作,
  • s ′ s' s是下一状态,
  • a ′ a' a是下一状态的所有可能动作。

2.5 代码示例

# -*- coding:utf-8 -*-
# @Time   : 2024-04-01
# @Author : Carl_DJ

'''
实现功能:
    使用PyTorch框架的简单DQN(Deep Q-Network)实现示例

'''
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# 创建一个简单的神经网络,作为Q网络
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), action, reward, np.array(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN算法实现
class DQNAgent:
    def __init__(self, input_dim, output_dim):
        self.model = DQN(input_dim, output_dim)
        self.target_model = DQN(input_dim, output_dim)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters())
        self.buffer = ReplayBuffer(10000)
        self.steps_done = 0
        self.epsilon_start = 1.0
        self.epsilon_final = 0.01
        self.epsilon_decay = 500
        self.batch_size = 32

    def act(self, state):
        epsilon = self.epsilon_final + (self.epsilon_start - self.epsilon_final) * \
                  np.exp(-1. * self.steps_done / self.epsilon_decay)
        self.steps_done += 1
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_value = self.model(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(2)
        return action

    def update(self):
        if len(self.buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state)
        next_q_values = self.target_model(next_state)

        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + 0.99 * next_q_value * (1 - done)

        loss = (q_value - expected_q_value.data).pow(2).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())

# 训练环境设置
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

# 训练循环
episodes = 100
for episode in range(episodes):
    state = env.reset()
    total_reward = 0
    done = False
    while not done:
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        agent.buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        agent.update()
    agent.update_target()
    print('Episode: {}, Total reward: {}'.format(episode, total_reward))


解析:

  • 首先定义了一个简单的神经网络DQN,
  • 然后定义了ReplayBuffer用于经验回放,
  • 接着定义了DQNAgent类封装了DQN的决策、学习和目标网络更新逻辑。
  • 最后,通过创建一个gym环境(这里使用的是CartPole-v1)并在该环境中运行DQNAgent来进行训练。
    在这里插入图片描述

3、总结

深度Q网络(DQN)通过将深度学习与强化学习相结合,解决了传统Q-learning在高维连续状态空间中的维度灾难问题。

DQN利用深度神经网络的强大表征能力来估计Q值函数,并通过经验回放和目标网络等技术来稳定学习过程。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)测评一、二等奖获得者

关注小鱼,学习【机器学习】&【深度学习】领域的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/526276.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 开发篇+一个简单的数据库管理系统ZDB

说明&#xff1a;本文供数据库爱好者和初级开发人员学习使用 标签&#xff1a;数据库管理系统、RDBMS、Java小程序、Java、Java程序 系统&#xff1a;Windows 11 x86 CPU &#xff1a;Intel IDE &#xff1a;IntelliJ IDEA Community Edition 2024 语言&#xff1a;Java语言 标…

“AI+信创”两翼齐飞,实在智能全面加速自主可控实在智能RPA

近日&#xff0c;实在智能牵手华为昇腾、摩尔线程在信创领域展开紧密合作&#xff0c;共同加速推进AI和信创产业创新发展。 华为昇腾与实在智能达成昇腾原生大模型联合创新合作&#xff0c;基于华为昇腾AI自主创新软硬件平台全栈技术、实在智能自研RPA基础大模型解决方案能力&a…

简单好用高效的视频补帧软件:Squirrel-RIFE

Squirrel-RIFE&#xff0c;轻松实现高效补帧&#xff0c;让您的视频画面瞬间流畅升级&#xff01;- 精选真开源&#xff0c;释放新价值。 概览 在观看视频内容的过程中&#xff0c;尤其是面对复杂多变的动画场景或高速运动镜头时&#xff0c;观众时常会遭遇视频帧率不足所引发…

算法中的二阶差分

众所周知&#xff0c;在往区间的每一个数都加上一个相同的数k&#xff0c;进行n次后会得到一个新的数列&#xff0c;如果每次加都循环区间挨个数加上k&#xff0c;这样时间复杂度无疑是O(n^2)&#xff0c;很高。这时可以采用一阶差分就可解决&#xff0c;这里默认会一阶差分&am…

物联网行业趋势——青创智通

工业物联网解决方案-工业IOT-青创智通 随着科技的不断进步和应用场景的日益扩大&#xff0c;物联网行业呈现出迅猛发展的势头。作为当今世界最具前瞻性和战略意义的领域之一&#xff0c;物联网行业的趋势和未来发展值得深入探讨。 ​一、物联网行业正逐渐实现全面普及。随着物…

鸿蒙ArkUI开发实战:制作一个【简单计数器】

构建第一个页面 使用文本组件 工程同步完成后&#xff0c;在 Project 窗口&#xff0c;点击 entry > src > main > ets > pages &#xff0c;打开 Index.ets 文件&#xff0c;可以看到页面由 Row 、 Column 、 Text 组件组成。 index.ets 文件的示例如下&#xff1…

飞机降落(区间问题)

思路&#xff1a; 受P1803 凌乱的yyy / 线段覆盖的启发。 对于这道题&#xff0c;我的第一想法不是dfs&#xff0c;而是把它看作区间来看&#xff0c;分别就是【t&#xff0c;tl】和【td&#xff0c;tdl】。先按照结束时间排序&#xff0c;先用第一个飞机不延迟降落的时间a[0…

制造业智能化一体式I/O模块的集成与应用案例分享

在现代制造业中&#xff0c;智能化一体式I/O模块的应用已经成为提升生产效率、优化工艺流程的关键技术之一。这种一体化I/O模块的主要功能在于作为PLC&#xff08;可编程逻辑控制器&#xff09;系统的扩展接口&#xff0c;以满足多样化的输入输出需求。本文将通过一个实际案例&…

DFS-0与异或问题,有奖问答,飞机降落

代码和解析 #include<bits/stdc.h> using namespace std; int a[5][5]{{1,0,1,0,1}}; //记录图中圆圈内的值&#xff0c;并初始化第1行 int gate[11]; //记录10个逻辑门的一种排列 int ans; //答案 int logic(int x, int y, int op){…

麒麟系统下安装qt5.9.1后不能输入中文

引言 在虚拟机上安装麒麟系统后,安装了qt5.9.1,只能输入英文和数字不能输入中文注释,编译的程序也不能输入中文。 原因 安装后的麒麟系统自带搜狗输入法,原本可以输入中文,但是qt5.9.1缺少支持搜狗输入法的fcitx插件。所以qt5.9.1中不能输入中文。 解决方法 安装fcit…

逆向入门:为CTF国赛而战day03

今天来做几道题目。 环境准备&#xff1a;ida ,Exeinfo,万能脱壳器&#xff08;后面有写资源&#xff09; 强推&#xff0c;亲测有效CTF小工具下载整理_ctf工具御剑下载-CSDN博客 [网站BUUCTF] 目录 题目一 题目二三 题目4&#xff1a;新年快乐 题目一 easyre题解_easyr…

在自定义数据集上微调 YOLOv9 模型

在自定义数据集上微调 YOLOv9模型可以显着提高目标检测性能,但这种改进有多显着呢?在这次全面的探索中,YOLOv9在SkyFusion数据集上进行了微调,分为三个不同的类别:飞机、船舶和车辆。通过一系列广泛的实验,包括修改学习率、图像大小和战略性冻结主干网,已经实现了令人印…

目标检测——RCNN系列学习(一)

前置知识 包括&#xff1a;非极大值抑制&#xff08;NMS&#xff09;、selective search等 RCNN [1311.2524] Rich feature hierarchies for accurate object detection and semantic segmentation (arxiv.org)https://arxiv.org/abs/1311.2524 1.网络训练 2.推理流程 3.总…

【数据库事务并发问题】脏读、丢失修改、不可重复读、幻读

文章目录 一、脏读二、丢失修改三、不可重复读四、幻读 一、脏读 第二个事务读了①修改的数据后&#xff0c;前一个事务回滚了 一个事务读取数据并且对数据进行了修改&#xff0c;这个修改对其他事务来说是可见的&#xff0c;即使当前事务没有提交。这时另外一个事务读取了这个…

幸运数(蓝桥杯)

该 import java.util.*; public class Main {public static void main(String[] args) {Scanner scannew Scanner(System.in);int cnt0;for(int i1;i<100000000;i) {String si"";int lens.length();if(len%2!0) continue;int sum10; //左边int sum20; //右边fo…

【Mybatis】Mybatis 二级缓存全详解教程

【Mybatis-Plus】Mybatis-Plus 二级缓存全详解 一&#xff0c;Mybatis-Plus介绍 MyBatis-Plus&#xff08;简称MP&#xff09;是一个基于 MyBatis 的增强工具&#xff0c;它简化了 MyBatis 的开发&#xff0c;并且提供了许多便利的功能&#xff0c;帮助开发者更高效地进行持久…

申请专利有用吗 好处

申请专利&#xff1a;一项值得考虑的策略 随着科技的快速发展和市场竞争的日益激烈&#xff0c;创新成为了企业或个人取得竞争优势的关键。在这样的背景下&#xff0c;申请专利成为了许多创新者保护自己创意和技术的重要手段。 申请专利真的有用吗&#xff1f; 申请专利可以…

CS162 Operating System笔记

What is an Operating System? it’s typically a special layer of software that provides the application access to hardware resources.So.it’s convenient abs fractions of complex hardware devices.

数位递增数-第12届蓝桥杯选拔赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第46讲。 数位递增数&#…

【软考】23种设计模式详解,记忆方式,并举例说明

23种设计模式详解&#xff0c;举例说明 一、创建型模式1.1、抽象工厂模式&#xff08;Abstract Factory&#xff09;1.1.1、简介1.1.2、意图与应用场景1.1.3、结构1.1.4、优缺点1.1.4、示例代码&#xff08;简化版&#xff09; 1.2、建造者模式&#xff08;Builder&#xff09;…