【强化学习】——Q-learning算法为例入门Pytorch强化学习

🤵‍♂️ 个人主页:@Lingxw_w的个人主页

✍🏻作者简介:计算机研究生在读,研究方向复杂网络和数据挖掘,阿里云专家博主,华为云云享专家,CSDN专家博主、人工智能领域优质创作者,安徽省优秀毕业生
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+ 

目录

1、强化学习是什么

1.1 定义

1.2 基本组成

1.3 马尔可夫决策过程

2、强化学习的应用

3、常见的强化学习算法

3.1 Q-learning算法

3.2 Q-learning的算法步骤

3.3 Pytorch代码实现


1、强化学习是什么

1.1 定义

强化学习(Reinforcement Learning,RL)是一种机器学习方法,其目标是通过智能体(Agent)与环境的交互学习最优行为策略,以使得智能体能够在给定环境中获得最大的累积奖励。

强化学习在许多领域都有应用,例如机器人控制、游戏智能、自动驾驶、资源管理等。通过与环境的交互和试错学习,强化学习使得智能体能够在复杂、不确定的环境中做出优化的决策,并逐步提升性能。

1.2 基本组成

强化学习的基本组成部分包括:

  1. 智能体(Agent):在强化学习中,智能体是学习和决策的主体,它通过与环境的交互来获取知识和经验,并根据获得的奖励信号进行学习和优化。

  2. 环境(Environment):环境是智能体所处的外部世界,它可以是真实的物理环境,也可以是虚拟的模拟环境。智能体通过观察环境的状态,执行动作,并接收来自环境的奖励或惩罚信号。

  3. 状态(State):状态表示环境的某个特定时刻的观察或描述,它包含了智能体需要的所有信息来做出决策。

  4. 动作(Action):动作是智能体在某个状态下采取的行为,它会对环境产生影响并导致状态的转换。

  5. 奖励(Reward):奖励是环境根据智能体的行为给予的反馈信号,用于指导智能体学习合适的策略。奖励可以是正数(奖励)也可以是负数(惩罚),智能体的目标是最大化累积奖励。

1.3 马尔可夫决策过程

(Markov Decision Process,MDP)强化学习中常用的建模框架,用于描述具有马尔可夫性质的序贯决策问题。它是基于马尔可夫链(Markov Chain)和决策理论的组合。

在马尔可夫决策过程中,智能体与环境交互,通过采取一系列动作来影响环境的状态和获得奖励。MDP的关键特点是马尔可夫性质,即当前状态的信息足以决定未来状态的转移概率。这意味着在MDP中,未来的状态和奖励仅取决于当前状态和采取的动作,而与过去的状态和动作无关。 

2、强化学习的应用

强化学习旨在解决以下类型的问题:

  1. 决策问题:强化学习可以用于解决需要做出一系列决策的问题。例如,自动驾驶车辆需要在不同交通情况下选择合适的行驶策略,智能机器人需要学习在复杂环境中执行任务的最佳策略。

  2. 控制问题:强化学习可用于控制系统的优化。例如,通过学习最优策略来调整电力网格的能源分配,或者在金融投资中确定最佳的投资组合。

  3. 资源管理:强化学习可以应用于资源管理问题,如动态网络管理、数据中心的负载平衡、无线通信中的频谱分配等。智能体可以通过与环境的交互来学习如何最优地利用和分配有限的资源。

  4. 序列决策问题:强化学习适用于需要在连续时间步骤中做出决策的问题。例如,在自然语言处理中,可以使用强化学习来训练智能体生成合适的文本回复,或者在推荐系统中根据用户行为动态调整推荐策略。

  5. 探索与开发:强化学习可以用于探索未知环境和发现新知识。通过与环境的交互,智能体可以通过试错学习来积累经验并发现最优策略。

