优先经验回放(prioritized experience replay)

prioritized experience replay 思路

优先经验回放出自ICLR 2016的论文《prioritized experience replay》。

prioritized experience replay的作者们认为,按照一定的优先级来对经验回放池中的样本采样,相比于随机均匀的从经验回放池中采样的效率更高,可以让模型更快的收敛。其基本思想是RL agent在一些转移样本上可以更有效的学习,也可以解释成“更多地训练会让你意外的数据”。

那优先级如何定义呢?作者们使用的是样本的TD error δ \delta δ 的幅值。对于新生成的样本,TD error未知时,将样本赋值为最大优先级,以保证样本至少将会被采样一次。每个采样样本的概率被定义为
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=kpkαpiα
上式中的 p i > 0 p_i >0 pi>0是回放池中的第i个样本的优先级, α \alpha α则强调有多重视该优先级,如果 α = 0 \alpha=0 α=0,采样就退化成和基础DQN一样的均匀采样了。

p i p_i pi如何取值,论文中提供了如下两种方法,两种方法都是关于TD error δ \delta δ 单调的:

  • 基于比例的优先级: p i = ∣ δ i ∣ + ϵ p_i = |\delta_i| + \epsilon pi=δi+ϵ ϵ \epsilon ϵ是一个很小的正数常量,防止当TD error为0时样本就不会被访问到的情形。(目前大部分实现都是使用的这个形式的优先级)
  • 基于排序的优先级: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi=rank(i)1, 式中的 r a n k ( i ) rank(i) rank(i)是样本根据 ∣ δ i ∣ |\delta_i| δi 在经验回放池中的排序号,此时P就变成了带有指数 α \alpha α的幂率分布了。

作者们定义的概率调整了样本的优先级,因此也就在数据分布中引入了偏差,为了弥补偏差,使用了重要性采样权重(importance-sampling (IS) weights):
w i = ( 1 N ⋅ 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi=(N1P(i)1)β
上式权重中,当 β = 1 \beta=1 β=1时就完全补偿了非均匀概率采样引入的偏差,作者们提到为了收敛性考虑,最后让 β \beta β从0到1中的某个值开始,并逐渐增加到1。在Q-learning更新时使用这些权重乘以TD error,也就是使用 w i δ i w_i \delta_i wiδi而不是原来的 δ i \delta_i δi。此外,为了使训练更稳定,总是对权重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxiwi进行归一化。

以Double DQN为例,使用优先经验回放的算法(论文算法1)如下图:

在这里插入图片描述

prioritized experience replay 实现

直接实现优先经验回放池如下代码(修改自代码 )

class PrioReplayBufferNaive:
    def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):
        self.prob_alpha = prob_alpha
        self.capacity = buf_size
        self.pos = 0
        self.buffer = []
        self.priorities = np.zeros((buf_size, ), dtype=np.float32)
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.epsilon = epsilon

    def __len__(self):
        return len(self.buffer)

    def size(self):  # 目前buffer中数据的数量
        return len(self.buffer)

    def add(self, sample):
        # 新加入的数据使用最大的优先级,保证数据尽可能的被采样到
        max_prio = self.priorities.max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append(sample)
        else:
            self.buffer[self.pos] = sample
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        probs = np.array(prios, dtype=np.float32) ** self.prob_alpha

        probs /= probs.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)
        samples = [self.buffer[idx] for idx in indices]
        total = len(self.buffer)
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
        weights = (total * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        return samples, indices, np.array(weights, dtype=np.float32)

    def update_priorities(self, batch_indices, batch_priorities):
        '''
        更新样本的优先级'''
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio + self.epsilon

直接实现的优先经验回放,在样本数很大时的采样效率不够高,作者们通过定义了sumtree的数据结构来存储样本优先级,该数据结构的每一个节点的值为其子节点之和,而样本优先级被放在树的叶子节点上,树的根节点的值为所有优先级之和 p t o t a l p_{total} ptotal,更新和采样时的效率为 O ( l o g N ) O(logN) O(logN)。在采样时,设采样批次大小为k,将 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal]均分为k等份,然后在每一个区间均匀的采样一个值,再通过该值从树中提取到对应的样本。python 实现如下(代码来源)

class SumTree:
    """
    父节点的值是其子节点值之和的二叉树数据结构
    """
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])


