强化学习10——免模型控制Q-learning算法

Q-learning算法

主要思路

由于 V π ( s ) = ∑ a ∈ A π ( a ∣ s ) Q π ( s , a ) V_\pi(s)=\sum_{a\in A}\pi(a\mid s)Q_\pi(s,a) Vπ(s)=aAπ(as)Qπ(s,a) ,当我们直接预测动作价值函数,在决策中选择Q值最大即动作价值最大的动作,则可以使策略和动作价值函数同时最优,那么由上述公式可得,状态价值函数也是最优的。
Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ max ⁡ a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\alpha[r_t+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)] Q(st,at)Q(st,at)+α[rt+γamaxQ(st+1,a)Q(st,at)]
Q-learning基于时序差分的更新方法,具体流程如下所示:

  • 初始化 Q ( s , a ) Q(s,a) Q(s,a)
  • for 序列 e = 1 → E e=1\to E e=1E do:
    • 得到初始状态s
    • for 时步 t = 1 → T t=1\to T t=1T do:
      • 使用 ϵ − g r e e d y \epsilon -greedy ϵgreedy 策略根据Q选择当前状态s下的动作a
      • 得到环境反馈 r , s ′ r,s' r,s
      • Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha[r+\gamma\max_{a^{\prime}}Q(s^{\prime},a^{\prime})-Q(s,a)] Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]
      • s ← s ′ s\gets s' ss
    • end for
  • end for

算法实战

我们在悬崖漫步环境下实习Q-learning算法。

首先创建悬崖漫步的环境:

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # tqdm是显示循环进度条的库

class CliffWalkingEnv:
    def __init__(self, ncol, nrow):
        self.nrow = nrow
        self.ncol = ncol
        self.x = 0  # 记录当前智能体位置的横坐标
        self.y = self.nrow - 1  # 记录当前智能体位置的纵坐标

    def step(self, action):  # 外部调用这个函数来改变当前位置
        # 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))
        self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))
        next_state = self.y * self.ncol + self.x
        reward = -1
        done = False
        if self.y == self.nrow - 1 and self.x > 0:  # 下一个位置在悬崖或者目标
            done = True
            if self.x != self.ncol - 1:
                reward = -100
        return next_state, reward, done

    def reset(self):  # 回归初始状态,坐标轴原点在左上角
        self.x = 0
        self.y = self.nrow - 1
        return self.y * self.ncol + self.x

创建Q-learning算法

class QLearning:
    def __init__(self, ncol, nrow, epsilon, alpha, gamma,n_action=4):
        self.epsilon = epsilon  # 随机探索的概率
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.n_action = n_action  # 动作数量
        # 给每一个状态创建一个长度为4的列表。
        self.Q_table = np.zeros([nrow*ncol,n_action])  # 初始化Q(s,a)
        
    def take_action(self,state):
        # 选取下一步的操作
        if np.random.random()<self.epsilon:
            action = np.random.randint(self.n_action)  # 随机探索
        else:
            action = np.argmax(self.Q_table[state])  # 贪婪策略,选择Q值最大的动作
        return action
    
    def best_action(self, state):  # 用于打印策略
        Q_max = np.max(self.Q_table[state])
        a = [0 for _ in range(self.n_action)]
        for i in range(self.n_action):
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a
    
    def update(self,s0,a0,r,s1):
        td_error = r+self.gamma*self.Q_table[s1].max()-self.Q_table[s0,a0]
        self.Q_table[s0, a0] += self.alpha * td_error
