Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

目录

  • 0 专栏介绍
  • 1 Q-Learning算法原理
  • 2 强化学习基本框架
  • 3 机器人走迷宫算法
    • 3.1 迷宫环境
    • 3.2 状态、动作和奖励
    • 3.3 Q-Learning算法实现
    • 3.4 完成训练
  • 4 算法分析
    • 4.1 Q-Table
    • 4.2 奖励曲线

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 Q-Learning算法原理

在Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)介绍到时序差分强化学习是动态规划与蒙特卡洛的折中

Q π ( s t , a t ) = n 次增量 Q π ( s t , a t ) + α ( R t − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) ⏟ 采样 \begin{aligned}Q^{\pi}\left( s_t,a_t \right) &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( R_t-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+\gamma R_{t+1}-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}{ \underset{\text{采样}}{\underbrace{Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }}}\end{aligned} Qπ(st,at)n次增量 Qπ(st,at)+α(RtQπ(st,at))n次增量 Qπ(st,at)+α(rt+1+γRt+1Qπ(st,at))n次增量 采样 Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)Qπ(st,at))

其中 r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) r_{t+1}+\gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) -Q^{\pi}\left( s_t,a_t \right) rt+1+γQπ(st+1,at+1)Qπ(st,at)称为时序差分误差。基于离轨策略的时序差分强化学习的代表性算法是Q-learning算法,其算法流程如下所示。具体的策略改进算法推导请见之前的文章,本文重点在于应用Q-learning算法解决实际问题

在这里插入图片描述

我们先来看看最终实现的效果

训练前
在这里插入图片描述

训练后

在这里插入图片描述

接下来详细讲解如何一步步实现这个智能体

2 强化学习基本框架

强化学习(Reinforcement Learning, RL)在潜在的不确定复杂环境中,训练一个最优决策 π \pi π指导一系列行动实现目标最优化的机器学习方法。在初始情况下,没有训练数据告诉强化学习智能体并不知道在环境中应该针对何种状态采取什么行动,而是通过不断试错得到最终结果,再反馈修正之前采取的策略,因此强化学习某种意义上可以视为具有“延迟标记信息”的监督学习问题。

在这里插入图片描述

强化学习的基本过程是:智能体对环境采取某种行动 a a a,观察到环境状态发生转移 s 0 → s s_0\rightarrow s s0s,反馈给智能体转移后的状态 s s s和对这种转移的奖赏 r r r。综上所述,一个强化学习任务可以用四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=S,A,P,R表征

  • 状态空间 S S S:每个状态 s ∈ S s \in S sS是智能体对感知环境的描述;
  • 动作空间 A A A:每个动作 a ∈ A a \in A aA是智能体能够采取的行动;
  • 状态转移概率 P P P:某个动作 a ∈ A a \in A aA作用于处在某个状态 s ∈ S s \in S sS的环境中,使环境按某种概率分布 P P P转换到另一个状态;
  • 奖赏函数 R R R:表示智能体对状态 s ∈ S s \in S sS下采取动作 a ∈ A a \in A aA导致状态转移的期望度,通常 r > 0 r>0 r>0为期望行动, r < 0 r<0 r<0为非期望行动。

所以,程序上也需要依次实现四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=S,A,P,R

3 机器人走迷宫算法

3.1 迷宫环境

我们创建的迷宫包含障碍物、起点和终点

class Maze(tk.Tk, object):
    '''
    * @breif: 迷宫环境类
    * @param[in]: None
    '''    
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze game')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self.buildMaze()

    '''
    * @breif: 创建迷宫
    '''
    def buildMaze(self):
        self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT)
        # 网格地图
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建原点坐标
        origin = np.array([20, 20])

        # 创建障碍
        barrier_list = [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0),
                        (0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6),
                        (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 1), (6, 2),
                        (6, 3), (6, 4), (6, 5), (1, 2), (2, 2), (4, 1), (5, 4),
                        (1, 4), (3, 3)]
        self.barriers = [self.creatObject(origin, *index) for index in barrier_list]

        # 创建终点
        self.terminus = self.creatObject(origin, 5, 5, 'blue')

3.2 状态、动作和奖励

机器人的状态可以设置为当前的位置坐标

s = self.canvas.coords(self.agent)