class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTree
    epsilon = 0.01
    alpha = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity

    def _get_priority(self, error):
        return (np.abs(error) + self.epsilon) ** self.alpha

    def add(self, error, sample):
        p = self._get_priority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        idxs = []
        segment = self.tree.total() / n
        priorities = []

        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)

        sampling_probabilities = priorities / self.tree.total()
        is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
        is_weight /= is_weight.max()

        return batch, idxs, is_weight

    def update(self, idx, error):
      '''
      这里是一次更新一个样本,所以在调用时,写for循环依次更次样本的优先级
      '''
        p = self._get_priority(error)
        self.tree.update(idx, p)

参考资料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的实现代码

  3. 相关blog: 1 (对应的代码), 2, 3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/178375.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

matlab绘图函数plot和fplot的区别

一、背景 有的函数用plot画就会报错&#xff0c;显示数据必须为可转换为双精度值的数值、日期时间、持续时间、分类或数组。 如下图所示&#xff1a; 但用fplot函数就没有问题&#xff0c;因此这里记录一下两者的区别&#xff0c;如果使用不当&#xff0c;画出的图可能就是下…

坚鹏:中国工商银行数字化转型发展现状与成功案例培训圆满结束

中国工商银行围绕“数字生态、数字资产、数字技术、数字基建、数字基因”五维布局&#xff0c;深入推进数字化转型&#xff0c;加快形成体系化、生态化实施路径&#xff0c;促进科技与业务加速融合&#xff0c;以“数字工行”建设推动“GBC”&#xff08;政务、企业、个人&…

用 jmeter 对 mongodb 进行测试方法合集

MongoDB 作为非关系型数据库&#xff0c;在现在企业中&#xff0c;还是有广泛的使用。但是&#xff0c;用 jmeter 如何测试 MongoDB&#xff0c;却是一个令很多人头疼的问题。去搜索&#xff0c;国内基本找不到一篇比较有价值的文章。 今天&#xff0c;我就用三种不同方法&…

如何通过RA过程识别Redcap UE?

以下是38.300中的描述 RedCap UE可以通过发送MSG3/MSGA的特定LCID识别&#xff0c;可选条件是通过MSGA/MSG1的PRACH occasion/PRACH preamble识别&#xff0c;根据这段描述&#xff0c;通过MSG3/MSGA的识别是必须项&#xff0c;而MSGA/MSG1的识别过程是可选项。如果通过MSGA/MS…

02 请求默认值

一、HTTP请求默认值&#xff1a;是用来管理所有请求共有的协议、网址、端口等信息的&#xff1b;通常情况下&#xff0c;一批量的接口测试&#xff0c;访问的是同一个站点&#xff0c;那么以上信息基本都是相同的&#xff0c;就不需要在每个请求中重复编写&#xff1b; 每个请…

Spring cloud - Hystrix源码

其实只是Hystrix初始化部分&#xff0c;我们从源码的角度分析一下EnableCircuitBreaker以及HystrixCommand注解的初始化过程。 从EnableCircuitBreaker入手 我们是通过在启动类添加EnableCircuitBreaker注解启用Hystrix的&#xff0c;所以&#xff0c;源码解析也要从这个注解…

2014年5月28日 Go生态洞察:GopherCon 2014大会回顾

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

ansible的基本安装

目录 一、简介 1.ansible自动化运维人工运维时代 2.自动化运维时代 3.ansible介绍 4.ansible特点 二、ansible实践 1.环境 2.ansible管理安装 3.ansible被管理安装 4.管理方式 5.添加被管理机器的ip 6.ssh密码认证方式管理 三、配置免密登录 1.ansible自带的密码…

hp惠普Victus Gaming Laptop 15-fa1025TX/fa1005tx原装出厂Win11系统ISO镜像

光影精灵9笔记本电脑原厂W11系统22H2恢复出厂时开箱状态一模一样 适用型号&#xff1a;15-fa1003TX&#xff0c;15-fa1005TX&#xff0c;15-fa1007TX&#xff0c;15-fa1025TX 链接&#xff1a;https://pan.baidu.com/s/1fBPjed1bhOS_crGIo2tP1w?pwduzvz 提取码&#xff1a…

【华为OD题库-033】经典屏保-java

