Pytorch深度强化学习2-1:基于价值的强化学习——DQN算法

目录

  • 0 专栏介绍
  • 1 基于价值的强化学习
  • 2 深度Q网络与Q-learning
  • 3 DQN原理分析
  • 4 DQN训练实例

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 基于价值的强化学习

根据不动点定理,最优策略和最优价值函数是唯一的(对该经典理论不熟悉的请看Pytorch深度强化学习1-4:策略改进定理与贝尔曼最优方程详细推导),通过优化价值函数间接计算最优策略的方法称为基于价值的强化学习(value-based)框架。设状态空间为 n n n维欧式空间 S = R n S=\mathbb{R} ^n S=Rn,每个维度代表状态的一个特征。此时状态-动作值函数记为

Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)

其中 s \boldsymbol{s} s是状态向量, a \boldsymbol{a} a是动作空间中的动作向量, θ \boldsymbol{\theta } θ是神经网络的参数向量。深度学习完成了从输入状态到输出状态-动作价值的映射

s → Q ( s , a ; θ ) [ Q ( s , a 1 ) Q ( s , a 2 ) ⋯ Q ( s , a m ) ] T    ( a 1 , a 2 , ⋯   , a m ∈ A ) \boldsymbol{s}\xrightarrow{Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}\left[ \begin{matrix} Q\left( \boldsymbol{s},a_1 \right)& Q\left( \boldsymbol{s},a_2 \right)& \cdots& Q\left( \boldsymbol{s},a_m \right)\\\end{matrix} \right] ^T\,\, \left( a_1,a_2,\cdots ,a_m\in A \right) sQ(s,a;θ) [Q(s,a1)Q(s,a2)Q(s,am)]T(a1,a2,,amA)

相当于对无穷维Q-Table的一次隐式查表,对经典Q-learing算法不熟悉的请看Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)、Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫。设目标价值函数为 Q ∗ Q^* Q,若采用最小二乘误差,可得损失函数为

J ( θ ) = E [ 1 2 ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) 2 ] J\left( \boldsymbol{\theta } \right) =\mathbb{E} \left[ \frac{1}{2}\left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) ^2 \right] J(θ)=E[21(Q(s,a)Q(s,a;θ))2]

采用梯度下降得到参数更新公式为

θ ← θ + α ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) ∂ Q ( s , a ; θ ) ∂ θ \boldsymbol{\theta }\gets \boldsymbol{\theta }+\alpha \left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) \frac{\partial Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}{\partial \boldsymbol{\theta }} θθ+α(Q(s,a)Q(s,a;θ))θQ(s,a;θ)

随着迭代进行, Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)将不断逼近 Q ∗ Q^* Q,由 Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)进行的策略评估和策略改进也将迭代至最优。

2 深度Q网络与Q-learning

Q-learning和深度Q学习(Deep Q-learning, DQN)是强化学习领域中两种重要的算法,它们在解决智能体与环境之间的决策问题方面具有相似之处,但也存在一些显著的异同。这里进行简要阐述以加深对二者的理解。

  • Q-learning是一种基于值函数的强化学习算法。它通过使用Q-Table来表示每个状态和动作对的预期回报。Q值函数用于指导智能体在每个时间步选择最优动作。通过不断更新Q值函数来使其逼近最优的Q值函数
  • DQN是对Q-learning的深度网络版本,它将神经网络引入Q-learning中,以处理具有高维状态空间的问题。通过使用深度神经网络作为函数逼近器,DQN可以学习从原始输入数据(如像素值)直接预测每个动作的Q值

在这里插入图片描述

3 DQN原理分析

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池(Experience Replay):考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性;
  • 目标网络(Target Network):考虑到若目标价值 与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建与 结构相同但慢于 更新的独立目标网络来评估目标价值,使模型更稳定。

拟合了高维状态空间,是Q-Learning算法的深度学习版本,算法流程如表所示

在这里插入图片描述

4 DQN训练实例

最简单的例子是使用全连接网络来构造DQN

class DQN(nn.Module):
	def __init__(self, input_dim, output_dim):
	    super(DQN, self).__init__()
	    self.input_dim = input_dim
	    self.output_dim = output_dim
	    
	    self.fc = nn.Sequential(
	        nn.Linear(self.input_dim[0], 128),
	        nn.ReLU(),
	        nn.Linear(128, 256),
	        nn.ReLU(),
	        nn.Linear(256, self.output_dim)
	    )
	
	def __str__(self) -> str:
	    return "Fully Connected Deep Q-Value Network, DQN"
	
	def forward(self, state):
	    qvals = self.fc(state)
	    return qvals

