基于LSTM的A股股票价格预测系统(torch) :从数据获取到模型训练的完整实现

在这里插入图片描述

1. 项目简介

本文介绍了一个使用LSTM(长短期记忆网络)进行股票价格预测的完整系统。该系统使用Python实现,集成了数据获取、预处理、模型训练和预测等功能。
这个代码使用的是 LSTM (Long Short-Term Memory) 模型,这是一种特殊的循环神经网络 (RNN)

2. 技术栈

  • Python 3.x
  • PyTorch (深度学习框架)
  • AKShare (股票数据获取)
  • Pandas (数据处理)
  • NumPy (数值计算)
  • Scikit-learn (数据预处理)

3. 系统架构

3.1 数据获取模块

def get_stock_data(stock_code, start_date, end_date, stock_name):
    """获取股票历史数据"""
    print(f"正在获取 {stock_name}{stock_code})的数据...")
    try:
        df = ak.stock_zh_a_hist(symbol=stock_code, 
                               period="daily", 
                               start_date=start_date, 
                               end_date=end_date, 
                               adjust="qfq")  # 使用前复权数据
        # ... 数据处理代码
        return df
    except Exception as e:
        print(f"获取{stock_name}数据时发生错误:{str(e)}")
        return None

3.2 LSTM模型定义

