pytorch强化学习(1)——DQNSARSA

实验环境

python=3.10
torch=2.1.1
gym=0.26.2
gym[classic_control]
matplotlib=3.8.0
numpy=1.26.2

DQN代码

首先是module.py代码,在这里定义了网络模型和DQN模型

import torch
import torch.nn as nn
import numpy as np

class Net(nn.Module):
    # 构造只有一个隐含层的网络
    def __init__(self, n_states, n_hidden, n_actions):
        super(Net, self).__init__()
        # [b,n_states]-->[b,n_hidden]
        self.network = nn.Sequential(
            torch.nn.Linear(n_states, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_actions)
        )

    # 前传
    def forward(self, x):  # [b,n_states]
        return self.network(x)


class DQN:
    def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):
        # 属性分配
        self.n_states = n_states  # 状态的特征数
        self.n_hidden = n_hidden  # 隐含层个数
        self.n_actions = n_actions  # 动作数
        self.lr = lr  # 训练时的学习率
        self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放
        self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索
        # 计数器,记录迭代次数
        self.count = 0

        # 实例化训练网络
        self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)

        # 优化器,更新训练网络的参数
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.criterion = torch.nn.MSELoss()  # 损失函数

    def choose_action(self, gym_state):
        state = torch.Tensor(gym_state)
        if np.random.random() < self.epsilon:
            action_values = self.q_net(state)  # q_net(state)采取动作后的预测
            action = action_values.argmax().item()
        else:
            # 随机选择一个动作
            action = np.random.randint(self.n_actions)
        return action

    def update(self, gym_state, action, reward, next_gym_state, done):
        state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)
        q_value = self.q_net(state)[action]
        # 前千万不能缺少done,如果下一步游戏结束的花,那下一步的q值应该为0
        q_target = reward + self.gamma * self.q_net(next_state).max() * (1 - float(done))

        self.optimizer.zero_grad()
        dqn_loss = self.criterion(q_value, q_target)
        dqn_loss.backward()
        self.optimizer.step()

然后是train.py代码,在这里调用DQN模型和gym环境,来进行训练:

import gym
import torch
from module import DQN
import matplotlib.pyplot as plt

lr = 1e-3  # 学习率
gamma = 0.95  # 折扣因子
epsilon = 0.8  # 贪心系数
n_hidden = 200  # 隐含层神经元个数

env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数

dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)

if __name__ == '__main__':
    reward_list = []
    for i in range(500):
        state = env.reset()[0]  # len=4
        total_reward = 0
        done = False
        while True:

            # 获取当前状态下需要采取的动作
            action = dqn.choose_action(state)
            # 更新环境
            next_state, reward, done, _, _ = env.step(action)
            dqn.update(state, action, reward, next_state, done)
            state = next_state

            total_reward += reward

            if done:
                break
        print("第%d回合,total_reward=%f" % (i, total_reward))
        reward_list.append(total_reward)

    # 绘图
    episodes_list = list(range(len(reward_list)))
    plt.plot(episodes_list, reward_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('DQN Returns')
    plt.show()

SARSA代码

首先是module.py代码,在这里定义了网络模型和SARSA模型。
SARSA和DQN基本相同,只有在更新Q网络的时候略有不同,已在代码相应位置做出注释。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Net(nn.Module):
    # 构造只有一个隐含层的网络
    def __init__(self, n_states, n_hidden, n_actions):
        super(Net, self).__init__()
        # [b,n_states]-->[b,n_hidden]
        self.network = nn.Sequential(
            torch.nn.Linear(n_states, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_actions)
        )

    # 前传
    def forward(self, x):  # [b,n_states]
        return self.network(x)


class SARSA:
    def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):
        # 属性分配
        self.n_states = n_states  # 状态的特征数
        self.n_hidden = n_hidden  # 隐含层个数
        self.n_actions = n_actions  # 动作数
        self.lr = lr  # 训练时的学习率
        self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放
        self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索
        # 计数器,记录迭代次数
        self.count = 0

        # 实例化训练网络
        self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)

        # 优化器,更新训练网络的参数
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.criterion = torch.nn.MSELoss()  # 损失函数

    def choose_action(self, gym_state):
        state = torch.Tensor(gym_state)
        # 基于贪婪系数,有一定概率采取随机策略
        if np.random.random() < self.epsilon:
            action_values = self.q_net(state)  # q_net(state)是在当前状态采取各个动作后的预测
            action = action_values.argmax().item()
        else:
            # 随机选择一个动作
            action = np.random.randint(self.n_actions)
        return action

    def update(self, gym_state, action, reward, next_gym_state, done):
        state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)
        q_value = self.q_net(state)[action]

        '''
        sarsa在更新网络时选择的是q_net(next_state)[next_action] 
        这是sarsa算法和dqn的唯一不同
        dqn是选择max(q_net(next))
        '''
        next_action = self.choose_action(next_state)
        # 千万不能缺少done,如果下一步游戏结束的话,那下一步的q值应该为0,而不是q网络输出的值
        q_target = reward + self.gamma * self.q_net(next_state)[next_action] * (1 - float(done))

        self.optimizer.zero_grad()
        dqn_loss = self.criterion(q_value, q_target)
        dqn_loss.backward()
        self.optimizer.step()