基于贝尔曼最优原理的损失计算如下

def computeLoss(self, batch):
    states, actions, rewards, next_states, dones = batch
    states = torch.FloatTensor(states).to(self.device)
    actions = torch.LongTensor(actions).to(self.device)
    rewards = torch.FloatTensor(rewards).to(self.device)
    next_states = torch.FloatTensor(next_states).to(self.device)
    dones = (1 - torch.FloatTensor(dones)).to(self.device)

    # 根据实际动作提取Q(s,a)值
    curr_Q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_Q = self.target_model(next_states)
    max_next_Q = torch.max(next_Q, 1)[0]
    expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q * dones

    loss = self.criterion(curr_Q, expected_Q.detach())
    return loss

基于经验回放池和目标网络的参数更新如下

def update(self, batch_size):
	batch = self.replay_buffer.sample(batch_size)
	loss = self.computeLoss(batch)
	self.optimizer.zero_grad()
	loss.backward()
	self.optimizer.step()
	
	# 更新target网络
	for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
	    target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
	
	# 退火
	self.epsilon = self.epsilon + self.epsilon_delta \
	    if self.epsilon < self.epsilon_max else self.epsilon_max

基于DQN可以实现最基本的智能体,下面给出一些具体案例

  • Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/271420.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于电商场景的高并发RocketMQ实战-Raft协议的leader选举算法、Broker基于状态机实现的leader选举

&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308; 【11来了】文章导读地址&#xff1a;点击查看文章导读&#xff01; &#x1f341;&#x1f341;&#x1f341;&#x1f341;&#x1f341;&#x1f341;&#x1f3…

使用LLaMA-Factory微调ChatGLM3

1、创建虚拟环境 略 2、部署LLaMA-Factory &#xff08;1&#xff09;下载LLaMA-Factory https://github.com/hiyouga/LLaMA-Factory &#xff08;2&#xff09;安装依赖 pip3 install -r requirements.txt&#xff08;3&#xff09;启动LLaMA-Factory的web页面 CUDA_VI…

Mybatis如何兼容各类日志?

文章目录 适配器模式日志模块代理模式1、静态代理模式2、JDK动态代理 JDBC Logger总结 Apache Commons Logging、Log4j、Log4j2、java.util.logging 等是 Java 开发中常用的几款日志框架&#xff0c;这些日志框架来源于不同的开源组织&#xff0c;给用户暴露的接口也有很多不同…

ResNet网络分析与demo实例

参考自 up主的b站链接&#xff1a;霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频这位大佬的博客 Fun_机器学习,pytorch图像分类,工具箱-CSDN博客 ResNet 详解 原论文地址 [1512.03385] Deep Residual Learning for Image Recognition (arxiv.org) ResNet 网络是在 …

Python、PHP/JAVA/C#电商评论数据采集与分析

引言 在电商竞争日益激烈的情况下&#xff0c;商家既要提高产品质量&#xff0c;又要洞悉客户的想法和需求&#xff0c;关注客户购买商品后的评论&#xff0c;而第三方商家获取商品评价主要依赖于人工收集&#xff0c;不但效率低&#xff0c;而且准确度得不到保障。通过使用Py…

【数据结构和算法】找到最高海拔

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、题目描述 二、题解 2.1 前缀和的解题模板 2.1.1 最长递增子序列长度 2.1.2 寻找数组中第 k 大的元素 2.1.3 最长公共子序列…

利用MATLAB设计一个(2,1,7)卷积码编译码器

1、条件&#xff1a; 输入数字信号&#xff0c;可以随机产生&#xff0c;也可手动输入 2、要求&#xff1a; &#xff08;1&#xff09;能显示编码树、网格图或状态转移图三者之一&#xff1b; &#xff08;2&#xff09;根据输入数字信号编码生成卷积码并显示&#xf…

如何进行块存储管理

目录 块存储概念 块存储&#xff08;云盘&#xff09;扩容 方式一&#xff1a;直接扩容现有云盘 方式二&#xff1a;创建一块新数据盘 方式三&#xff1a;在更换操作系统时&#xff0c;同时更换系统盘 块存储&#xff08;云盘&#xff09;变配 云盘变配操作步骤 块存储概…

索引进阶 | 再谈 MySQL 的慢 SQL 优化

索引可以提高数据检索的效率&#xff0c;降低数据库的IO成本。 MySQL在300万条记录左右性能开始逐渐下降&#xff0c;虽然官方文档说500~800w记录&#xff0c;所以大数据量建立索引是非常有必要的。 MySQL提供了Explain&#xff0c;用于显示SQL执行的详细信息&#xff0c;可以…