3、常见的强化学习算法

  • Q-learning:一种基于值函数(Q函数)的强化学习算法,通过迭代更新Q值来学习最优策略。
  • SARSA:另一种基于值函数的强化学习算法,与Q-learning类似,但在更新Q值时采用了一种“状态-动作-奖励-下一状态-下一动作(State-Action-Reward-State-Action)”的更新策略。
  • 策略梯度(Policy Gradient):一类直接学习策略函数的方法,通过优化策略函数的参数来提高智能体的性能。
  • 深度强化学习(Deep Reinforcement Learning):将深度学习方法与强化学习相结合,利用神经网络来表示值函数或策略函数,以解决具有高维状态空间的复杂任务。

3.1 Q-learning算法

Q-learning是一种经典的强化学习算法,用于解决马尔可夫决策过程(Markov Decision Process,MDP)的问题。它是基于值函数的强化学习算法,通过迭代地更新Q值来学习最优策略。

在Q-learning中,智能体与环境的交互过程由状态、动作、奖励下一个状态组成。智能体根据当前状态选择一个动作,与环境进行交互,接收到下一个状态和相应的奖励。Q-learning的目标是学习一个Q值函数,它估计在给定状态下采取特定动作所获得的长期累积奖励。

Q值函数表示为Q(s, a),其中s是状态,a是动作。初始时,Q值可以初始化为任意值。Q-learning使用贝尔曼方程(Bellman equation)来更新Q值,以逐步逼近最优的Q值函数:

Q(s, a) = Q(s, a) + α * (r + γ * max[Q(s', a')] - Q(s, a))

在上述方程中,α是学习率(learning rate),决定了每次更新的幅度;r是当前状态下执行动作a所获得的奖励;γ是折扣因子(discount factor),用于权衡即时奖励和未来奖励的重要性;s'是下一个状态,a'是在下一个状态下的最优动作。

3.2 Q-learning的算法步骤

  1. 初始化Q值函数。
  2. 在每个时间步骤中,根据当前状态选择一个动作。
  3. 执行动作,观察奖励和下一个状态。
  4. 根据贝尔曼方程更新Q值函数。
  5. 重复2-4步骤,直到达到预定的停止条件或收敛。

通过多次迭代更新Q值函数,Q-learning最终能够收敛到最优的Q值函数。智能体可以根据Q值函数选择具有最高Q值的动作作为策略,以实现最优的行为决策。

Q-learning是一种基于模型的强化学习方法,不需要对环境的模型进行显式建模,适用于离散状态空间和动作空间的问题。对于连续状态和动作空间的问题,可以通过函数逼近方法(如深度Q网络)来扩展Q-learning算法。

3.3 Pytorch代码实现

基于PyTorch的Q-learning算法来解决OpenAI Gym中的CartPole环境。

首先,导入所需的库,包括gym用于创建环境,random用于随机选择动作,以及torchtorch.nn用于构建和训练神经网络。

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

定义了一个Q网络(QNetwork)作为强化学习算法的近似函数。该网络具有三个全连接层,其中前两个层使用ReLU激活函数,最后一层输出动作值。forward方法用于定义网络的前向传播。

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

接下来,定义了一个QLearningAgent类。在初始化中,指定了状态维度、动作维度、折扣因子和探索率等超参数。同时创建了两个Q网络q_networktarget_networktarget_network用于计算目标Q值。还定义了优化器和损失函数。 