class StockRNN(nn.Module):
    """股票预测的LSTM模型"""
    def __init__(self, input_size, hidden_size, num_layers):
        super(StockRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

3.3 数据预处理

def prepare_data(df, sequence_length):
    """准备训练数据"""
    scaler = MinMaxScaler()
    scaled_data = scaler.fit_transform(df[['close']].values)
    
    X, y = [], []
    for i in range(len(scaled_data) - sequence_length):
        X.append(scaled_data[i:(i + sequence_length)])
        y.append(scaled_data[i + sequence_length])
    
    return np.array(X), np.array(y), scaler

4. 主要功能

  1. 股票数据获取和分析
  2. 市场状态评估
  3. 数据预处理和归一化
  4. LSTM模型训练
  5. 股价预测

5. 使用方法

  1. 运行程序
  2. 输入股票代码(或使用默认值)
  3. 设置日期范围
  4. 等待模型训练
  5. 获取预测结果
# 示例使用
stock_code = "002830"  # 股票代码
start_date = "20230101"  # 起始日期
end_date = "20240120"   # 结束日期
predict_date = "20241209"  # 预测日期

6. 模型参数

  • 序列长度:10天
  • LSTM隐藏层大小:64
  • LSTM层数:2
  • 训练轮数:100
  • 学习率:0.001

7. 风险提示

  1. 预测结果仅供参考,不构成投资建议
  2. 长期预测的准确性会显著降低
  3. 股市受多种因素影响,模型无法预测突发事件

8. 可能的改进方向

  1. 增加更多特征(如交易量、技术指标等)
  2. 优化模型架构
  3. 添加更多市场分析指标
  4. 实现实时数据更新
  5. 添加可视化功能

9. 总结

本项目展示了如何使用深度学习技术进行股票价格预测。通过整合数据获取、预处理和模型训练等功能,为股票分析提供了一个完整的解决方案。虽然预测结果仅供参考,但项目的实现过程对理解金融数据分析和深度学习应用具有重要的学习价值。

10. 环境配置与安装

10.1 Python环境要求

  • Python 3.8+

10.2 依赖包安装

# 创建虚拟环境(推荐)
python -m venv myvenv
source myvenv/bin/activate  # Linux/Mac
# 或
myvenv\Scripts\activate  # Windows

# 安装依赖包
pip install akshare
pip install torch
pip install pandas
pip install numpy
pip install scikit-learn

11. 完整代码实现

11.1 股票预测主程序 (stock_prediction_akshare.py)

import akshare as ak  # 导入akshare库,用于获取股票数据
import pandas as pd   # 导入pandas库,用于数据处理
import numpy as np    # 导入numpy库,用于数值计算
import torch         # 导入PyTorch库,用于深度学习
import torch.nn as nn  # 导入神经网络模块
import torch.optim as optim  # 导入优化器模块
from sklearn.preprocessing import MinMaxScaler  # 导入数据归一化工具
import random  # 导入随机数模块

[此处是完整的 stock_prediction_akshare.py 代码,与之前提供的相同]

11.2 下跌五日监控程序 (downfive5.py)

# 如果有 downfive5.py 的代码,请提供给我,我会添加到这里

11.3 Web应用接口 (app.py)

# 如果有 app.py 的代码,请提供给我,我会添加到这里

12. 运行说明

  1. 克隆或下载代码到本地
git clone [repository_url]
cd stock-prediction-system
  1. 安装依赖
pip install -r requirements.txt
  1. 运行股票预测程序
python stock_prediction_akshare.py
  1. 按提示输入:
    • 股票代码(例如:002830)
    • 起始日期(格式:YYYYMMDD)
    • 结束日期(格式:YYYYMMDD)
    • 预测日期(格式:YYYYMMDD)

13. 常见问题解答

  1. 数据获取失败

    • 检查网络连接
    • 确认股票代码是否正确
    • 验证日期格式
  2. 模型训练时间过长

    • 可以减少训练轮数(epochs)
    • 缩短历史数据范围
    • 使用GPU加速(如果可用)
  3. 预测结果异常

    • 检查数据预处理步骤
    • 调整模型参数
    • 验证输入数据的质量

14. 维护与更新

本项目仍在持续改进中,计划添加的功能包括:

  1. 多股票同时预测
  2. 更多技术指标支持
  3. 预测结果可视化
  4. 实时数据更新
  5. Web界面支持

15. 贡献指南

感谢以下开源项目:

  • AKShare
  • PyTorch
  • Pandas
  • NumPy
  • Scikit-learn

注意:

  1. 本项目仅供学习和研究使用
  2. 股市投资有风险,预测结果仅供参考
  3. 实际投资决策需要考虑多种因素

15. 代码:

import akshare as ak  # 导入akshare库,用于获取股票数据
import pandas as pd   # 导入pandas库,用于数据处理
import numpy as np    # 导入numpy库,用于数值计算
import torch         # 导入PyTorch库,用于深度学习
import torch.nn as nn  # 导入神经网络模块
import torch.optim as optim  # 导入优化器模块
from sklearn.preprocessing import MinMaxScaler  # 导入数据归一化工具
import random  # 导入随机数模块

def set_random_seed(seed=42):
    """设置随机种子,确保实验结果可重复"""
    random.seed(seed)  # 设置Python随机数种子
    np.random.seed(seed)  # 设置NumPy随机数种子
    torch.manual_seed(seed)  # 设置PyTorch CPU随机数种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # 设置PyTorch GPU随机数种子

class StockRNN(nn.Module):
    """定义股票预测的RNN模型类"""
    def __init__(self, input_size, hidden_size, num_layers):
        """
        初始化模型参数
        input_size: 输入特征维度
        hidden_size: LSTM隐藏层大小
        num_layers: LSTM层数
        """
        super(StockRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # 定义LSTM层
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # 定义全连接层,用于输出预测值
        self.fc = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        """
        定义前向传播过程
        x: 输入数据,形状为(batch_size, sequence_length, input_size)
        """
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # LSTM前向传播
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出进行预测
        out = self.fc(out[:, -1, :])
        return out

def get_stock_name(stock_code):
    """
    获取股票名称
    stock_code: 股票代码
    返回: 股票名称
    """
    try:
        # 获取A股实时行情数据
        df = ak.stock_zh_a_spot_em()
        # 查找对应股票代码的信息
        stock_info = df[df['代码'] == stock_code]
        if not stock_info.empty:
            return stock_info.iloc[0]['名称']
        
        # 如果没找到,尝试添加市场前缀
        if stock_code.startswith('6'):
            full_code = f"sh{stock_code}"
        else:
            full_code = f"sz{stock_code}"
        
        stock_info = df[df['代码'] == stock_code]
        if not stock_info.empty:
            return stock_info.iloc[0]['名称']
        
        return f"股票{stock_code}"
            
    except Exception as e:
        print(f"获取股票名称时发生错误:{e}")
        return f"股票{stock_code}"

def analyze_market_state(df, stock_name):
    """
    分析股票市场状态
    df: 股票数据DataFrame
    stock_name: 股票名称
    """
    # 计算日收益率
    daily_returns = df['close'].pct_change()
    
    # 计算年化波动率
    volatility = daily_returns.std() * np.sqrt(252)
    
    # 计算整体涨跌幅
    trend = (df['close'].iloc[-1] - df['close'].iloc[0]) / df['close'].iloc[0]
    
    # 计算最大回撤
    cummax = df['close'].cummax()  # 计算历史最高价
    drawdown = (cummax - df['close']) / cummax  # 计算回撤比例
    max_drawdown = drawdown.max()  # 计算最大回撤
    
    # 打印分析结果
    print(f"\n{stock_name}市场状态分析:")
    print(f"样本数量: {len(df)} 天")
    print(f"年化波动率: {volatility:.2%}")
    print(f"整体趋势: {trend:.2%}")
    print(f"最大回撤: {max_drawdown:.2%}")
    print(f"起始价格: {df['close'].iloc[0]:.2f}")
    print(f"结束价格: {df['close'].iloc[-1]:.2f}")

def get_stock_data(stock_code, start_date, end_date, stock_name):
    """
    获取股票历史数据
    stock_code: 股票代码
    start_date: 起始日期
    end_date: 结束日期
    stock_name: 股票名称
    """
    print(f"正在获取 {stock_name}{stock_code})的数据...")
    try:
        # 使用akshare获取股票历史数据
        df = ak.stock_zh_a_hist(symbol=stock_code, 
                               period="daily", 
                               start_date=start_date, 
                               end_date=end_date, 
                               adjust="qfq")  # 使用前复权数据
        
        # 重命名列
        df = df.rename(columns={'收盘': 'close', '日期': 'date'})
        df['date'] = pd.to_datetime(df['date'])  # 转换日期格式
        df.set_index('date', inplace=True)  # 设置日期为索引
        
        # 分析市场状态
        analyze_market_state(df, stock_name)
        
        return df
        
    except Exception as e:
        print(f"获取{stock_name}数据时发生错误:{str(e)}")
        return None

def prepare_data(df, sequence_length):
    """
    准备模型训练数据
    df: 股票数据DataFrame
    sequence_length: 序列长度(用多少天数据预测下一天)
    """
    # 数据归一化
    scaler = MinMaxScaler()
    scaled_data = scaler.fit_transform(df[['close']].values)
    
    # 创建序列数据
    X, y = [], []
    for i in range(len(scaled_data) - sequence_length):
        X.append(scaled_data[i:(i + sequence_length)])  # 输入序列
        y.append(scaled_data[i + sequence_length])      # 预测目标
    
    return np.array(X), np.array(y), scaler

def predict_next_day(model, last_sequence, scaler):
    """
    预测下一个交易日的价格
    model: 训练好的模型
    last_sequence: 最后一个序列数据
    scaler: 归一化器
    """
    with torch.no_grad():  # 不计算梯度
        # 准备输入数据
        last_sequence_tensor = torch.FloatTensor(last_sequence).unsqueeze(0)
        # 进行预测
        predicted_scaled = model(last_sequence_tensor)
        # 将预测结果转换回原始价格范围
        predicted_price = scaler.inverse_transform(predicted_scaled.numpy())
    return predicted_price[0][0]

def get_input_with_default(prompt, default_value):
    """
    获取用户输入,支持默认值
    prompt: 提示信息
    default_value: 默认值
    """
    user_input = input(f"{prompt} [默认: {default_value}]: ").strip()
    return user_input if user_input else default_value

def main():
    """主函数"""
    # 设置随机种子
    set_random_seed(42)
    
    try:
        # 设置默认参数
        default_stock_code = "002830"
        default_start_date = "20230101"
        default_end_date = "20240120"
        default_predict_date = "20241209"
        
        # 获取用户输入
        stock_code = get_input_with_default(
            "请输入股票代码", 
            default_stock_code
        )
        
        start_date = get_input_with_default(
            "请输入起始日期(格式:YYYYMMDD)", 
            default_start_date
        )
        
        end_date = get_input_with_default(
            "请输入结束日期(格式:YYYYMMDD)", 
            default_end_date
        )
        
        predict_date = get_input_with_default(
            "请输入预测日期(格式:YYYYMMDD)", 
            default_predict_date
        )
        
        # 获取股票名称
        stock_name = get_stock_name(stock_code)
        
        # 设置模型参数
        sequence_length = 10  # 使用10天数据预测下一天
        hidden_size = 64     # LSTM隐藏层大小
        num_layers = 2       # LSTM层数
        epochs = 100         # 训练轮数
        learning_rate = 0.001  # 学习率
        
        # 获取并分析数据
        df = get_stock_data(stock_code, start_date, end_date, stock_name)
        if df is None:
            return
            
        # 准备训练数据
        X, y, scaler = prepare_data(df, sequence_length)
        X_tensor = torch.FloatTensor(X)
        y_tensor = torch.FloatTensor(y)
        
        # 初始化模型
        model = StockRNN(input_size=1, hidden_size=hidden_size, num_layers=num_layers)
        criterion = nn.MSELoss()  # 使用均方误差损失
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器
        
        # 训练模型
        for epoch in range(epochs):
            optimizer.zero_grad()  # 清空梯度
            outputs = model(X_tensor)  # 前向传播
            loss = criterion(outputs, y_tensor)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            
            # 每10轮打印一次损失
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
        
        # 预测未来价格
        last_sequence = X[-1]
        predicted_price = predict_next_day(model, last_sequence, scaler)
        print(f"\n预测 {stock_name}{stock_code})在 {predict_date} 的收盘价: {predicted_price:.2f}")
        
        # 打印风险提示
        print("\n风险提示:")
        print(f"1. {stock_name}的预测是基于历史数据的模型预测,不构成投资建议")
        print("2. 长期预测(超过一周)的准确性会显著降低")
        print("3. 股市受多种因素影响,模型无法预测突发事件的影响")
        
    except Exception as e:
        print(f"程序运行出错:{str(e)}")

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/933403.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【python自动化五】接口自动化基础--requests的使用

python的接口请求可以用requests库,这个介绍就不多说了,网上说得很详细。 接下来直接记录下如何使用(当然也不限于自动化的使用) 1.安装requests requests也需要安装一下 pip install requests2.requests请求 1.常用的请求方法…

【NLP 5、深度学习的基本原理】

目录 一、梯度下降算法 1.引例 —— 找极小值问题 目标: 方法: 2.梯度 例: 3.求解目标 为什么损失函数越小越好 4.梯度下降法 代码实现 5.细节问题 6.梯度爆炸和梯度消失 梯度爆炸 梯度消失 7.过拟合和欠拟合 欠拟合(Underfitting…

DAY168内网对抗-基石框架篇单域架构域内应用控制成员组成用户策略信息收集环境搭建

知识点: 1、基石框架篇-单域架构-权限控制-用户和网络 2、基石框架篇-单域架构-环境搭建-准备和加入 3、基石框架篇-单域架构-信息收集-手工和工具 1、工作组(局域网) 将不同的计算机按照功能分别列入不同的工作组。想要访问某个部门的资源,只要在“…

MATLAB 建筑顶面面积计算(95)

MATLAB 建筑顶面面积计算(95) 一、算法介绍二、算法实现1.代码2.结果一、算法介绍 根据给出的建筑顶面点云,计算建筑面积,具体的方法实现和结果如下: 二、算法实现 1.代码 代码如下(示例): % 从 PLY 文件读取点云数据 filename = D:\shuju\屋顶2.ply; % 替换为你的…

Mac M1 安装数据库

1. Docker下载 由于Sqlserver和达梦等数据库,不支持M系列的芯片,所以我们通过docker安装 下载并安装docker: https://www.docker.com/get-started/ 安装完成后,打开docker 2. SQL Server 安装 2.1 安装 打开终端,执行命令 doc…

二十(GIT3)、echarts(折线图、柱状图、饼图)、黑马就业数据平台(主页图表实现、闭包了解、学生信息渲染)

1. echarts 数据可视化:将数据转换为图形,数据特点更加突出 echarts:一个基于 JavaScript 的开源可视化图表库 echarts官网 1.1 echarts核心使用步骤 // 1. 基于准备好的dom,初始化echarts实例 const myChart echarts.init…

软考高级架构-9.4.4-双机热备技术 与 服务器集群技术

一、双机热备 1、特点: 软硬件结合:系统由两台服务器(主机和备机)、一个共享存储(通常为磁盘阵列柜)、以及双机热备软件(提供心跳检测、故障转移和资源管理功能的核心软件)组成。 …

【Java若依框架】RuoYi-Vue的前端和后端配置步骤和启动步骤

🎙告诉你:Java是世界上最美好的语言 💎比较擅长的领域:前端开发 是的,我需要您的: 🧡点赞❤️关注💙收藏💛 是我持续下去的动力! 目录 一. 作者有话说 …

Kubernetes Nginx-Ingress | 禁用HSTS/禁止重定向到https

目录 前言禁用HSTS禁止重定向到https关闭 HSTS 和设置 ssl-redirect 为 false 的区别 前言 客户请求经过ingress到服务后,默认加上了strict-transport-security,导致客户服务跨域请求失败,具体Response Headers信息如下; 分析 n…

小程序入门学习(八)之页面事件

一、下拉刷新新事件 1. 什么是下拉刷新 下拉刷新是移动端的专有名词,指的是通过手指在屏幕上的下拉滑动操作,从而重新加载页面数据的行为。 2. 启用下拉刷新 启用下拉刷新有两种方式: 全局开启下拉刷新:在 app.json 的 window…

C++(十二)

前言: 本文将进一步讲解C中,条件判断语句以及它是如何运行的以及内部逻辑。 一,if-else,if-else语句。 在if语句中,只能判断两个条件的变量,若想实现判断两个以上条件的变体,就需要使用if-else,if-else语…

[Linux]文件属性和权限

目录 一.Linux文件的属性二.Linux用户权限分类三.文件权限的查询与修改1.修改用户的权限1).一般法2).8进制法 2.修改所属组和所属者3.如何在创建文件时权限预分配 在学习linux的时候,我们用ll命令显示文件的详情信息,难免会发现文件名前面会有一大堆其它…

ElK 8 收集 MySQL 慢查询日志并通过 ElastAlert2 告警至飞书

文章目录 1. 说明2. 启个 mysql3. 设置慢查询4. filebeat 设置5. 触发慢查询6. MySQL 告警至飞书 1. 说明 elk 版本:8.15.0 2. 启个 mysql docker-compose.yml 中 mysql: mysql:# restart: alwaysimage: mysql:8.0.27# ports:# - "3306:3306&q…

springSecurity权限控制

权限控制:不同的用户可以使用不同的功能。 我们不能在前端判断用户权限来控制显示哪些按钮,因为这样,有人会获取该功能对应的接口,就不需要通过前端,直接发送请求实现功能了。所以需要在后端进行权限判断。&#xff0…

力扣打卡9:重排链表

链接:143. 重排链表 - 力扣(LeetCode) 这是一道操作链表的题。按照要求,我们可以将解题的步骤分成三步。 1.找链表中间结点(我使用了快慢指针寻找),并断开。 2.现在有2链表,将后段…

计算机键盘的演变 | 键盘键名称及其功能 | 键盘指法

注:本篇为 “键盘的演变及其功能” 相关几篇文章合辑。 英文部分机翻未校。 The Evolution of Keyboards: From Typewriters to Tech Marvels 键盘的演变:从打字机到技术奇迹 Introduction 介绍 The keyboard has journeyed from a humble mechanical…

【Appium报错】安装uiautomator2失败

目录 1、通过nmp安装uiautomator2:失败 2、通过 Appium 的平台直接安装驱动程序 3、通过pip 来安装 uiautomator2 1、通过nmp安装uiautomator2:失败 我先是通过npm安装的uiautomator2,也显示已经安装成功了: npm install -g …

SSM整合原理实战案例《任务列表案例》

一、前端程序搭建和运行: 1.整合案例介绍和接口分析: (1).案例功能预览: (2).接口分析: 学习计划分页查询 /* 需求说明查询全部数据页数据 请求urischedule/{pageSize}/{currentPage} 请求方式 get 响应的json{"code":200,"flag":true,"data&…

Chrome扩展程序开发示例

项目文件夹内文件如下: manifest.json文件内容: {"manifest_version": 3,"name": "我的法宝","description": "我的有魔法的宝贝","version": "1.0","icons": {"…

石头剪子布

石头剪子布 C语言实现C实现Java实现Python实现 💐The Begin💐点点关注,收藏不迷路💐 石头剪子布,是一种猜拳游戏。起源于中国,然后传到日本、朝鲜等地,随着亚欧贸易的不断发展它传到了欧洲&…