质量免费吗?

本文首发于个人网站「BY林子」&#xff0c;转载请参考版权声明。 两个场景 场景一&#xff1a;有限经费与质量改进 “要写自动化的单元测试、E2E测试&#xff0c;就会需要更多的钱&#xff0c;可是我们经费有限暂时做不了。” “CI上配置SonarQube扫描&#xff0c;对于扫描出来…

godot 报错Unable to initialize Vulkan video driver解决

版本 godot 4.2.1 现象 godot4.2.1 默认使用vulkan驱动&#xff0c;如果再不支持vulkan驱动的主机上&#xff0c;进入引擎编辑器将报错如下 解决 启动参数添加 –rendering-driver opengl3 即可进入引擎编辑器 此时运行项目仍然会报错无法初始化驱动 在项目设置中配置编…

Vue-Pinina基本教程

前言 官网地址&#xff1a;Pinia | The intuitive store for Vue.js (vuejs.org) 看以下内容&#xff0c;需要有vuex的基础&#xff0c;下面很多概念会直接省略&#xff0c;比如state、actions、getters用处含义等 1、什么是Pinina Pinia 是 Vue 的存储库&#xff0c;它允许您跨…

储能:东风已至,破浪在即——安科瑞 顾烊宇

今年的各省政府工作报告已经陆续发布&#xff0c;新能源是各省能源工作的重点&#xff0c;从目前31个省&#xff08;区、市&#xff09;相继公布的2022年经济增长数据来看&#xff0c;一些提前布局新能源产业的省市纷纷交出不错的成绩单&#xff0c;新能源成为当地GDP增速的重要…

饥荒Mod 开发(二三):显示物品栏详细信息

饥荒Mod 开发(二二)&#xff1a;显示物品信息 源码 前一篇介绍了如何获取 鼠标悬浮物品的信息&#xff0c;这一片介绍如何获取 物品栏的详细信息。 拦截 inventorybar 和 itemtile等设置字符串方法 在modmain.lua 文件中放入下面代码即可实现鼠标悬浮到 物品栏显示物品详细信…

微信小程序云开发-下载云存储中的文件

一、前言 很多时候我们需要实现用户在客户端下载服务端的文件&#xff08;图片、视频、pdf等&#xff09;到用户本地并保存起来&#xff0c;小程序也经常需要实现这样的需求。 在传统服务器开发下网上已经有很多关于小程序下载服务端文件的资料了&#xff0c;但是基于云开发的…

苹果怎么备份QQ的聊天记录?这3招教你快速备份!

QQ聊天记录是我们与好友之间的重要互动和沟通记录。但是&#xff0c;有时可能会由于各种原因&#xff0c;比如系统崩溃、更换手机、自身误操作、QQ闪退等&#xff0c;可能会导致聊天记录丢失。 因此&#xff0c;备份QQ聊天记录显得尤为重要。那么&#xff0c;苹果手机怎么备份…

SAP CO系统配置-与PS集成相关配置(机器人制造项目实例)

维护分配结构 配置路径 IMG菜单路径:控制>内部订单>实际过帐>结算>维护分配结构 事务代码 OKO6 维护结算参数文件 定义利润分析码

ZED-Mini 标定完全指南(应该是最详细的吧)

标定 ZED-Mini 相机主要为了跑 VINS-Fusion 以及后期的联合标定相关事宜 双目相机标定 出厂标定数据 关于ZED相机的内参&#xff0c;使用出厂标定的数据就好了&#xff0c;如果安装ZED的SDK时使用的是默认的安装路径&#xff0c;可以在/usr/local/zed/settings下面找到一个SN…

漏洞处理-未设置X-Frame-Options

漏洞名称&#xff1a;iFrame注入 风险描述&#xff1a;系统未设置x-frame-options头 风险等级&#xff1a;低 整改建议&#xff1a;为系统添加x-frame-options头 知识 X-Frame-Options 响应头 X-Frame-Options HTTP 响应头是用来给浏览器指示允许一个页面可否在 <fram…

通过 Bytebase API 做数据库 Schema 变更

Bytebase 是一款数据库 DevOps 和 CI/CD 工具&#xff0c;适用于开发人员、DBA 和平台工程团队。 它提供了一个直观的图形用户界面来管理数据库 Schema 变更。另一方面&#xff0c;一些团队可能希望将 Bytebase 集成到现有的内部 DevOps 研发平台中。这需要调用 Bytebase API。…