pytorch实现门控循环单元 (GRU)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt

# 🏁 迷宫环境(5×5)
class MazeEnv:
    def __init__(self, size=5):
        self.size = size
        self.state = (0, 0)  # 起点
        self.goal = (size-1, size-1)  # 终点
        self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上
    
    def reset(self):
        self.state = (0, 0)  # 重置起点
        return self.state

    def step(self, action):
        dx, dy = self.actions[action]
        x, y = self.state
        nx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))
        
        reward = 1 if (nx, ny) == self.goal else -0.1
        done = (nx, ny) == self.goal
        
        self.state = (nx, ny)
        return (nx, ny), reward, done

# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUPolicy, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])  # 只取最后时间步
        return out, hidden

# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径

for episode in range(num_episodes):
    state = env.reset()
    hidden = torch.zeros(1, 1, 16)  # GRU 初始状态
    states, actions, rewards = [], [], []
    logits_list = []  

    for _ in range(20):  # 最多 20 步
        state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)
        logits, hidden = policy(state_tensor, hidden)
        logits_list.append(logits)

        # ε-greedy 策略
        if random.random() < epsilon:
            action = random.choice(range(4))  # 随机选择动作
        else:
            action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作

        next_state, reward, done = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)

        if done:
            print(f"Episode {episode} - Reached Goal!")
            # 找到最优路径
            best_path = states + [next_state]  # 当前 episode 的路径
            break
        state = next_state

    # 计算损失
    logits = torch.cat(logits_list, dim=0)  # (T, 4)
    action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)
    loss = loss_fn(logits, action_tensor)  

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 衰减 ε
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    if episode % 100 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")

# 🧐 确保 best_path 已经记录
if len(best_path) == 0:
    print("No path found during training.")
else:
    print(f"Best path: {best_path}")

# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))

# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")

# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)

# 画出最佳路径(红色)
for (x, y) in best_path:
    ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))

# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")

plt.title("GRU RL Agent - Best Path")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964525.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【赵渝强老师】Spark RDD的依赖关系和任务阶段

Spark RDD彼此之间会存在一定的依赖关系。依赖关系有两种不同的类型&#xff1a;窄依赖和宽依赖。 窄依赖&#xff1a;如果父RDD的每一个分区最多只被一个子RDD的分区使用&#xff0c;这样的依赖关系就是窄依赖&#xff1b;宽依赖&#xff1a;如果父RDD的每一个分区被多个子RD…

为何在Kubernetes容器中以root身份运行存在风险?

作者:马辛瓦西奥内克(Marcin Wasiucionek) 引言 在Kubernetes安全领域,一个常见的建议是让容器以非root用户身份运行。但是,在容器中以root身份运行,实际会带来哪些安全隐患呢?在Docker镜像和Kubernetes配置中,这一最佳实践常常被重点强调。在Kubernetes清单文件中,…

Docker 安装详细教程(适用于CentOS 7 系统)

目录 步骤如下&#xff1a; 1. 卸载旧版 Docker 2. 配置 Docker 的 YUM 仓库 3. 安装 Docker 4. 启动 Docker 并验证安装 5. 配置 Docker 镜像加速 总结 前言 Docker 分为 CE 和 EE 两大版本。CE即社区版&#xff08;免费&#xff0c;支持周期7个月&#xff09;&#xf…

[MRCTF2020]Ez_bypass1(md5绕过)

[MRCTF2020]Ez_bypass1(md5绕过) ​​ 这道题就是要绕过md5强类型比较&#xff0c;但是本身又不相等&#xff1a; md5无法处理数组&#xff0c;如果传入的是数组进行md5加密&#xff0c;会直接放回NULL&#xff0c;两个NuLL相比较会等于true&#xff1b; 所以?id[]1&gg…

Deep Crossing:深度交叉网络在推荐系统中的应用

实验和完整代码 完整代码实现和jupyter运行&#xff1a;https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main 引言 在机器学习和深度学习领域&#xff0c;特征工程一直是一个关键步骤&#xff0c;尤其是对于大规模的推荐系统和广告点击率预…

C++多线程编程——基于策略模式、单例模式和简单工厂模式的可扩展智能析构线程

1. thread对象的析构问题 在 C 多线程标准库中&#xff0c;创建 thread 对象后&#xff0c;必须在对象析构前决定是 detach 还是 join。若在 thread 对象销毁时仍未做出决策&#xff0c;程序将会终止。 然而&#xff0c;在创建 thread 对象后、调用 join 前的代码中&#xff…

Rust中使用ORM框架diesel报错问题

1 起初环境没有问题&#xff1a;在Rust开发的时候起初使用的是mingw64平台加stable-x86_64-pc-windows-gnu编译链&#xff0c;当使用到diesel时会报错&#xff0c;如下&#xff1a; x86_64-w64-mingw32/bin/ld.exe: cannot find -lmysql具体信息很长这是主要信息是rust找不到链…

Fastdds学习分享_xtpes_发布订阅模式及rpc模式

