循环神经网络(RNN)实现股票预测

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
  • 四、数据预处理
    • 1.归一化
    • 2.设置测试集训练集
  • 五、构建模型
  • 六、激活模型
  • 七、训练模型
  • 八、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('./datasets/SH600519.csv')  # 读取股票文件

data
Unnamed: 0dateopenclosehighlowvolumecode
0742010-04-2688.70287.38189.07287.362107036.13600519
1752010-04-2787.35584.84187.35584.68158234.48600519
2762010-04-2884.23584.31885.12883.59726287.43600519
3772010-04-2984.59285.67186.31584.59234501.20600519
4782010-04-3083.87182.34083.87181.52385566.70600519
242124952020-04-201221.0001227.3001231.5001216.80024239.00600519
242224962020-04-211221.0201200.0001223.9901193.00029224.00600519
242324972020-04-221206.0001244.5001249.5001202.22044035.00600519
242424982020-04-231250.0001252.2601265.6801247.77026899.00600519
242524992020-04-241248.0001250.5601259.8901235.18019122.00600519

2426 rows × 8 columns

training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

四、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置测试集训练集

x_train = []
y_train = []

x_test = []
y_test = []

"""
使用前60天的开盘价作为输入特征x_train
    第61天的开盘价作为输入标签y_train
    
for循环共构建2426-300-60=2066组训练数据。
       共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):
    x_train.append(training_set[i - 60:i, 0])
    y_train.append(training_set[i, 0])
    
for i in range(60, len(test_set)):
    x_test.append(test_set[i - 60:i, 0])
    y_test.append(test_set[i, 0])
    
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)

调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)

"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

五、构建模型

model = tf.keras.Sequential([
    SimpleRNN(80, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。
    Dropout(0.2),                         #防止过拟合
    SimpleRNN(80),
    Dropout(0.2),
    Dense(1)
])

六、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='mean_squared_error')  # 损失函数用均方误差

七、训练模型

history = model.fit(x_train, y_train, 
                    batch_size=64, 
                    epochs=20, 
                    validation_data=(x_test, y_test), 
                    validation_freq=1)                  #测试的epoch间隔数

model.summary()
Epoch 1/20
33/33 [==============================] - 6s 123ms/step - loss: 0.1809 - val_loss: 0.0310
Epoch 2/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0257 - val_loss: 0.0721
Epoch 3/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0165 - val_loss: 0.0059
Epoch 4/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0097 - val_loss: 0.0111
Epoch 5/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0099 - val_loss: 0.0139
Epoch 6/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0067 - val_loss: 0.0167
Epoch 7/20
33/33 [==============================] - 3s 86ms/step - loss: 0.0067 - val_loss: 0.0095
Epoch 8/20
33/33 [==============================] - 3s 91ms/step - loss: 0.0063 - val_loss: 0.0218
Epoch 9/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0052 - val_loss: 0.0109
Epoch 10/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0043 - val_loss: 0.0120
Epoch 11/20
33/33 [==============================] - 3s 92ms/step - loss: 0.0044 - val_loss: 0.0167
Epoch 12/20
33/33 [==============================] - 3s 89ms/step - loss: 0.0039 - val_loss: 0.0032
Epoch 13/20
33/33 [==============================] - 3s 88ms/step - loss: 0.0041 - val_loss: 0.0052
Epoch 14/20
33/33 [==============================] - 3s 93ms/step - loss: 0.0035 - val_loss: 0.0179
Epoch 15/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0033 - val_loss: 0.0124
Epoch 16/20
33/33 [==============================] - 3s 95ms/step - loss: 0.0035 - val_loss: 0.0149
Epoch 17/20
33/33 [==============================] - 4s 111ms/step - loss: 0.0028 - val_loss: 0.0111
Epoch 18/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0029 - val_loss: 0.0061
Epoch 19/20
33/33 [==============================] - 3s 104ms/step - loss: 0.0027 - val_loss: 0.0110
Epoch 20/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0028 - val_loss: 0.0037
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 60, 80)            6560      
_________________________________________________________________
dropout (Dropout)            (None, 60, 80)            0         
_________________________________________________________________
simple_rnn_1 (SimpleRNN)     (None, 80)                12880     
_________________________________________________________________
dropout_1 (Dropout)          (None, 80)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 81        
=================================================================
Total params: 19,521
Trainable params: 19,521
Non-trainable params: 0
_________________________________________________________________

八、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(0,1)反归一化到原始范围

# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction by K同学啊')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)

print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 1833.92534
均方根误差: 42.82435
平均绝对误差: 36.23424
R2: 0.72347

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/175188.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

为什么录屏没声音?实用技巧大放送!

录屏已成为我们在数字时代记录和分享内容的重要方式之一。但有时,您可能会遇到录制视频却没有声音的问题。这个问题可能出现在不同的录屏软件中,导致许多人感到疑惑。在本文中,我们将探讨为什么录屏没声音,并提供两种解决方案&…

防雷接地+防雷工程施工综合方案

一、地凯科技防雷工程接地概述 防雷接地工程是指在建筑物或其他设施上安装防雷装置,以防止雷电对人员、设备和建筑物造成危害的工程。防雷装置主要包括避雷针(网)、引下线、接地体(网)等部分,其中接地体&a…

代码随想录刷题】Day17 二叉树04

文章目录 1.【110】平衡二叉树(优先掌握递归)1.1 题目描述1.2 解题思路1.3 java代码实现 2.【257】二叉树的所有路径(优先掌握递归)2.1 题目描述2.2 解题思路2.3 java代码实现 3.【404】左叶子之和(优先掌握递归&#…

CSS特效015:7个小球转圈圈加载效果

CSS常用示例100专栏目录 本专栏记录的是经常使用的CSS示例与技巧,主要包含CSS布局,CSS特效,CSS花边信息三部分内容。其中CSS布局主要是列出一些常用的CSS布局信息点,CSS特效主要是一些动画示例,CSS花边是描述了一些CSS…

渲染器——快速Diff算法

讨论第三种用于比较新旧两组子节点的方式:快速Diff 算法。正如其名,该算法的实测速度非常快。该算法最早应用于 ivi 和 inferno 这两个框架,Vue.js 3 借鉴并扩展了它。 下图比较了 ivi、inferno 以及 Vue.js 2 的性能: 上图来自…

Docker Swarm总结

1、swarm 理论基础 1.1 简介 Docker Swarm 是由 Docker 公司推出的 Docker 的原生集群管理系统,它将一个 Docker 主机池变成了一个单独的虚拟主机,用户只需通过简单的 API 即可实现与 Docker 集群的通 信。Docker Swarm 使用 GO 语言开发。从 Docker 1.…

SpringBoot : ch04 整合数据源

前言 Spring Boot 是当今最流行的 Java 开发框架之一,它以简洁、高效的特点帮助开发者快速构建稳健的应用程序。在实际项目中,涉及到数据库操作的需求时,我们需要对数据源进行整合。本文将重点介绍如何在 Spring Boot 中整合数据源&#xff…

这是一棵适合搜索二叉树

🎈个人主页:🎈 :✨✨✨初阶牛✨✨✨ 🐻强烈推荐优质专栏: 🍔🍟🌯C的世界(持续更新中) 🐻推荐专栏1: 🍔🍟🌯C语言初阶 🐻推荐专栏2: 🍔…

Tesco EDI需求分析

Tesco,成立于1919年,是一家全球领先的综合性零售企业,总部位于英国。公司致力于提供高质量、多样化的商品和服务,以满足客户的需求。Tesco的使命是通过创新和卓越的客户服务,为客户创造更美好的生活。多年来&#xff0…

【带你读懂数据手册】CN3702 一款锂电池充电芯片

大家在学习智能车或者飞行器的时候,是不是外接一个电池?最近刚好学习了一款充电芯片,来和大家分享一下,也算是我的一点点笔记。 一款7.4V锂电池,基本上也满足了单片机的外设,如果需要12V或者24V的电压&…

【前端】前端监控⊆埋点

文章目录 前端监控分为三个方面前端监控流程异常监控常见的错误捕获方法主要是 try / catch 、window.onerror 和window.addEventListener 等。Promise 错误Vue 错误React 错误 性能监控用户行为监控常见的埋点方案来源 前端监控分为三个方面 异常监控(监控前端页面…

软件临界资源访问冲突

1. 基础概念 1.1 cpu执行汇编代码 处理指令的步骤主要包括以下几步动作: 1.提取(Fetch)指令。 2.解码(Decode)指令。 3.执行(Execute)指令。 cpu运行一条汇编需要执行三个步骤,按照顺序依次执行。异常触发中断需要等待一条汇编运行完成才能跳转&…

数据库的基本概念以及MySQL基本操作

一、数据库的基本概念 1、数据库的组成 数据:描述事物的符号记录 包括数字,文字、图形、图像、声音、档案记录等 以“记录”形式按统一格式进行存储 表:将不同的记录组织在一起,用来存储具体数据 数据库: 表的集合…

Python 跨文件夹导入自定义包

一、问题再现 有时我们自己编写一些模块时,跨文件夹调用会出现ModuleNotFoundError: No module named XXX 二、解决方案 只需要在下层文件夹中的__init__.py文件中,添加如下代码即可: import sys from os import path sys.path.append(pa…

万字解析设计模式之 适配器模式

一、 适配器模式 1.1概述 将一个接口转换成客户希望的另一个接口,适配器模式使接口不兼容的那些类可以一起工作。 适配器模式分为类适配器模式和对象适配器模式,前者类之间的耦合度比后者高,且要求程序员了解现有组件库中的相关组件的内部结…

imx VPU解码分析4-wrap与hantro的关系

前面已经分析了wrap和hantro,但是二者是如何结合的,wrap是如何封装hantro的,提供了哪些接口,封装了哪些细节还不太清楚,此文来探究下。这里还是只关注解码。 imx VPU解码分析1-wrap-CSDN博客 imx VPU解码分析2-hantr…

值得收藏推荐的 21 款免费数据恢复软件工具

使用这些免费数据恢复工具 之一找回您认为永远消失的文件。我根据这些程序的易用性和提供的功能对这些程序进行了排名。 这些应用程序从您的硬盘驱动器、USB 驱动器、媒体卡等恢复文档、视频、图像、音乐等。我建议每个计算机所有者安装其中一个程序,最好尽快&#…

【MySQL】一些内置函数(时间函数、字符串函数、数学函数等,学会了有妙用)

内置函数 前言正式开始时间函数显示当前日期、时间、日期时间的日期计算相差多少天示例创建一张表,记录生日 留言表 字符串函数charsetconcatinstr(string, substring)ucase和lcaseleft(string, length)length求字符串长度replace(str, search_str, replace_str)tri…

【LeetCode刷题笔记】DFSBFS(一)

51. N 皇后 解题思路: DFS + 回溯 :由于 NxN 个格子放 N 个皇后, 同一行不能放置 2 个皇后,所以皇后必然放置在不同行 。 因此,可以从第 0 行开始,逐行地尝试,在每一个 i

Pyside6/PyQt6的QTreeWidget如何添加多级子项,如何实现选中父项,子项也全部选中功能,源码示例

文章目录 📖 介绍 📖🏡 环境 🏡📒 使用方法 📒📝 数据📝 源码📖 介绍 📖 在UI开发中经常会需要展示/让用户多层级选择,这篇文章记录了一个QTreeWidget如何添加多级子项,如何实现选中父项,子项也全部选中/取消选中功能的源码示例,大家可以举一反三实现自…