强化学习DQN实践(gymnasium+pytorch)

Pytorch官方教程中有强化学习教程,但是很多中文翻译都太老了,里面的代码也不能跑了
这篇blog按照官方最新教程实现,并加入了一些个人理解

工具

  • gymnasium:由gym升级而来,官方定义:An API standard for reinforcement learning with a diverse collection of reference environments。提供强化学习的“环境”
    • pip install gymnasium
  • pytorch

任务

倒立摆模型,使用强化学习控制小车来使倒立摆稳定,有小车向左和向右两个action
在这里插入图片描述
gym中提供了很多类似的强化学习“环境”

代码

准备工作

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

定义transition,包括state,action,next_state,reward:

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
	'''
	样本经验池,用于存储过往的transition,这些信息会被用来训练模型
	'''
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

DQN算法

class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

根据上篇强化学习理论blog,强化学习的流程是通过贝尔曼最优公式得到最优策略 π ( s ) \pi(s) π(s),求解需要知道reward函数 r ( s , a ) r(s,a) r(s,a),action value函数 Q ( s , a ) Q(s,a) Q(s,a),其中reward函数可以自行定义,action value如何得到?
在Q-learning算法中,设定策略为贪心策略,即每次都是选择action value最大的action执行,这样就可以将原本原贝尔曼公式中的state value变为action value中,将 V ( s ′ ) V(s^{\prime}) V(s) Q π ( s ′ , π ( s ′ ) ) Q^{\pi}(s^{\prime},\pi (s^{\prime})) Qπ(s,π(s))来代替:
Q π ( s , a ) = r + γ Q π ( s ′ , π ( s ′ ) ) Q^{\pi}(s,a) = r + \gamma Q^{\pi}(s^{\prime},\pi (s^{\prime})) Qπ(s,a)=r+γQπ(s,π(s))
实际更新流程变为:

  • 初始化 Q 0 Q_0 Q0
  • 使用结果估计 Q ~ ( s , a ) = r + γ Q 0 ( s ′ , π ( s ′ ) ) \tilde{Q}(s,a) = r + \gamma Q_0(s^{\prime},\pi (s^{\prime})) Q~(s,a)=r+γQ0(s,π(s))
  • 计算误差 δ \delta δ δ = Q 0 ( s , a ) − Q ~ ( s , a ) \delta = Q_0(s,a)-\tilde{Q}(s,a) δ=Q0(s,a)Q~(s,a)
  • 使用损失函数 L L L Δ = L ( δ ) \Delta = L(\delta) Δ=L(δ)
  • 更新 Q Q Q Q 1 = Q 0 − α Δ Q_1 = Q_0-\alpha \Delta Q1=Q0αΔ

经过迭代后便会收敛到最优策略。
DNQ将Q-learning中的Q表改进为了神经网络,对于连续状态的环境更合适。而且会将过往的决策过程记录下来维护一个样本池,方便一次更新多个样本,而不是按时间一次次运行更新,这样回放机制就会减少应用于高度相关的状态序列时因为前后样本存在关联导致的强化学习震荡和发散的问题。

强化学习分类:
通过学习目标分类:

  • 基于价值的方法,训练agent学习行为价值函数,隐式学习了策略,如Q-learning
  • 基于策略的方法,训练agent直接学习策略
  • 把 value-based 和 policy-based 结合起来就是 演员-评论家(Actor-Critic)方法。这一类 agent 需要显式地学习价值函数和策略。如DDPG

通过交互策略和更新策略间的关系分类

  • on-policy方法:目标策略和交互策略是同一个策略 π ( s ) \pi(s) π(s)
  • off-policy方法:使用一种行为策略 μ ( s ) \mu(s) μ(s)来与环境交互,学习的目标策略是 π ( s ) \pi(s) π(s)

训练准备

# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA 折扣因子,用于计算discounted
# EPS_START:随机选择action的概率初始值。
# EPS_END 随机选择action的概率末尾值
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR 优化器的学习率
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0