SARSA也有tarin.py文件,功能和上面DQN的一样,内容也几乎完全一样,只是把DQN的名字改成SARSA而已,所以在这里不再赘述。

运行结果

DQN的运行结果如下:
在这里插入图片描述

SARSA运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/251775.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Zabbix监控系统部署与管理

目录 zabbix介绍 zabbix构成 zabbix进程 环境 zabbix-server节点部署 安装zabbix服务 安装与配置数据库 修改zabbix-PHP时区 登录网页安装 ​编辑数据库Access denied故障 zabbix-agent节点部署 zabbix web管理 中文乱码问题 zabbix介绍 zabbix是⼀个基于 Web 界…

【人工智能】实验二: 洗衣机模糊推理系统实验与基础知识

实验二: 洗衣机模糊推理系统实验 实验目的 理解模糊逻辑推理的原理及特点&#xff0c;熟练应用模糊推理。 实验内容 设计洗衣机洗涤时间的模糊控制。 实验要求 已知人的操作经验为&#xff1a; “污泥越多&#xff0c;油脂越多&#xff0c;洗涤时间越长”&#xff1b;“…

如何使用ycsb工具对mongodb进行性能测试过程

测试环境&#xff1a; linux系统&#xff1a;Centos 7.2 ,版本&#xff1a;Red Hat 4.8.5-44) YCSB简介 ycsb是一款性能测试工具&#xff0c;用Java写的&#xff0c;并且什么都可以压&#xff0c;像是mongodb&#xff0c;redis&#xff0c;mysql&#xff0c;hbase&#xff0c;等…

JavaScript值类型和引用类型两道经典面试题

JavaScript值类型和引用类型两道经典面试题 题目1题目2 题目1 首先&#xff0c;小试牛刀&#xff0c;请看第一道题。 let a {x: 10 } let b a a.x 20 console.log(b.x)a {x: 30 } console.log(b.x) a.x 40 console.log(b.x);那么上述代码输出结果是多少呢&#xff1f; …

逻辑分析仪_使用手册

LA1010 1> 能干啥&#xff1f;2> 硬件连接3> 软件安装4> 参数设置4.1> 采样深度和采样率4.2> 添加协议解析器4.3> 毛刺过滤设置 1> 能干啥&#xff1f; 测量通信波形&#xff0c;并自动解析&#xff1b; 比如测量&#xff0c;UART&#xff0c;SPI&…

Java系列-ConcurrentHashMap-addCount