# Q-learning算法
class QLearningAgent:
    def __init__(self, state_dim, action_dim, gamma, epsilon):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率

        # 初始化Q网络和目标网络
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()

        self.optimizer = optim.Adam(self.q_network.parameters())
        self.loss_fn = nn.MSELoss()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.FloatTensor(state)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()

    def train(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return

        # 从回放缓存中采样一个小批量样本
        samples = random.sample(replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # 计算当前状态的Q值
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # 计算下一个状态的Q值
        next_q_values = self.target_network(next_states).max(1)[0]
        expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # 计算损失并更新Q网络
        loss = self.loss_fn(q_values, expected_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

select_action方法用于根据当前状态选择动作。以epsilon的概率选择随机动作,以探索环境;以1-epsilon的概率选择基于当前Q值的最优动作。

# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

train方法用于训练Q网络。它从回放缓存中采样一个小批量样本,并计算当前状态和下一个状态的Q值。然后计算损失并进行优化。

接下来,创建CartPole环境并获取状态和动作的维度。

然后,实例化一个QLearningAgent对象,并设置相关的超参数。

接下来,进行训练循环。在每个回合中,重置环境,然后在每个时间步中执行以下步骤:

  1. 根据当前状态选择一个动作。
  2. 执行动作,观察下一个状态、奖励和终止信号。
  3. 将状态、动作、奖励、下一个状态和终止信号存储在回放缓存中。
  4. 调用agent的train方法进行网络训练。

每隔一定的回合数,通过update_target_network方法更新目标网络的权重。

# 创建Q-learning智能体
agent = QLearningAgent(state_dim, action_dim, gamma=0.99, epsilon=0.2)

# 训练
replay_buffer = []
episodes = 1000
batch_size = 32

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        agent.train(replay_buffer, batch_size)

    if episode % 10 == 0:
        agent.update_target_network()
        print(f"Episode: {episode}, Total Reward: {total_reward}")

最后,使用训练好的智能体进行测试。在测试过程中,根据当前状态选择动作,并执行动作,直到终止信号出现。同时可通过env.render()方法显示环境的图形界面。

# 使用训练好的智能体进行测试
state = env.reset()
done = False
total_reward = 0

while not done:
    env.render()
    action = agent.select_action(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward

print(f"Test Total Reward: {total_reward}")

env.close()

代码执行完毕后,关闭环境并显示测试的总奖励。

总体而言,这段代码实现了基于PyTorch的Q-learning算法,并将其应用于CartPole环境。通过训练,智能体可以学习到一个最优策略,使得杆子保持平衡的时间尽可能长。

汇总的代码:

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Q-learning算法
class QLearningAgent:
    def __init__(self, state_dim, action_dim, gamma, epsilon):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率

        # 初始化Q网络和目标网络
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()

        self.optimizer = optim.Adam(self.q_network.parameters())
        self.loss_fn = nn.MSELoss()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.FloatTensor(state)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()

    def train(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return

        # 从回放缓存中采样一个小批量样本
        samples = random.sample(replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # 计算当前状态的Q值
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # 计算下一个状态的Q值
        next_q_values = self.target_network(next_states).max(1)[0]
        expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # 计算损失并更新Q网络
        loss = self.loss_fn(q_values, expected_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# 创建Q-learning智能体
agent = QLearningAgent(state_dim, action_dim, gamma=0.99, epsilon=0.2)

# 训练
replay_buffer = []
episodes = 1000
batch_size = 32

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        agent.train(replay_buffer, batch_size)

    if episode % 10 == 0:
        agent.update_target_network()
        print(f"Episode: {episode}, Total Reward: {total_reward}")

# 使用训练好的智能体进行测试
state = env.reset()
done = False
total_reward = 0

while not done:
    env.render()
    action = agent.select_action(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward

print(f"Test Total Reward: {total_reward}")

env.close()

 相关博客专栏订阅链接

【机器学习】——房屋销售的探索性数据分析

【机器学习】——数据清理、数据变换、特征工程

【机器学习】——决策树、线性模型、随机梯度下降

【机器学习】——多层感知机、卷积神经网络、循环神经网络

【机器学习】——模型评估、过拟合和欠拟合、模型验证

【机器学习】——模型调参、超参数优化、网络架构搜索

【机器学习】——方差和偏差、Bagging、Boosting、Stacking

【机器学习】——模型调参、超参数优化、网络架构搜索

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/31787.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【30天熟悉Go语言】8 Go流程控制之循环结构for range、goto、break、continue

文章目录 一、前言二、for循环1、语法1&#xff09;和Java的for循环一样2&#xff09;和Java的while一样3&#xff09;和Java的for(;;)一样 2、for语句执行过程 三、for range1、语法1&#xff09;遍历key、value只遍历value 2&#xff09;遍历key 四、关键字1、break1&#xf…

【Java】如何优雅的关闭线程池

文章目录 背景一、线程中断 interrupt二、线程池的关闭 shutdown 方法2.1、第一步&#xff1a;advanceRunState(SHUTDOWN) 把线程池置为 SHUTDOWN2.2、第二步&#xff1a;interruptIdleWorkers() 把空闲的工作线程置为中断2.3、 第三步&#xff1a;onShutdown() 一个空实现&…

Java POI (1)—— 数据读写操作快速入门

一、Excel的版本区别&#xff08;03版和07版&#xff09; 所谓“03版” 和 “07版”&#xff0c;指的是 Microsoft Excel 版本号。这些版本号代表着不同的Excel 文件格式。2003版 Excel 使用的文件格式为 .xls&#xff0c;而2007版开始使用新的文件格式 .xlsx。 . xlsx 文件格式…

【Spring 】项目创建和使用

哈喽&#xff0c;哈喽&#xff0c;大家好~ 我是你们的老朋友&#xff1a;保护小周ღ 谈起Java 圈子里的框架&#xff0c;最年长最耀眼的莫过于 Spring 框架啦&#xff0c;如今已成为最流行、最广泛使用的Java开发框架之一。不知道大家有没有在使用 Spring 框架的时候思考过这…

VulnHub靶机渗透:SKYTOWER: 1

SKYTOWER: 1 靶机环境介绍nmap扫描端口扫描服务扫描漏洞扫描总结 80端口目录爆破 3128端口获取立足点获取立足点2提权总结 靶机环境介绍 https://www.vulnhub.com/entry/skytower-1,96/ 靶机IP&#xff1a;192.168.56.101 kali IP&#xff1a;192.168.56.102 nmap扫描 端口扫…

使用mpi并行技术实现wordcount算法

【问题描述】 编写程序统计一个英文文本文件中每个单词的出现次数&#xff08;词频统计&#xff09;&#xff0c;并将统计结果按单词字典序输出到屏幕上。 注&#xff1a;在此单词为仅由字母组成的字符序列。包含大写字母的单词应将大写字母转换为小写字母后统计。 【输入形…

ChatGPT使用的SSE技术是什么?

在现代web应用程序中&#xff0c;实时通信变得越来越重要。HTTP协议的传统请求/响应模式总是需要定期进行轮询以获得最新的数据&#xff0c;这种方式效率低下并且浪费资源。因此&#xff0c;出现了一些新的通信技术&#xff0c;如WebSocket和SSE。但是&#xff0c;GPT为什么选择…

分布式数据库架构

分布式数据库架构 1、MySQL常见架构设计 对于mysql架构&#xff0c;一定会使用到读写分离&#xff0c;在此基础上有五种常见架构设计&#xff1a;一主一从或多从、主主复制、级联复制、主主与级联复制结合。 1.1、主从复制 这种架构设计是使用的最多的。在读写分离的基础上…

JS 介绍 Babel 的使用及 presets plugins 的概念

一、Babel 是什么 Bebal 可以帮助我们将新 JS 语法编译为可执行且兼容旧浏览器版本的一款编译工具。 举个例子&#xff0c;ES6&#xff08;编译前&#xff09;&#xff1a; const fn () > {};ES5&#xff08;编译后&#xff09;&#xff1a; var fn function() {}二、B…

设计模式-抽象工厂模式

抽象工厂模式 1、抽象工厂模式简介2、具体实现 1、抽象工厂模式简介 抽象工厂模式(Abstract Factory Pattern)在工厂模式尚添加了一个创建不同工厂的抽象接口(抽象类或接口实现)&#xff0c;该接口可叫做超级工厂。在使用过程中&#xff0c;我们首先通过抽象接口创建不同的工厂…

【HTML界面设计(二)】说说模块、登录界面

记录很早之前写的前端界面&#xff08;具体时间有点久远&#xff09; 一、说说模板 采用 适配器&#xff08;Adapter&#xff09;原理 来设计这款说说模板&#xff0c;首先看一下完整效果 这是demo样图&#xff0c;需要通过业务需求进行修改的部分 这一部分&#xff0c;就是dem…

Redis系列--布隆过滤器(Bloom Filter)

一、前言 在实际开发中&#xff0c;会遇到很多要判断一个元素是否在某个集合中的业务场景&#xff0c;类似于垃圾邮件的识别&#xff0c;恶意ip地址的访问&#xff0c;缓存穿透等情况。类似于缓存穿透这种情况&#xff0c;有许多的解决方法&#xff0c;如&#xff1a;redis存储…

宏景eHR SQL注入漏洞复现(CNVD-2023-08743)

0x01 产品简介 宏景eHR人力资源管理软件是一款人力资源管理与数字化应用相融合&#xff0c;满足动态化、协同化、流程化、战略化需求的软件。 0x02 漏洞概述 宏景eHR 存在SQL注入漏洞&#xff0c;未经过身份认证的远程攻击者可利用此漏洞执行任意SQL指令&#xff0c;从而窃取数…

如何在大规模服务中迁移缓存

当您启动初始服务时&#xff0c;通常会过度设计以考虑大量流量。但是&#xff0c;当您的服务达到爆炸式增长阶段&#xff0c;或者如果您的服务请求和处理大量流量时&#xff0c;您将需要重新考虑您的架构以适应它。糟糕的系统设计导致难以扩展或无法满足处理大量流量的需求&…

docker基础

文章目录 通过Vagrant安装虚拟机修改虚拟机网络配置 docker CE安装(在linux上)docker desktop安装(在MacOS上)Docker架构关于-阿里云镜像加速服务配置centos卸载docker 官网: http://www.docker.com 仓库: https://hub.docker.com Docker安装在虚拟机上&#xff0c;可以通过V…

Go语言的TCP和HTTP网络服务基础

目录 【TCP Socket 编程模型】 Socket读操作 【HTTP网络服务】 HTTP客户端 HTTP服务端 TCP/IP 网络模型实现了两种传输层协议&#xff1a;TCP 和 UDP&#xff0c;其中TCP 是面向连接的流协议&#xff0c;为通信的两端提供稳定可靠的数据传输服务&#xff1b;UDP 提供了一种…

[MySQL]不就是SQL语句

前言 本期主要的学习目标是SQl语句中的DDL和DML实现对数据库的操作和增删改功能&#xff0c;学习完本章节之后需要对SQL语句手到擒来。 1.SQL语句基本介绍 SQL&#xff08;Structured Query Language&#xff09;是一种用于管理关系型数据库的编程语言。它允许用户在数据库中存…

双因素身份验证在远程访问中的重要性

在快速发展的数字环境中&#xff0c;远程访问计算机和其他设备已成为企业运营的必要条件。无论是在家庭办公室运营的小型初创公司&#xff0c;还是团队分散在全球各地的跨国公司&#xff0c;远程访问解决方案都能保证工作效率和连接性&#xff0c;能够跨越距离和时间的阻碍。 …

7Z045 引脚功能详解

本文针对7Z045芯片&#xff0c;详细讲解硬件设计需要注意的技术点&#xff0c;可以作为设计和检查时候的参考文件。问了方便实用&#xff0c;按照Bank顺序排列&#xff0c;包含配置Bank、HR Bank、HP Bank、GTX Bank、供电引脚等。 参考文档包括&#xff1a; ds191-XC7Z030-X…

怎么计算 flex-shrink 的缩放尺寸

计算公式: 子元素的宽度 - (子元素的宽度的总和 - 父盒子的宽度) * (某个元素的flex-shrink / flex-shrink总和) 面试问题是这样的下面 left 和 right 的宽度分别是多少 * {padding: 0;margin: 0;}.container {width: 500px;height: 300px;display: flex;}.left {width: 500px…