def select_action(state):
	'''
	选择一个action,随机选择or使用模型输出,这么做是为了能够全面学习所有行为,而不陷入局部最优
	随着迭代随机选择的概率会逐渐减小至EPS_END
	'''
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
	'''
	画过去的随机概率的变化过程
	'''
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

train loop

def optimize_model():
	'''
	模型更新函数,从memory中抽取一个batch,然后更新
	'''
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # 计算Q(s,a),模型输出是每个action的概率,根据这个输出到action_batch中获取action value
    # These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

if torch.cuda.is_available() or torch.backends.mps.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # 软更新,即老的权重使用新的权重进行加权更新
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

Ref

强化学习基本原理
Pytorch官方tutorial

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/905561.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024快手面试算法题-生气传染

问题描述 思路分析 生气只会向后传播&#xff0c;最后一个生气的人一定是最长连续没有生气的人中的最后一个人&#xff0c;前提是前面得有一个人生气。 注意&#xff0c;一次只能传播一个人&#xff0c;比如示例1&#xff0c;第一次只会传播给第一个P&#xff0c;不会传播给第…

入门 | Kafka数据使用vector消费到Loki中使用grafana展示

一、Loki的基本介绍 1、基本介绍 Loki 是由 Grafana Labs 开发的一款水平可扩展、高性价比的日志聚合系统。它的设计初衷是为了有效地处理和存储大量的日志数据&#xff0c;与 Grafana 生态系统紧密集成&#xff0c;方便用户在 Grafana 中对日志进行查询和可视化操作。 从架构…

分类算法——逻辑回归 详解

逻辑回归&#xff08;Logistic Regression&#xff09;是一种广泛使用的分类算法&#xff0c;特别适用于二分类问题。尽管名字中有“回归”二字&#xff0c;逻辑回归实际上是一种分类方法。下面将从底层原理、数学模型、优化方法以及源代码层面详细解析逻辑回归。 1. 基本原理 …

【Spring MVC】DispatcherServlet 请求处理流程

一、 请求处理 Spring MVC 是 Spring 框架的一部分&#xff0c;用于构建 Web 应用程序。它遵循 MVC&#xff08;Model-View-Controller&#xff09;设计模式&#xff0c;将应用程序分为模型&#xff08;Model&#xff09;、**视图&#xff08;View&#xff09;和控制器&#x…

现代数字信号处理I--最佳线性无偏估计 BLUE 学习笔记

目录 1. 最佳线性无偏估计的由来 2. 简单线性模型下一维参数的BLUE 3. 一般线性模型下一维参数的BLUE 4. 一般线性模型下多维参数的BLUE 4.1 以一维情况说明Rao论文中的结论 4.2 矢量参数是MVUE的本质是矢量参数中的每个一维参数都是MVUE 4.3 一般线性模型多维参数BLUE的…

QT(绘图)

目录 QPainter QPainter 的一些关键步骤和使用方法&#xff1a; QPainter 的一些常用接口&#xff1a; 1. 基础绘制接口 2. 颜色和画刷设置 3. 图像绘制 4. 文本绘制 5. 变换操作 6. 渲染设置 7. 状态保存与恢复 8. 其它绘制方法 示例代码1&#xff1a; 示例代码…

【js逆向学习】某多多anti_content逆向(补环境)

文章目录 声明逆向目标逆向分析逆向过程总结 声明 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;不提供完整代码&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的…

【安全解决方案】深入解析:如何通过CDN获取用户真实IP地址

一、业务场景 某大型互联网以及电商公司为了防止客户端获取到真实的ip地址&#xff0c;以及达到保护后端业务服务器不被网站攻击&#xff0c;同时又可以让公安要求留存网站日志和排查违法行为&#xff0c;以及打击犯罪的时候&#xff0c;获取不到真实的ip地址&#xff0c;发现…

Java | Leetcode Java题解之第524题通过删除字母匹配到字典里最长单词

题目&#xff1a; 题解&#xff1a; class Solution {public String findLongestWord(String s, List<String> dictionary) {int m s.length();int[][] f new int[m 1][26];Arrays.fill(f[m], m);for (int i m - 1; i > 0; --i) {for (int j 0; j < 26; j) {…

python爬虫抓取豆瓣数据教程