机器人的动作可以设为上、下左、右

if action == 0:   # up
      if s[1] > UNIT:
          base_action[1] -= UNIT
  elif action == 1:   # down
      if s[1] < (MAZE_H - 1) * UNIT:
          base_action[1] += UNIT
  elif action == 2:   # right
      if s[0] < (MAZE_W - 1) * UNIT:
          base_action[0] += UNIT
  elif action == 3:   # left
      if s[0] > UNIT:
          base_action[0] -= UNIT

机器人的奖励设置为以下几种:

  • 碰到障碍物:-10分,并进入终止状态
  • 成功到达终点: +50分,并进入终止状态
  • 未到达终点:-1分,能量耗散惩罚,防止机器人原地振荡
if s_ in [self.canvas.coords(barrier) for barrier in self.barriers]:
   reward = -10
   done = True
   s_ = 'terminal'
elif s_ == self.canvas.coords(self.terminus):
   reward = 50
   done = True
   s_ = 'terminal'
else:
   reward = -1
   done = False

3.3 Q-Learning算法实现

根据算法流程,实现下面的Q-Learning训练函数

def train(self, env, episodes=1000, reward_curve=[], file=None):
	with tqdm(range(episodes)) as bar:
	    for _ in bar:
	        # 初始化环境和该幕累计奖赏
	        state = env.reset()
	        acc_reward = 0
	        while True:
	            # 刷新环境
	            env.render()
	            # 采样一个动作并进行状态转移
	            action = self.policySample(str(state))
	            next_state, reward, done = env.step(action)
	            acc_reward += reward
	            # 智能体学习策略
	            self.learn(str(state), action, reward, str(next_state))
	            state = next_state
	            if done:
	                reward_curve.append(acc_reward)
	                break
	# 保存策略
	if not file:
	    self.q_table.to_csv(file)
	env.destroy()

3.4 完成训练

训练过程如下所示,完成后保存权重文件

if __name__ == "__main__":
    env = Maze()
    agent = Agent(actions=list(range(env.n_actions)))
    reward_curve = []

    # 训练智能体
    env.after(100, agent.train, env, 50, reward_curve, './weight/csv')

    # 主循环
    env.mainloop()

4 算法分析

4.1 Q-Table

在Q-Learning算法中,我们需要维护一个Q-Table,用来记录各种状态和动作的价值。Q-Table是一个二维表格,其中每一行表示一个状态,每一列表示一个动作。Q-Table中的值表示某个状态下执行某个动作所获得的回报(或者预期回报)。Q-Table的更新是Q-Learning算法的核心。在每次执行动作后,我们会根据当前状态、执行的动作、获得的奖励和下一个状态,来更新Q-Table中对应的值,更新方式是

Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) Q^{\pi}\left( s_t,a_t \right) ={ {Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }} Qπ(st,at)=Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)Qπ(st,at))

对应代码

self.q_table.loc[state, action] += self.lr * (q_target - q_predict)

在这里插入图片描述

保存的权重文件正是Q-Table,我们可以直观地看一下,其中0-3指的是上下左右四个动作,每行行首则是状态值,其余数是Q-Value

,0,1,2,3
"[45.0, 45.0, 75.0, 75.0]",-3.764746051087998,-4.129632180625153,2.070923999854885,-4.129632180625153
terminal,0.0,0.0,0.0,0.0
"[85.0, 45.0, 115.0, 75.0]",-3.7017636879676745,-3.2427095093971663,6.341493354722148,-2.4376270354451357
"[125.0, 45.0, 155.0, 75.0]",-2.822694674017249,12.009385340227768,-3.10550914130922,-1.7370066390489591
"[125.0, 85.0, 155.0, 115.0]",-1.018256983413196,-2.3765728565289628,19.23732307528551,-2.602996266117196
"[165.0, 85.0, 195.0, 115.0]",-2.063857163563445,27.370237164958994,-0.7307141976318489,0.14330394709222574
"[205.0, 85.0, 235.0, 115.0]",-0.4546075907459214,-0.45498153729692925,-0.490099501,0.3662096391980347
"[165.0, 125.0, 195.0, 155.0]",0.9791630128216775,35.427315495348594,-0.28782126600827374,-1.7383137616441329
"[205.0, 45.0, 235.0, 75.0]",-0.3940399,-0.38288265597631166,-0.3940399,-0.3940399
"[205.0, 125.0, 235.0, 155.0]",-0.31765122402993484,-0.3940399,-0.3940399,1.5298899806741253
...