ncol = 12
nrow = 4    
np.random.seed(0)
epsilon = 0.1
alpha = 0.1
gamma = 0.9
env = CliffWalkingEnv(ncol, nrow)
agent = QLearning(ncol, nrow, epsilon, alpha, gamma)
num_episodes = 500  # 智能体在环境中运行的序列的数量
return_list = [] # 记录每一条序列的回报
# 显示10个进度条
for i in range(10):
       # tqdm的进度条功能
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):  # 每个进度条的序列数
            episode_return = 0
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                next_state, reward, done = env.step(action)
                episode_return += reward  # 这里回报的计算不进行折扣因子衰减
                agent.update(state, action, reward, next_state)
                state = next_state
            return_list.append(episode_return)
            if (i_episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Q-learning on {}'.format('Cliff Walking'))
plt.show()

action_meaning = ['^', 'v', '<', '>']
print('Q-learning算法最终收敛得到的策略为:')
def print_agent(agent, env, action_meaning, disaster=[], end=[]):
    for i in range(env.nrow):
        for j in range(env.ncol):
            if (i * env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * env.ncol + j) in end:
                print('EEEE', end=' ')
            else:
                a = agent.best_action(i * env.ncol + j)
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()


action_meaning = ['^', 'v', '<', '>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
Iteration 0: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2040.03it/s, episode=50, return=-105.700]
Iteration 1: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2381.99it/s, episode=100, return=-70.900] 
Iteration 2: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3209.35it/s, episode=150, return=-56.500] 
Iteration 3: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3541.95it/s, episode=200, return=-46.500] 
Iteration 4: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 5005.26it/s, episode=250, return=-40.800] 
Iteration 5: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3936.76it/s, episode=300, return=-20.400] 
Iteration 6: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 4892.00it/s, episode=350, return=-45.700] 
Iteration 7: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 5502.60it/s, episode=400, return=-32.800] 
Iteration 8: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6730.49it/s, episode=450, return=-22.700] 
Iteration 9: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6768.50it/s, episode=500, return=-61.700] 
Q-learning算法最终收敛得到的策略为:
Qling算法最终收敛得到的策略为:
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/306649.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis入门-redis的五大数据类型+三种特殊的数据类型

前言&#xff1a;Redis有五大基本类型与三种特殊类型的介绍 Redis有五大基本类型&#xff1a;字符串&#xff08;string&#xff09;、哈希&#xff08;hash&#xff09;、列表&#xff08;list&#xff09;、集合&#xff08;set&#xff09;和有序集合&#xff08;sorted se…

ModuleNotFoundError: No module named ‘SwissArmyTransformer‘

小问题&#xff0c;直接pip install pip install SwissArmyTransformer 但是&#xff0c;安装之后却还是提示&#xff0c;屏幕上依然标红 ModuleNotFoundError: No module named SwissArmyTransformer 查找环境目录发现&#xff0c; 这是因为新版的SwissArmyTransformer中&…

Spring boot 3 集成rocketmq-spring-boot-starter解决版本不一致问题

安装RocketMQ根据上篇文章使用Docker安装RocketMQ并启动之后&#xff0c;有个隐患详情见下文 Spring Boot集成 <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId><version>2.2…

【排序】快速排序

思想 快速排序是一种基于分治策略的排序算法&#xff0c;其核心思想通过选取一个基准元素&#xff0c;将数组分成两个子数组&#xff1a;一个包含小于基准元素的值&#xff0c;另一个包含大于基准元素的值。然后递归地对这两个子数组进行排序&#xff0c;最终将它们合并起来&a…

linux项目部署(jdk,tomcat,mysql,nginx,redis)

打开虚拟机&#xff0c;与连接工具连接好&#xff0c;创建一个文件夹 cd /tools 把jdk,tomcat安装包放入这个文件夹里面 jdk安装 #解压 tar -zxvf apache-tomcat-8.5.20.tar.gz #解压jdk tar -zxvf jdk-8u151-linux-x64.tar.gz 编辑jdk文件以及测试jdk安装 第一行代码路径…

数 据 分 析 1

1.使用Wireshark查看并分析靶机桌面下的capture.pcapng数据包文件&#xff0c;找到黑客的IP地址&#xff0c;并将黑客的IP地址作为Flag值&#xff08;如&#xff1a;172.16.1.1&#xff09;提交&#xff1b;172.16.1.41 查找&#xff1a;tcp.connection.syn 2.继续分析captu…

冲刺2024年AMC8竞赛:往年真题练一练和答案详解(3)

今天我们继续来做一做往年的AMC8真题&#xff0c;通过高质量的真题来体会我们所学的知识如何解题&#xff0c;建立快速思考、做对题目的策略。 今天分享的五道题目仍然是随机从六分成长独家制作的575道在线题库&#xff08;来自于往年真题&#xff09;中抽取5道题来做一下&…

TSP(Python):Qlearning求解旅行商问题TSP(提供Python代码)

一、Qlearning简介 Q-learning是一种强化学习算法&#xff0c;用于解决基于奖励的决策问题。它是一种无模型的学习方法&#xff0c;通过与环境的交互来学习最优策略。Q-learning的核心思想是通过学习一个Q值函数来指导决策&#xff0c;该函数表示在给定状态下采取某个动作所获…

Android Studio导入项目 下载gradle很慢或连接超时,提示:Read timed out---解决方法建议收藏!

目录 前言 一、报错信息 二、解决方法 三、更多资源 前言 一般来说&#xff0c;使用Android Studio导入项目并下载gradle的过程应该是相对顺利的&#xff0c;但是有时候会遇到下载速度缓慢或连接超时的问题&#xff0c;这可能会让开发者感到头疼。这种情况通常会出现在网络…