题目 DVD机在视频输出时&#xff0c;为了保护电视显像管&#xff0c;在待机状态会显示"屏保动画”&#xff0c;如下图所示,DVD Logo在屏幕内来回运动&#xff0c;碰到边缘会反弹:请根据如下要求&#xff0c;实现屏保Logo坐标的计算算法 1、屏幕是一个800 * 600像素的矩形&…

软件测试:功能测试常用的测试用例大全

登录、添加、删除、查询模块是我们经常遇到的&#xff0c;这些模块的测试点该如何考虑 1)登录 ① 用户名和密码都符合要求(格式上的要求) ② 用户名和密码都不符合要求(格式上的要求) ③ 用户名符合要求&#xff0c;密码不符合要求(格式上的要求) ④ 密码符合要求&#xf…

JavaScript实现右键菜单

1、代码实现 window.onload function () {(function () {// 自定义右键菜单内容并插入到body最后一个节点前let dom <div id"rightMenuBars"><div class"rightMenu-group rightMenu-small"><div class"rightMenu-item"><…

【应用程序启动过程-三种加载控制器的方式-上午内容复习 Objective-C语言】

一、我们先来回忆一下,上午所有内容 1.首先呢,我们先说的是这个“应用程序启动过程”, 应用程序启动过程里面,有三方面内容 1)UIApplication对象介绍 2)AppDelegate对象介绍 3)应用程序启动过程 现在不知道大家对这个应用程序启动过程有印象吗, 2.首先,这个UIAp…

渗透测试高级技巧(一):分析验签与前端加密

“开局一个登录框” 在黑盒的安全测试的工作开始的时候&#xff0c;打开网站一般来说可能仅仅是一个登录框&#xff1b;很多时候这种系统往往都是自研或者一些业务公司专门研发。最基础的情况下&#xff0c;我们会尝试使用 SQL 注入绕过或者爆破之类的常规手段&#xff0c;如果…

开源计算机视觉库OpenCV详解

目录 1、概述 2、OpenCV详细介绍 2.1、OpenCV的起源 2.2、OpenCV开发语言 2.3、OpenCV的应用领域 3、OpenCV模块划分 4、OpenCV源码文件结构 4.1、根目录介绍 4.2、常用模块介绍 4.3、CUDA加速模块 5、OpenCV配置以及Visual Studio使用OpenCV 6、关于Lena图片 7、…

简于外 强于内,联想全新ThinkCentre M90a Pro Gen4以强劲性能开启商用新体验

近日&#xff0c;联想发布了最新一代商用台式一体机联想ThinkCentre M90a Pro Gen4。作为联想ThinkCentre M大师系列的旗舰产品&#xff0c;其配备了优质的显示屏&#xff0c;拥有强大的性能和稳定安全的特性&#xff0c;能够满足多样的工作场合&#xff0c;为商用一体机的行业…

面试题:你怎么理解System.out.println() ?

文章目录 首先分析System源码out源码分析println分析拓展知识点 你如何理解System.out.println() ? 学了这么久的面向对象编程&#xff0c;那如何用一行代码体现呢&#xff1f; 如果你能自己读懂System.out.println()&#xff0c;就真正了解了Java面向对象编程的含义 面向对…

Ajax入门-Express框架介绍和基本使用

电脑实在忒垃圾了&#xff0c;出现问题耗费了至少一刻钟time&#xff0c;然后才搞出来正常的效果&#xff1b; 效果镇楼 另外重新安装了VScode软件&#xff0c;原来的老是报错&#xff0c;bug。。&#xff1b; 2个必要的安装命令&#xff1b; 然后建立必要的文件夹和文件&…

MAX/MSP SDK学习07:list传递

实现自定义Obejct&#xff0c;要求将传入的一组数据100后传出。 #include "ext.h" #include "ext_obex.h" typedef struct _listTrans {t_object ob;void* outLet;t_atom* fArr;long listNum;} t_listTrans;void* listTrans_new(t_symbol* s, long arg…

【资深硬件工程师总结-千兆以太网设计指南】

文章目录 01通用PCB布线指南02标志焊盘中的接地过孔区示例03EMI注意事项04ESD注意事项 资深硬件工程师总结-千兆以太网设计指南 本应用笔记旨在帮助客户使用Microchip的10/100/1000 Mbps以太网器件系列设计PCB。本文档提供有关PCB布线的建 议&#xff0c; PCB 布线是保持信号完…