4.2 奖励曲线

训练过程的奖励曲线如下所示

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/256537.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

flowable工作流学习笔记

不同版本使用方式不一样&#xff0c;案例使用两个版本6.5.0及6.6.0,学习中6.5.0 版本是独立框架&#xff08;服务单独部署&#xff09;使用的&#xff0c; 6.6.0与springboot集成&#xff0c; 6.5.0版本如下&#xff1a; 下载flowable&#xff1a; https://github.com/flowa…

kubernetesr安全篇之云原生安全概述

云原生 4C 安全模型 云原生 4C 安全模型&#xff0c;是指在四个层面上考虑云原生的安全&#xff1a; Cloud&#xff08;云或基础设施层&#xff09;Cluster&#xff08;Kubernetes 集群层&#xff09;Container&#xff08;容器层&#xff09;Code&#xff08;代码层&#xf…

电商API接口接入|电商系统中的商品功能就该这么设计,稳的一批!

商品功能作为电商系统的核心功能&#xff0c;它的设计可谓是非常重要的。就算不是电商系统中&#xff0c;只要是涉及到需要交易物品的项目&#xff0c;商品功能都具有很好的参考价值。今天就以mall项目中的商品功能为例&#xff0c;来聊聊商品功能的设计与实现。 mall项目简介 …

你必须知道的低代码和低代码代表厂商!

自低代码进入中国市场以来&#xff0c;已经有不少年头。低代码&#xff08;Low-Code&#xff09;是一种软件开发方法&#xff0c;它使得开发人员能够通过图形界面、拖放组件和模型驱动的逻辑&#xff0c;快速地构建和部署应用程序&#xff0c;而无需编写大量的代码。 低代码开…

减速机振动相关标准 - 笔记

参考标准&#xff1a;国家标准|GB/T 39523-2020 减速机的振动标准与发动机不同&#xff0c;摘引&#xff1a; 原始加速度传感器波形 可以明显看到调幅波 它的驱动电机是300Hz~2000Hz范围的。这个采样时间是5秒&#xff0c;看分辨率至少1024线。可分出500条谱线。 频谱部分 …

大模型上下文扩展之YaRN解析:从RoPE、到ALiBi、位置插值、到YaRN

前言 下半年以来&#xff0c;我全力推动我司大模型项目团队的组建&#xff0c;我虽兼管整个项目团队&#xff0c;但为了并行多个项目&#xff0c;最终分成了三个项目组&#xff0c;每个项目都有一个项目负责人&#xff0c;分别为霍哥、阿荀、朝阳 在今年Q4&#xff0c;我司第…

经纬恒润AUTOSAR成功适配曦华科技国产车规级芯片

近日&#xff0c;经纬恒润AUTOSAR基础软件产品INTEWORK-EAS-CP成功适配曦华科技的蓝鲸CVM014x系列车规级MCU芯片。同时&#xff0c;经纬恒润完成了对曦华科技开发板的MCAL软件适配和工程集成&#xff0c;为曦华科技提供了全套AUTOSAR解决方案。 基于蓝鲸CVM014x适配经纬恒润AUT…

微信万能表单源码系统 自定义你的表单系统+完整代码包+安装部署教程

表单系统已经成为了网站、APP等应用中不可或缺的一部分。无论是注册、登录、反馈还是其他各种场景&#xff0c;都需要表单来收集用户信息。然而&#xff0c;传统的表单系统往往存在着一些问题&#xff0c;如功能单一、扩展性差、维护困难等。 以下是部分代码示例&#xff1a; …

ArkTS 状态管理@Prop、@Link

当父子组件之间需要数据同步的时候&#xff0c;可以使用Prop和Link装饰器。 实现的案例之中&#xff0c;代码时平铺直叙的&#xff0c;阅读性可理解性比较差。我们应改遵循组件化开发的思想。 在我们使用组件开发的时候&#xff0c;遇到数据同步问题的时候&#xff0c;State状态…

HuatuoGPT模型介绍