在之前的博客中我们介绍了dds的大致功能&#xff0c;与组成结构。本篇博文主要介绍的是xtypes.分为理论和实际运用两部分.理论主要用于梳理hzy大佬的知识&#xff0c;对于某些一带而过的部分作出更为详细的阐释&#xff0c;并在之后通过实际案例便于理解。案例分为普通发布订阅…

Windows:AList+RaiDrive挂载阿里云盘至本地磁盘

零、前言 电脑存储的文件多了&#xff0c;出现存储空间不够用的情况。又没前买新的硬盘或者笔记本电脑没有额外的插槽提供给新的硬盘。遇到这种情况&#xff0c;我想到可以使用网盘&#xff0c;但单纯的网盘又要上传下载&#xff0c;极其麻烦。看到WPS云盘可以直接挂载本地&…

Unity游戏(Assault空对地打击)开发(3) 摄像机的控制

详细步骤 打开My Assets或者Package Manager。 选择Unity Registry。 搜索Cinemachine&#xff0c;找到 Cinemachine包&#xff0c;点击 Install按钮进行安装。 关闭窗口&#xff0c;新建一个FreeLook Camera&#xff0c;如下。 接着新建一个对象Pos&#xff0c;拖到Player下面…

保姆级教程Docker部署Zookeeper官方镜像

目录 1、安装Docker及可视化工具 2、创建挂载目录 3、运行Zookeeper容器 4、Compose运行Zookeeper容器 5、查看Zookeeper运行状态 6、验证Zookeeper是否正常运行 1、安装Docker及可视化工具 Docker及可视化工具的安装可参考&#xff1a;Ubuntu上安装 Docker及可视化管理…

力扣73矩阵置零

给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 输入&#xff1a;matrix [[1,1,1],[1,0,1],[1,1,1]] 输出&#xff1a;[[1,0,1],[0,0,0],[1,0,1]] 输入&#xff1a;matrix [[0,1,2,0],[3,4,5,2],[…

【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据

说明&#xff1a; 该专栏之前的文章中python脚本使用的是ArcMap10.6自带的arcpy&#xff08;好几年前的文章&#xff09;&#xff0c;从本篇开始使用的是ArcGIS Pro 3.3版本自带的arcpy&#xff0c;需要注意不同版本对应的arcpy函数是存在差异的 数据准备&#xff1a;准备一个…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.19 线性代数核武器:BLAS/LAPACK深度集成

2.19 线性代数核武器&#xff1a;BLAS/LAPACK深度集成 目录 #mermaid-svg-yVixkwXWUEZuu02L {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yVixkwXWUEZuu02L .error-icon{fill:#552222;}#mermaid-svg-yVixkwXWUEZ…

记8(高级API实现手写数字识别

目录 1、Keras&#xff1a;2、Sequential模型&#xff1a;2.1、建立Sequential模型&#xff1a;modeltf.keras.Sequential()2.2、添加层&#xff1a;model.add(tf.keras.layers.层)2.3、查看摘要&#xff1a;model.summary()2.4、配置训练方法&#xff1a;model.compile(loss,o…

RK3568中使用QT opencv(显示基础图像)

文章目录 一、查看对应的开发环境是否有opencv的库二、QT使用opencv一、查看对应的开发环境是否有opencv的库 在开发板中的/usr/lib目录下查看是否有opencv的库: 这里使用的是正点原子的ubuntu虚拟机,在他的虚拟机里面已经安装好了opencv的库。 二、QT使用opencv 在QT pr…

el-table表格点击单元格实现编辑

使用 el-table 和 el-table-column 创建表格。在单元格的默认插槽中&#xff0c;使用 div 显示文本内容&#xff0c;单击时触发编辑功能。使用 el-input 组件在单元格中显示编辑框。data() 方法中定义了 tableData&#xff0c;tabClickIndex: null,tabClickLabel: ,用于判断是否…

idea隐藏无关文件

idea隐藏无关文件 如果你想隐藏某些特定类型的文件&#xff08;例如 .log 文件或 .tmp 文件&#xff09;&#xff0c;可以通过以下步骤设置&#xff1a; 打开设置 在菜单栏中选择 File > Settings&#xff08;Windows/Linux&#xff09;或 IntelliJ IDEA > Preference…

备考蓝桥杯嵌入式4:使用LCD显示我们捕捉的PWM波

上一篇博客我们提到了定时器产生PWM波&#xff0c;现在&#xff0c;我们尝试的想要捕获我们的PWM波&#xff0c;测量它的频率&#xff0c;我们应该怎么做呢&#xff1f;答案还是回到我们的定时器上。 我们知道&#xff0c;定时器是一个高级的秒表&#xff08;参考笔者的比喻&a…

ChatGPT-4o和ChatGPT-4o mini的差异点

在人工智能领域&#xff0c;OpenAI再次引领创新潮流&#xff0c;近日正式发布了其最新模型——ChatGPT-4o及其经济实惠的小型版本ChatGPT-4o Mini。这两款模型虽同属于ChatGPT系列&#xff0c;但在性能、应用场景及成本上展现出显著的差异。本文将通过图文并茂的方式&#xff0…