使用 PyTorch 构建 LSTM 股票价格预测模型

目录

      • 引言
      • 准备工作
      • 1. 训练模型(`train.py`)
      • 2. 模型定义(`model.py`)
      • 3. 测试模型和可视化(`test.py`)
      • 使用说明
      • 模型调整
      • 结论

引言

在金融领域,股票价格预测是一个重要且具有挑战性的任务。随着深度学习的发展,长短期记忆网络(LSTM)因其在处理时间序列数据方面的出色表现而受到关注。本篇博客将指导你如何使用PyTorch构建一个LSTM模型来预测股票价格,我们将逐步介绍数据预处理、模型训练和结果可视化的完整流程。

准备工作

  1. 安装依赖
    确保你已经安装了以下 Python 库:

    pip install pandas numpy torch matplotlib scikit-learn
    
  2. 下载数据
    使用 yfinance 库下载你感兴趣的股票的历史数据,并保存为 CSV 文件。我们这里使用 Apple(AAPL)过去五年的数据,文件命名为 AAPL_5y_data.csv。以下是一个下载数据的代码示例:

    import yfinance as yf
    
    # 下载Apple股票过去5年的数据
    data = yf.download('AAPL', start='2019-01-01', end='2024-01-01')
    data.to_csv('AAPL_5y_data.csv')
    