1.addCount public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>implements ConcurrentMap<K,V>, Serializable {private final void addCount(long x, int check) {CounterCell[] as; long b, s;//1.counterCells不为null//2.或者 x加到baseCou…

如何在Docker部署draw.io流程图软件并实现公网远程访问

前言 提到流程图&#xff0c;大家第一时间可能会想到Visio&#xff0c;不可否认&#xff0c;VIsio确实是功能强大&#xff0c;但是软件为收费&#xff0c;并且因为其功能强大&#xff0c;导致安装需要很多的系统内存&#xff0c;并且是不可跨平台使用。所以&#xff0c;今天给…

深入学习《大学计算机》系列之第1章 1.3节——计算机科学的知识领域

一.欢迎来到我的酒馆 第1章 1.3节&#xff0c;计算机科学的知识领域。 目录 一.欢迎来到我的酒馆二.计算机科学的知识领域1.什么是计算机科学 二.计算机科学的知识领域 什么是计算机科学&#xff1f;什么是计算机学科&#xff1f;计算机科学包含哪些知识领域&#xff1f; …

PyCharm控制台异常堆栈乱码问题解决

目录 1、问题描述2、问题原因3、问题解决 1、问题描述 PyCharm环境都已经配置成了UTF-8编码&#xff0c;控制台打印中文也不会出现乱码&#xff0c;但异常堆栈信息中如果有中文会出现中文乱码&#xff1a; 这种该怎么解决呢&#xff1f; 2、问题原因 未将PyCharm编码环境与项目…

接口自动化测试实操【设置断言思路】

1 断言设置思路 这里总结了我在项目中常用的5种断言方式&#xff0c;基本可能满足90%以上的断言场景&#xff0c;具体参见如下脑图&#xff1a; 在这里插入图片描述 下面分别解释一下图中的五种思路&#xff1a; 1&#xff09; 响应码 对于http类接口&#xff0c;有时开发人…

Python:Jupyter

Jupyter是一个开源的交互式计算环境&#xff0c;由Fernando Perez和Brian Granger于2014年创立。它提供了一种方便的方式来展示、共享和探索数据&#xff0c;并且可以与多种编程语言和数据格式进行交互。Jupyter的历史可以追溯到2001年&#xff0c;当时Fernando Perez正在使用P…

Linux Shell——输入输出重定向

输入输出重定向 1. 重定向输入2. 重定向输出 总结 最近学习了shell语法&#xff0c;总结一下关于输入输出重定向的知识。 一般情况下&#xff0c;linux每次执行命令其实都会打开三个文件&#xff0c;分别是&#xff1a; 标准输入stdin 文件描述符为0 标准输出stdout 文件描述符…

《软件方法(下)》8.2.3 提炼类和属性(1)

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 8.2 建模步骤C-1 识别类和属性 8.2.2 三种分析类 8.2.2.6 自测题 扫码或访问http://www.umlchina.com/book/quiz8_2_2.html完成在线测试&#xff0c;做到全对以获得答案。 1. [单选…

springMVC-@RequestMapping

基本介绍 RequestMapping注解可以指定控制器/处理器的某个方法的请求的url, 示例 &#xff08;结合springMVC基本原理理解&#xff09; Controller public class UserHandler {RequestMapping(value "/login")public String login() {System.out.println("登…

微服务保护--熔断降级

1.熔断降级介绍 熔断降级是解决雪崩问题的重要手段。其思路是由断路器统计服务调用的异常比例、慢请求比例&#xff0c;如果超出阈值则会熔断该服务。即拦截访问该服务的一切请求&#xff1b;而当服务恢复时&#xff0c;断路器会放行访问该服务的请求。 断路器控制熔断和放行…

ShuffleNet V1+V2(pytorch)

V1 V1根本思想&#xff1a; 1.GConv替换resnet的普通1*1Conv 2.GConv后加channel shuffle模块 对GConv的不同组进行重新组合。channel_shuffle a是resnet模块&#xff0c;b&#xff0c;c是ShuffleNetV1的block&#xff0c;在V1版中&#xff0c;两模块branch2的第一个1*1卷积…

linux日志管理_日志系统

10.1 日志系统&#xff08;系统日志管理&#xff09;syslog&rsyslog 日志&#xff1a;主要用途是系统审计、监测追踪和分析统计。 ​ Linux内核由很多子系统组成&#xff0c;包括网络、文件访问、内存管理等。子系统需要给用户传送一些消息&#xff0c;这些消息内容包括消…

2023/12/17 初始化

普通变量&#xff08;int,float,double变量&#xff09;初始化&#xff1a; int a0; float b(0); double c0; 数组初始化&#xff1a; int arr[10]{0}; 指针初始化&#xff1a; 空指针 int *pnullptr; 被一个同类型的变量的地址初始化&#xff08;赋值&#xff09; int…

Latex表格的问题(如何合并单元格、单元格垂直居中、水平居中)

用到的package % 表格里面合并单元格用到的 \usepackage{multirow} % 表格 \usepackage{tabularx} % 限制图片或者表格在文字下方 \usepackage{float} % y应该就是这两个包&#xff0c;如果报错就去搜索一下&#xff0c;可以找得到的怎么实现水平居中 \begin{table}[H] \cent…

【ZYNQ】AXI4总线接口协议学习

建议翻看着底部的参考文档资料和本文一起辅助阅读 本文带你详细的了解AXI总线协议&#xff0c;并且基于官方手册&#xff0c;能够提高你的手册阅读能力。 什么是AXI AXI 的英文全称是 Advanced eXtensible Interface&#xff0c;即高级可扩展接口&#xff0c;它是 ARM 公司所提…