pytorch实现长短期记忆网络 (LSTM)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMPolicy, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, hidden_state):
        batch_size = x.size(0)

        # 确保 hidden_state 维度正确
        if hidden_state[0].dim() == 2:
            hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),
                            hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))

        out, hidden_state = self.lstm(x, hidden_state)
        out = self.fc(out[:, -1, :])  # 取最后时间步的输出
        action_prob = self.softmax(out)  # 归一化输出,作为策略
        return action_prob, hidden_state

    def init_hidden(self, batch_size=1):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))


# ========== 2. 创建网格环境 ==========
class GridWorld:
    def __init__(self, grid_size=10, goal_position=9):
        self.grid_size = grid_size
        self.goal_position = goal_position
        self.reset()

    def reset(self):
        self.position = 0
        return self.position

    def step(self, action):
        if action == 0:
            self.position = max(0, self.position - 1)
        elif action == 1:
            self.position = min(self.grid_size - 1, self.position + 1)

        reward = 1 if self.position == self.goal_position else -0.1
        done = self.position == self.goal_position
        return self.position, reward, done


# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    optimizer = optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)

        log_probs = []
        rewards = []

        for step in range(max_steps):
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.multinomial(action_probs, 1).item()
            log_prob = torch.log(action_probs.squeeze(0)[action])
            log_probs.append(log_prob)

            next_state, reward, done = env.step(action)
            rewards.append(reward)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 计算回报并更新策略
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (episode + 1) % 50 == 0:
            print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")

    torch.save(policy.state_dict(), "policy.pth")


# 训练智能体
train(500)


# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    policy.load_state_dict(torch.load("policy.pth"))

    plt.figure(figsize=(10, 5))
    best_path = None
    best_steps = float('inf')

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)
        positions = [env.position]  # 记录位置变化

        while True:
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.argmax(action_probs, dim=-1).item()
            next_state, reward, done = env.step(action)
            positions.append(next_state)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 记录最佳路径(最短步数)
        if len(positions) < best_steps:
            best_steps = len(positions)
            best_path = positions

        # 绘制普通路径(蓝色)
        plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,
                 label=f'Episode {episode + 1}' if episode == 0 else "")

    # 绘制最佳路径(红色)
    if best_path:
        plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,
                 label="Best Path")

    # 打印最佳路径
    print(f"Best Path (steps={best_steps}): {best_path}")

    plt.xlabel("Time Steps")
    plt.ylabel("Agent Position")
    plt.title("Agent's Movement Path (Best Path in Red)")
    plt.legend()
    plt.grid(True)
    plt.show()