文章目录 HuatuoGPT 模型介绍LLM4Med&#xff08;医疗大模型&#xff09;的作用ChatGPT 存在的问题HuatuoGPT的特点ChatGPT 与真实医生的区别解决方案用于SFT阶段的混合数据基于AI反馈的RL 评估单轮问答多轮问答人工评估 HuatuoGPT 模型介绍 HuatuoGPT&#xff08;华佗GPT&…

利用台阶仪测量薄膜厚度的方法和技巧

在薄膜制备过程中&#xff0c;薄膜厚度是一个至关重要的参数&#xff0c;直接影响薄膜的性能和应用。为了准确测量薄膜厚度&#xff0c;研究者广泛使用台阶仪&#xff0c;这是一种方便、直接、准确的测量方法。本文将介绍如何利用台阶仪进行薄膜厚度测量的方法和技巧。 选择合…

JavaWeb编程语言—登录校验

一、前言&简介 前言&#xff1a;小编的上一篇文章“JavaWeb编程语言—登录功能实现”&#xff0c;介绍了如何通过Java代码实现通过接收前端传来的账号、密码信息来登录后端服务器&#xff0c;但是没有实现登录校验功能&#xff0c;这代表着用户不需要登录也能直接访问服务器…

龙迅LT6211B,HDMI1.4转LVDS,应用于AR/VR市场

产品描述 LT6211B 是一款用于 VR/ 显示应用的高性能 HDMI1.4 至 LVDS 芯片。 对于 LVDS 输出&#xff0c;LT6211B 可配置为单端口、双端口或四端口。对于2D视频流&#xff0c;同一视频流可以映射到两个单独的面板&#xff0c;对于3D视频格式&#xff0c;左侧数据可以发送到一个…

基于YOLOv8深度学习的智能小麦害虫检测识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

MYSQL中使用IN,在xml文件中怎么写?

MYSQL&#xff1a; Spring中&#xff1a; mysql中IN后边的集合&#xff0c;在后端中使用集合代替&#xff0c;其他的没有什么注意的&#xff0c;还需要了解foreach 语法即可。

ros2 学习03-开发工具vscode 插件配置

VSCode插件配置 为了便于后续ROS2的开发与调试&#xff0c;我们还可以安装一系列插件&#xff0c;无限扩展VSCode的功能。 中文语言包 Python插件 C插件 CMake插件 vscode-icons ROS插件 Msg Language Support Visual Studio IntelliCode URDF Markdown All in One VSCode支持的…

Linux服务器修改系统时间

一、修改时区 1、查看系统当前时间 date 2、删除当前时间&#xff1a; #删除当前默认时区 rm -rf /etc/localtime 3、 将当前时区修改为上海时区 #修改默认时区为上海 ln -s /usr/share/zoneinfo/Asia/Shanghai /etc/localtime 二、修改系统时间 1、查看系统当前时间 d…

2018年第七届数学建模国际赛小美赛D题速度扼杀爱情解题全过程文档及程序

2018年第七届数学建模国际赛小美赛 D题 速度扼杀爱情 原题再现&#xff1a; 在网上约会的时代&#xff0c;有比鱼更多的浪漫选择&#xff0c;好吧&#xff0c;你知道的。例如&#xff0c;在命名恰当的网站Plenty of Fish上&#xff0c;你可以仔细查看数百或数千名潜在伴侣的档…

Web前端-HTML(简介)

文章目录 1. HTML1.1概述1.2 HTML骨架标签1.3 HTML元素标签及分类1.4 HTML标签关系 2. 代码开发工具&#xff08;书写代码&#xff09;3. 文档类型<!DOCTYPE>4. 页面语言lang5. 字符集 1. HTML 1.1概述 HTML 指的是超文本标记语言 (Hyper Text Markup Language)&#x…

STM32 CAN多节点组网项目实操 挖坑与填坑记录2

系列文章&#xff0c;持续探索CAN多节点通讯&#xff0c; 上一篇文章链接&#xff1a; STM32 CAN多节点组网项目实操 挖坑与填坑记录-CSDN博客文章浏览阅读120次。CAN线性组网项目开发过程中遇到的数据丢包问题&#xff0c;并尝试解决的记录和推测分析。开发了一个多节点线性…