环境准备 在开始之前&#xff0c;你需要确保你的Python环境已经安装了以下库&#xff1a; requests&#xff1a;用于发送HTTP请求。BeautifulSoup&#xff1a;用于解析HTML文档。 如果你还没有安装这些库&#xff0c;可以通过以下命令安装&#xff1a; pip install requests…

Python实现深度学习模型预测控制(tensorflow)DL-MPC(Deep Learning Model Predictive Control

链接&#xff1a;深度学习模型预测控制 &#xff08;如果认为有用&#xff0c;动动小手为我点亮github小星星哦&#xff09;&#xff0c;持续更新中…… 链接&#xff1a;WangXiaoMingo/TensorDL-MPC&#xff1a;DL-MPC&#xff08;深度学习模型预测控制&#xff09;是基于 P…

简单的ELK部署学习

简单的ELK部署学习 1. 需求 我们公司现在使用的是ELK日志跟踪&#xff0c;在出现问题的时候&#xff0c;我们可以快速定为到问题&#xff0c;并且可以对日志进行分类检索&#xff0c;比如对服务名称&#xff0c;ip , 级别等信息进行分类检索。此文章为本人学习了解我们公司的…

神经网络进行波士顿房价预测

前言 前一阵学校有五一数模节校赛&#xff0c;和朋友一起参加做B题&#xff0c;波士顿房价预测&#xff0c;算是第一次自己动手实现一个简单的小网络吧&#xff0c;虽然很简单&#xff0c;但还是想记录一下。 题目介绍 波士顿住房数据由哈里森和鲁宾菲尔德于1978年Harrison …

Spark的集群环境部署

一、Standalone集群 1.1、架构 架构&#xff1a;普通分布式主从架构 主&#xff1a;Master&#xff1a;管理节点&#xff1a;管理从节点、接客、资源管理和任务 调度&#xff0c;等同于YARN中的ResourceManager 从&#xff1a;Worker&#xff1a;计算节点&#xff1a;负责利…

[java][基础]JSP

目标&#xff1a; 理解 JSP 及 JSP 原理 能在 JSP中使用 EL表达式 和 JSTL标签 理解 MVC模式 和 三层架构 能完成品牌数据的增删改查功能 1&#xff0c;JSP 概述 JSP&#xff08;全称&#xff1a;Java Server Pages&#xff09;&#xff1a;Java 服务端页面。是一种动态的…

常见问题 | 数字签名如何保障电子商务交易安全?

如何解决电商交易中数据泄露、交易欺诈等问题&#xff1f; 数字签名是一种类似于电子“指纹”的安全技术&#xff0c;它在电子商务中扮演着至关重要的角色。随着电子商务的迅猛发展&#xff0c;网上交易的数量不断增加&#xff0c;确保交易的安全性和完整性成为了亟待解决的问题…

【Python基础】

一、编程语言介绍 1、分类 机器语言 (直接用 0 1代码编写&#xff09;汇编语言 &#xff08;英文单词替代二进制指令&#xff09;高级语言 2、总结 1、执行效率&#xff1a;机器语言&#xff1e;汇编语言>高级语言&#xff08;编译型>解释型&#xff09; 2、开发效率&…

Java项目实战II基于Java+Spring Boot+MySQL的编程训练系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在当今数字…

双指针习题篇(上)

双指针习题篇(上) 文章目录 双指针习题篇(上)1.移动零题目描述&#xff1a;算法原理&#xff1a;算法流程&#xff1a;代码实现&#xff1a; 2.复写零题目描述&#xff1a;算法原理&#xff1a;算法流程&#xff1a;代码实现&#xff1a; 3.快乐数题目描述&#xff1a;算法原理…

更安全高效的文件传输工具,Ftrans国产FTP替代方案可以了解

文件传输协议&#xff08;FTP&#xff09;&#xff0c;诞生于1971年&#xff0c;自20世纪70年代发明以来&#xff0c;FTP已成为传输大文件的不二之选。内置有操作系统的 FTP 可提供一个相对简便、看似免费的文件交换方法&#xff0c;因此得到广泛使用。 随着企业发展过程中新增…