基于Jackson自定义json数据的对象转换器

1、问题说明 后端数据表定义的id主键是Long类型&#xff0c;一共有20多位。 前端在接收到后端返回的json数据时&#xff0c;Long类型会默认当做数值类型进行处理。但前端处理20多位的数值会造成精度丢失&#xff0c;于是导致前端查询数据出现问题。 测试前端Long类型的代码 …

常见排序算法及其稳定性分析

前言&#xff1a; 排序算法可以说是每一个程序员在学习数据结构和算法时必须要掌握的知识点&#xff0c;同样也是面试过程中可能会遇到的问题&#xff0c;在早些年甚至还会考冒泡排序。由此可见呢&#xff0c;掌握一些常见的排序算法是一个程序员的基本素养。虽然现在的语言标…

【第一次使用finalshell连接虚拟机内的centos】小白处理方式

第一次使用finalshell连接centos7的时候&#xff0c;因为都是新环境什么都没有配置&#xff0c;所以就需要安装finalshell和对新的centos7 进行一些配置。 安装finalshel&#xff0c;默认不安装d盘&#xff0c;就需要对安装路径做一下调整&#xff0c;其余都是下一步默认安装的…

数据结构和算法-交换排序中的快速排序(演示过程 算法实现 算法效率 稳定性)

文章目录 总览快速排序&#xff08;超级重要&#xff09;啥是快速排序演示过程算法实现第一次quicksort函数第一次partion函数到第一次quicksort的第一个quicksort到第二次quicksort的第一个quicksort到第二次quicksort的第二个quicksort到第一次quicksort的第二个quicksort到第…

大漠插件7.2353

工具名称:大漠插件7.2353 更新时间2023-12-29更新内容/v7.23531. FindPicSim优化,防止有些时候会找不到图2. 增加接口TerminateProcessTree3. 解决AsmCall 模式6在部分WIN11下无法正常生效的BUG/ 工具简介:大漠 综合 插件 (dm.dll)采用vc6.0编写&#xff0c;识别速度超级快&…

怎样在Anaconda下安装pytorch(conda安装和pip安装)

前言 文字说明 本文中标红的&#xff0c;代表的是我认为比较重要的。 版本说明 python环境配置&#xff1a;jupyter的base环境下的python是3.10版本。CUDA配置是&#xff1a;CUDA11.6。目前pytorch官网提示支持的版本是3.7-3.9 本文主要用来记录自己在安装pytorch中出现的问…

【软件测试】学习笔记-脚本与数据的解耦 + Page Object模型

本篇文章介绍GUI测试中两个非常重要的概念&#xff1a;测试脚本和数据的解耦&#xff0c;以及页面对象&#xff08;Page Object&#xff09;模型。 测试脚本和数据的解耦 GUI自动化测试适用的场景&#xff0c;尤其适用于需要回归测试页面功能的场景。如果在测试脚本中硬编码&a…

开源的RNA-Seq分析软件Trinity的详细介绍和使用方法

介绍 GitHub - trinityrnaseq/trinityrnaseq: Trinity RNA-Seq de novo transcriptome assembly Trinity是一种开源的RNA-Seq分析软件&#xff0c;用于转录组的de novo组装。转录组de novo组装是通过将RNA-Seq数据中的短序列片段&#xff08;reads&#xff09;重新组装成完整的…

Java8 Stream集合的筛选、归约、分组、聚合讲解

目录 1 Stream概述 2 Stream的创建 3 Stream的使用 3.1 Optional 3.2 案例 3.2.1 遍历/匹配&#xff08;foreach/find/match&#xff09; 3.2.2 筛选&#xff08;filter&#xff09; 3.2.3 聚合&#xff08;max/min/count) 3.2.4 映射(map/flatMap) 3.2.5 归约(reduce…

仿stackoverflow名片与b站名片实现(HTML、CSS)

目录 前言一、仿stackoverflow名片HTMLCSS 二、仿b站名片HTMLCSS 素材 前言 学习自ACwing - Web应用课 一、仿stackoverflow名片 HTML <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport&…

【实用技巧】Windows 电脑向iPhone或iPad传输视频方法1:无线传输

一、内容简介 本文介绍如何使用 Windows 电脑向 iPhone 或 iPad 传输视频&#xff0c;以 iPhone 为例&#xff0c;iPad的操作方法类似&#xff0c;本文不作赘述。 二、所需原材料 Windows 电脑&#xff08;桌面或其它文件夹中存有要导入的视频&#xff09;、iPhone 14。 待…