# 测试并绘制智能体移动路径
test(5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963700.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++:抽象类习题

题目内容&#xff1a; 求正方体、球、圆柱的表面积&#xff0c;抽象出一个公共的基类Container为抽象类&#xff0c;在其中定义一个公共的数据成员radius(此数据可以作为正方形的边长、球的半径、圆柱体底面圆半径)&#xff0c;以及求表面积的纯虚函数area()。由此抽象类派生出…

GEE | 计算Sentinel-2的改进型土壤调整植被指数MSAVI

同学们好&#xff01;今天和大家分享的是 “改进型土壤调整植被指数MSAVI”&#xff0c;它能够更准确地反映植被生长状态&#xff0c;且广泛应用于植被覆盖监测、生态环境评估等领域。 1. MSAVI 改进型土壤调整植被指数&#xff08;MSAVI&#xff09;是一种针对植被覆盖区域土…

deepseek+vscode自动化测试脚本生成

近几日Deepseek大火,我这里也尝试了一下,确实很强。而目前vscode的AI toolkit插件也已经集成了deepseek R1,这里就介绍下在vscode中利用deepseek帮助我们完成自动化测试脚本的实践分享 安装AI ToolKit并启用Deepseek 微软官方提供了一个针对AI辅助的插件,也就是 AI Toolk…

CodeGPT使用本地部署DeepSeek Coder

目前NV和github都托管了DeepSeek&#xff0c;生成Key后可以很方便的用CodeGPT接入。CodeGPT有三种方式使用AI&#xff0c;分别时Agents&#xff0c;Local LLMs&#xff08;本地部署AI大模型&#xff09;&#xff0c;LLMs Cloud Model&#xff08;云端大模型&#xff0c;从你自己…

[c语言日寄]C语言类型转换规则详解

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋&#xff1a;这是一个专注于C语言刷题的专栏&#xff0c;精选题目&#xff0c;搭配详细题解、拓展算法。从基础语法到复杂算法&#xff0c;题目涉及的知识点全面覆盖&#xff0c;助力你系统提升。无论你是初学者&#xff0c;还是…

FPGA 使用 CLOCK_DEDICATED_ROUTE 约束

使用 CLOCK_DEDICATED_ROUTE 约束 CLOCK_DEDICATED_ROUTE 约束通常在从一个时钟区域中的时钟缓存驱动到另一个时钟区域中的 MMCM 或 PLL 时使 用。默认情况下&#xff0c; CLOCK_DEDICATED_ROUTE 约束设置为 TRUE &#xff0c;并且缓存 /MMCM 或 PLL 对必须布局在相同…

Ollama 介绍,搭建本地 AI 大模型 deepseek,并使用 Web 界面调用

Ollama 是一个基于 Go 语言的本地大语言模型运行框架&#xff0c;类 Docker产品&#xff08;支持 list,pull,push,run 等命令&#xff09;&#xff0c;事实上它保留了 Docker 的操作习惯&#xff0c;支持上传大语言模型仓库(有 deepseek、llama 2&#xff0c;mistral&#xff0…

OpenEuler学习笔记(十四):在OpenEuler上搭建.NET运行环境

一、在OpenEuler上搭建.NET运行环境 基于包管理器安装 添加Microsoft软件源&#xff1a;运行命令sudo rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm&#xff0c;将Microsoft软件源添加到系统中&#xff0c;以便后续能够从该源安装.…

【Linux】从硬件到软件了解进程

个人主页~ 从硬件到软件了解进程 一、冯诺依曼体系结构二、操作系统三、操作系统进程管理1、概念2、PCB和task_struct3、查看进程4、通过系统调用fork创建进程&#xff08;1&#xff09;简述&#xff08;2&#xff09;系统调用生成子进程的过程〇提出问题①fork函数②父子进程关…

物联网 STM32【源代码形式-使用以太网】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

物联网&#xff08;IoT&#xff09;‌是指通过各种信息传感器、射频识别技术、全球定位系统、红外感应器等装置与技术&#xff0c;实时采集并连接任何需要监控、连接、互动的物体或过程&#xff0c;实现对物品和过程的智能化感知、识别和管理。物联网的核心功能包括数据采集与监…

【背包问题】二维费用的背包问题

目录 二维费用的背包问题详解 总结&#xff1a; 空间优化&#xff1a; 1. 状态定义 2. 状态转移方程 3. 初始化 4. 遍历顺序 5. 时间复杂度 例题 1&#xff0c;一和零 2&#xff0c;盈利计划 二维费用的背包问题详解 前面讲到的01背包中&#xff0c;对物品的限定条件…

眼见着折叠手机面临崩溃,三星计划增强抗摔能力挽救它

据悉折叠手机开创者三星披露了一份专利&#xff0c;通过在折叠手机屏幕上增加一个抗冲击和遮光层的方式来增强折叠手机的抗摔能力&#xff0c;希望通过这种方式进一步增强折叠手机的可靠性和耐用性&#xff0c;来促进折叠手机的发展。 据悉三星和研发可折叠玻璃的企业的做法是在…

首发!ZStack 智塔支持 DeepSeek V3/R1/ Janus Pro,多种国产 CPU/GPU 可私有化部署

2025年2月2日&#xff0c;针对日益强劲的AI推理需求和企业级AI应用私有化部署场景&#xff08;Private AI&#xff09;&#xff0c;云轴科技 ZStack 宣布 AI Infra 平台 ZStack 智塔全面支持企业私有化部署 DeepSeek V3/R1/ Janus Pro三种模型&#xff0c;并可基于海光、昇腾、…

25寒假算法刷题 | Day1 | LeetCode 240. 搜索二维矩阵 II,148. 排序链表

目录 240. 搜索二维矩阵 II题目描述题解 148. 排序链表题目描述题解 240. 搜索二维矩阵 II 点此跳转题目链接 题目描述 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性&#xff1a; 每行的元素从左到右升序排列。每列的元素从上到…

it基础使用--5---git远程仓库

文章目录 it基础使用--5---git远程仓库1. 按顺序看2. 什么是远程仓库3. Gitee操作3.1 新建远程仓库3.2 远程操作基础命令3.3 查看当前所有远程地址别名 git remote -v3.4 创建远程仓库别名 git remote add 别名 远程地址3.4 推送本地分支到远程仓库 git push 别名 分支3.5 拉取…

SpringBoot 整合 Mybatis:注解版

第一章&#xff1a;注解版 导入配置&#xff1a; <groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>1.3.1</version> </dependency> 步骤&#xff1a; 配置数据源见 Druid…

海思ISP开发说明

1、概述 ISP&#xff08;Image Signal Processor&#xff09;图像信号处理器是专门用于处理图像信号的硬件或处理单元&#xff0c;广泛应用于图像传感器&#xff08;如 CMOS 或 CCD 传感器&#xff09;与显示设备之间的信号转换过程中。ISP通过一系列数字图像处理算法完成对数字…

基于springboot私房菜定制上门服务系统设计与实现(源码+数据库+文档)

私房菜定制上门服务系统目录 目录 基于springbootvue私房菜定制上门服务系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员功能实现 &#xff08;1&#xff09;菜品管理 &#xff08;2&#xff09;公告管理 &#xff08;3&#xff09; 厨师管理 2、用…

SpringBoot 整合 SpringMVC:配置嵌入式服务器

修改和 server 相关的配置(ServerProperties)&#xff1a; server.port8081 server.context‐path/tx server.tomcat.uri‐encodingUTF‐8 注册 Servlet 三大组件&#xff1a;Servlet、Fileter、Listener SpringBoot 默认是以 jar 包的方式启动嵌入式的 Servlet 容器来启动 Spr…

如何实现滑动网格的功能

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了SliverList组件相关的内容&#xff0c;本章回中将介绍SliverGrid组件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1 概念介绍 我们在本章回中介绍的SliverGrid组件是一种网格类组件&#xff0c;主要用来…