1. 训练模型(train.py

在这个脚本中,我们将读取 CSV 文件,归一化数据,并使用 LSTM 模型进行训练。

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类

# 设置随机种子
torch.manual_seed(42)

# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)

# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)

# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values

# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)

# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]

def create_dataset(data, time_step=1):
    X, y = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        X.append(a)
        y.append(data[i + time_step, 0])  # 预测收盘价
    return np.array(X), np.array(y)

# 创建数据集
time_step = 50  # 时间步长
X_train, y_train = create_dataset(train_data, time_step)

# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float().view(-1, 1)

# 初始化模型、损失函数和优化器
model = LSTM()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# 训练模型
num_epochs = 300
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')
print("模型已保存为 'lstm_model.pth'")

2. 模型定义(model.py

在这个文件中定义 LSTM 模型结构。

import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=5, hidden_size=100, num_layers=2, batch_first=True)
        self.fc = nn.Linear(100, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # 取最后时间步的输出
        return out

3. 测试模型和可视化(test.py

在这个脚本中,我们将加载训练好的模型,并使用测试数据进行预测和可视化。

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类

# 设置字体为SimHei,用于显示中文
plt.rcParams['font.family'] = 'SimHei'

# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)

# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)

# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values

# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)

# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]

def create_dataset(data, time_step=1):
    X, y = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        X.append(a)
        y.append(data[i + time_step, 0])  # 预测收盘价
    return np.array(X), np.array(y)

# 创建测试数据集
time_step = 50  # 时间步长
X_test, y_test = create_dataset(test_data, time_step)

# 转换为PyTorch张量
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float().view(-1, 1)

# 加载模型
model = LSTM()
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()

# 测试模型
with torch.no_grad():
    test_outputs = model(X_test)
    # test_outputs 是预测的收盘价,将其重新归一化为原始价格
    test_outputs = scaler.inverse_transform(np.concatenate((test_outputs.numpy(), np.zeros((test_outputs.shape[0], 4))), axis=1))[:, 0]  # 反归一化收盘价
    y_test_inverse = scaler.inverse_transform(np.concatenate((y_test.numpy(), np.zeros((y_test.shape[0], 4))), axis=1))[:, 0]

# 可视化结果
plt.figure(figsize=(14, 7))
plt.plot(data.index[-len(y_test):], y_test_inverse, label='真实价格', color='blue')
plt.plot(data.index[-len(test_outputs):], test_outputs, label='预测价格', color='red')
plt.title('股票价格预测')
plt.xlabel('日期')
plt.ylabel('价格')
plt.legend()
plt.show()

使用说明

  1. 保存脚本

    • 将训练脚本代码保存为 train.py
    • 将模型定义代码保存为 model.py
    • 将测试脚本代码保存为 test.py
  2. 运行训练

    • 在命令行中运行训练脚本:
      python train.py
      
    • 训练完成后,模型将保存为 lstm_model.pth
  3. 运行测试和可视化

    • 在命令行中运行测试脚本:

      python test.py
      
    • 这将加载已训练的模型,并可视化预测结果。
      在这里插入图片描述
      这只是一个演示,模型的预测效果还有待进一步优化。

模型调整

如果预测的价格和真实价格差距较大,可能是由于以下几个原因:

  1. 数据规模不足

    • 如果训练数据不足,模型可能无法学到市场的长期趋势。
    • 改进:使用更多的历史数据,尽量包括多年的数据。可以尝试增加数据的时间跨度。
  2. 数据预处理问题

    • 数据没有正确归一化,或归一化范围过窄。
    • 改进:检查 MinMaxScaler 的应用。你可以尝试不同的归一化范围,例如 (0, 1)(-1, 1),也可以使用其他标准化方法(例如 StandardScaler)。
  3. 模型复杂度不足

    • 模型的层数或隐藏单元数量可能不足以捕捉数据的复杂性。
    • 改进:增加 LSTM 的隐藏层数量或隐藏单元数量。你还可以考虑添加其他类型的层,例如卷积层(CNN)或全连接层,以提高模型的表达能力。
  4. 超参数调整

    • 学习率、批大小和时间步长等超参数可能需要调整以优化模型性能。
    • 改进:尝试不同的学习率(例如,0.001、0.0001 等)、不同的批大小(如 16、32、64)和时间步长(如 30、60)。
  5. 更改损失函数

    • 在某些情况下,使用不同的损失函数可能有助于模型的收敛。
    • 改进:可以尝试使用其他损失函数,例如 Huber 损失函数(nn.SmoothL1Loss)或自定义损失函数,以更好地适应数据。

结论

通过使用 PyTorch 构建 LSTM 模型,我们成功地实现了股票价格的预测。在这个过程中,我们学习了如何处理时间序列数据,构建和训练深度学习模型,以及如何评估和可视化预测结果。尽管模型的性能可能需要进一步的优化和调整,但这个示例为未来的工作奠定了基础。

希望这篇博客能够帮助你在股票价格预测方面取得更好的成果。欢迎分享你的成果和经验,或者提出你的问题!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/899540.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

1024软件推荐-rubick

开源的插件化桌面端效率工具箱。插件是基于 npm 进行安装和卸载,非常轻便。插件数据支持 webdav 多端同步,非常安全。支持内网部署,可二次定制化开发,非常灵活。 前言 rubick 之前的插件管理,依托于云服务器存储&…

滴水逆向三期笔记与作业——02C语言——13 指针(3)(4)

滴水逆向三期笔记与作业——02C语言——13 指针3、4 一、模拟实现CE的数据搜索功能 OneNote迁移 一、模拟实现CE的数据搜索功能 //其中有0xAA,超过有符号char范围,在vscode中会报错,所以使用unsigned char unsigned char data[100] {0x00,0…

一起搭WPF架构之完结总结篇

一起搭WPF架构之完结总结篇 前言设计总结设计介绍页面一页面二页面三 结束 前言 整体基于WPF架构,根据自己的需求简单设计与实现了衣橱的数据统计、增加与读取数据、并展示数据的小软件。我知道自己在设计方面还有很多不足,暂时先做到这里了&#xff0c…

gbase8s权限管理

一 权限分类 分片级权限(分片表) 表引用 类型级权限 例程级权限 语言级权限 序列级权限 等... 其中常用的为 数据库级权限,表级权限,序列级权限以及例程级权限 二 权限控制 当创建一个用户时,该用户没有任何权…

为了数清还有几天到周末,我用python绘制了日历

日历的秘密 昨天,在看小侄子写作业的时候,发现了一个秘密:他在“演算纸”(计算数学题用的草纸)上画了非常多的日历。对此我感到了非常的困惑,“这是做什么的?” 后来,经过了我不懈…

机器学习面试笔试知识点-线性回归、逻辑回归(Logistics Regression)和支持向量机(SVM)

机器学习面试笔试知识点-线性回归、逻辑回归Logistics Regression和支持向量机SVM 一、线性回归1.线性回归的假设函数2.线性回归的损失函数(Loss Function)两者区别3.简述岭回归与Lasso回归以及使用场景4.什么场景下用L1、L2正则化5.什么是ElasticNet回归6.ElasticNet回归的使…

【设计模式】MyBatis 与经典设计模式:从ORM到设计的智慧

作者:后端小肥肠 🍇 我写过的文章中的相关代码放到了gitee,地址:xfc-fdw-cloud: 公共解决方案 🍊 有疑问可私信或评论区联系我。 🥑 创作不易未经允许严禁转载。 姊妹篇: 【设计模式】揭秘Spri…

计算机网络:数据链路层 —— 以太网(Ethernet)

文章目录 局域网局域网的主要特征 以太网以太网的发展100BASE-T 以太网物理层标准 吉比特以太网载波延伸物理层标准 10吉比特以太网汇聚层交换机物理层标准 40/100吉比特以太网传输媒体 局域网 局域网(Local Area Network, LAN)是一种计算机网络&#x…

GitLab-删除仓库分支(删除远程分支)

进入对应仓库选择对应的分支进行删除操作。

为什么学习使用数控加工中心吗?

现代制造业现代制造业对高精度、高效率的加工需求日益增长,数控加工中心作为核心设备,其操作和维护技能成为企业招聘的重要考量。企业需要能够熟练操作数控加工中心,并具备解决复杂加工问题的能力的人才。 学校通过系学习和实践,学…

不用编程,快速实现多台西门子PLC跟三菱PLC之间数据通讯

PLC通讯智能网关IGT-DSER模块支持汇川、西门子、三菱、欧姆龙、罗克韦尔AB、GE等各种品牌的PLC之间通讯,同时也支持PLC与Modbus协议的变频器、智能仪表等设备通讯。网关有多个网口、串口,也可选择WIFI无线通讯。PLC内无需编程开发,在智能网关…

基于SSM健身国际俱乐部系统的设计

管理员账户功能包括:系统首页,个人中心,用户管理,场地类别管理,场地信息管理,运动项目管理,场地类型管理,项目类型管理 用户账号功能包括:系统首页,个人中心…

使用SearXNG-搭建个人搜索引擎(附国内可用Docker镜像源)

介绍 SearXNG是聚合了七十多种搜索服务的开源搜索工具。我们可以匿名浏览页面,不会被记录和追踪。作为开发者,SearXNG也提供了清晰的API接口以及完整的开发文档。 部署 我们可以很方便地使用Docker和Docker compose部署SearXNG。下面给出Docker部署Se…

ChartCheck: Explainable Fact-Checking over Real-World Chart Images

论文地址: https://aclanthology.org/2024.findings-acl.828.pdfhttps://aclanthology.org/2024.findings-acl.828.pdf 1.概述 事实验证技术在自然语言处理领域获得了广泛关注,尤其是在针对误导性陈述的检查方面。然而,利用图表等数据可视化来传播信息误导的情况却很少受到…

反弹shell的小汇总

前提 理解正向连接和反向连接 正向连接:客户端主动发起连接到服务器或目标系统客户端充当主动方,向服务器发起连接请求,然后服务器接受并处理请求。 反向连接:目标系统(通常是受害者)主动建立与控制系统…

手机拍证件照,换正装有领衣服及底色的方法

证件照在我们的职业生涯的关键节点是经常会用到的,比如毕业入职、人事档案建立、升迁履历、执业资格考试和领证等,这些重要的证件照往往要求使用正装照,有时候手头没有合适的衣服,或者原先的证件照背景色不符合要求,就…

如何在算家云搭建ControlNext-SVD(视频生成)

一、ControlNext-SVD-V2简介 ControlNext-SVD-V2 是 ControlNext-SVD 的 V2 模型。其中 ControlNext-SVD 模型是通过添加 ControlNet 来控制 Stable Video Diffusion (SVD),使用高分辨率视频训练,具体来说它可以将图片生成与指定姿态相匹配的高质量视频…

Python异常检测- 单类支持向量机(One-Class SVM)

系列文章目录 Python异常检测- Isolation Forest(孤立森林) python异常检测 - 随机离群选择Stochastic Outlier Selection (SOS) python异常检测-局部异常因子(LOF)算法 Python异常检测- DBSCAN 文章目录 系列文章目录前言一、On…

1024程序员日|向改变世界的程序员 致敬!

“给我一行代码,我将点亮整个服务器” “给我一个键盘,我就能征服数字世界” 今天 10月24号 是广大程序员“法定”节日(据说是自定的) 因为1024是2的十次方 二进制计数的基本计量单位 1GB 1024MB,1MB 1024KB……

机器学习理论系列——线性模型(上)

系列文章目录 文章目录 线性模型线性回归线性回归与线性模型一元和多元线性回归最小二乘法损失函数与均方误差最小二乘与闭式解 正则化为什么引入正则化 L 1 L^1 L1, L 2 L^2 L2正则化 梯度下降什么是梯度下降算法上的体现 附录 线性模型 线性是数学中的基本概念&a…