TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文

import yfinance as yf
yf.pdr_override() #需要调用这个函数

# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')

# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import deque

x = []
y = []

seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:
    deq.append(list(i))
    if len(deq) == seq_len:
        x.append(list(deq))

x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)

x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam

# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均
              metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)
 
plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt

#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/54847.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

spring boot合并 http请求(前后端实现)

为什么要合并http请求 页面加载时要初始化很多资源信息,会发送好多http请求,根据http的通信握手特点,多个http请求比较浪费资源。进而如果能将多个http请求合并为一个发送给服务器,由服务器获取对应的资源返回给客户端 奇思妙解 …

Hadoop学习日记-YARN组件

YARN(Yet Another Resource Negotiator)作为一种新的Hadoop资源管理器,是另一种资源协调者。 YARN是一个通用的资源管理系统和调度平台,可为上层应用提供统一的资源管理和调度 YARN架构图 YARN3大组件: (物理层面&#xff09…

C语言每日一题:12《数据结构》相交链表。

题目: 题目链接 思路一: 1.如果最后一个节点相同说明一定有交点。 2.使用两个循环获取一下长度,同时可以获取到尾节点。 3。注意初始化lenA和lenB为1,判断下一个节点是空是可以保留尾节点的。长度会少一个,尾节点没有…

小白到运维工程师自学之路 第六十集 (docker的概述与安装)

一、概述 1、客户(老板)-产品-开发-测试-运维项目周期不断延后,项目质量差。 随着云计算和DevOps生态圈的蓬勃发展,产生了大量优秀的系统和软件。软件开发人员可以自由选择各种软件应用环境。但同时带来的问题就是需要维护一个非…

【HarmonyOS】ArkTS 组件内转场动画,动画播放时颜色异常问题

【关键字】 HarmonyOS、ArkTS、组件内转场动画、颜色异常 【问题描述】 根据组件内转场动画文档中示例编写代码,使用动画转场组件button,并给button设置背景色让button透明度为0,实现动画转场时,会先出现默认蓝色button&#xf…

Ai创作系统ChatGPT源码搭建教程+附源码

系统使用Nestjs和Vue3框架技术,持续集成AI能力到本系统! 更新内容: 同步官方图片重新生成指令 同步官方 Vary 指令 单张图片对比加强 Vary(Strong) | Vary(Subtle) 同步官方 Zoom 指令 单张图片无限缩放 Zoom out 2x | Zoom out 1.5x 新增GP…

Java使用hive连接kyuubi

一、Maven依赖 <dependency><groupId>org.apache.hive</groupId><artifactId>hive-jdbc</artifactId><version>2.3.9</version> </dependency> 二、相关配置信息 驱动类&#xff1a;org.apache.hive.jdbc.HiveDriver连接UR…

海外网红营销:如何利用故事打造独具魅力的品牌形象?

随着全球数字化时代的来临&#xff0c;品牌推广已经从传统的广告宣传方式逐渐转变为更加注重故事性和情感共鸣的营销手段。故事营销在品牌塑造和传播过程中发挥着重要作用&#xff0c;它能够吸引消费者的注意力&#xff0c;加深品牌与受众的情感连接&#xff0c;从而为品牌带来…

单元测试框架中测试用例的执行顺序

一、用例用例全部执行与选择执行 单元测试用例的执行顺序按照定义的用例的名称的编码大小&#xff0c;从小到大依次执行&#xff0c;因此一般通过后缀001、002...等来规划测试用例的执行顺序&#xff0c;例如&#xff1a; import unittestclass F1(unittest.TestCase):def set…

3D Tiles官方示例资源下载链接

本文列出Cesium官方提供的 3D Tiles 1.0和1.1规范的9个示例切块集&#xff08;tileset&#xff09;。 有关如何使用本地服务器托管这些示例的详细信息&#xff0c;请参阅 INSTRUCTIONS.md。 推荐&#xff1a;用 NSDT设计器 快速搭建可编程3D场景。 1、Metadata Granularities …

【java线上监控】Arthas由菜鸟到菜鸡

目录 1 arthas介绍 1.1 简介 1.2 背景 1.3 Arthas&#xff08;阿尔萨斯&#xff09;能为你做什么&#xff1f; 2 window环境搭建 2.1 下载和启动 2.1 启动后&#xff1a;选择需要监控的服务 3 快速入门 3.1 打开浏览器 3.2 开启trace请求 3.3 查看 dashboard 3.6.…

在人间烟火里,卡萨帝开启品牌新征程

大暑刚过天正热&#xff0c;尤其是在今年厄尔尼诺现象席卷太平洋的背景下&#xff0c;人们对空调的需求持续“升温”。CCTV2财经频道《正点财经》在专题报道中提到&#xff0c;6月国内空调产销两旺&#xff0c;同比增长均在35%以上。 炙热的天气下&#xff0c;南方居民有更加难…

fiddler 手机抓包(含https) 完整流程

第一部分&#xff1a;下载并安装fiddler 一.使用任一浏览器搜索【fiddler下载安装】&#xff0c;并下载fiddler 安装包。 二.fiddler安装包下载成功后&#xff0c;将下载的fiddler压缩包解压到自定义文件夹【fiddler】或者解压到当前文件夹下&#xff0c;双击文件夹中的【fidd…

哈工大计算机网络课程局域网详解之:交换机概念

哈工大计算机网络课程局域网详解之&#xff1a;交换机概念 文章目录 哈工大计算机网络课程局域网详解之&#xff1a;交换机概念以太网交换机&#xff08;switch&#xff09;交换机&#xff1a;多端口间同时传输交换机转发表&#xff1a;交换表交换机&#xff1a;自学习交换机互…

【SCSS】网格布局中的动画

效果 index.html <!DOCTYPE html> <html><head><title> Document </title><link type"text/css" rel"styleSheet" href"index.css" /></head><body><div class"container">&l…

RabbitMQ 教程 | 第3章 客户端开发向导

&#x1f468;&#x1f3fb;‍&#x1f4bb; 热爱摄影的程序员 &#x1f468;&#x1f3fb;‍&#x1f3a8; 喜欢编码的设计师 &#x1f9d5;&#x1f3fb; 擅长设计的剪辑师 &#x1f9d1;&#x1f3fb;‍&#x1f3eb; 一位高冷无情的编码爱好者 大家好&#xff0c;我是 DevO…

什么是云原生和 CNCF?

一、CNCF简介 CNCF&#xff1a;全称Cloud Native Computing Foundation&#xff08;云原生计算基金会&#xff09;&#xff0c;成立于 2015 年 12 月 11 日&#xff0c;是一个开源软件基金会&#xff0c;它致力于云原生&#xff08;Cloud Native&#xff09;技术的普及和可持续…

Docker中gitlab以及gitlab-runner的安装与使用

1、本文主要讲述如何使用Docker安装gitlab以及gitlab-runner&#xff0c;并且会讲述gitlab-runner如何使用 2、gitlab部分不需要修改过多的配置即可使用&#xff0c;本文未讲述https配置&#xff0c;如有需求&#xff0c;可自行百度 3、Docker如何安装可以自行百度 一、Docker安…

保姆级教程,Linux服务器docker搭建jenkins持续集成一键部署SpringBoot项目(Gradle)

前言&#xff1a; 在后台项目开发过程从Java延伸到Kotlin开发&#xff0c;从maven pom到gradle&#xff0c;IDEA新项目SpringBoot init框架官方推荐kotlingradle&#xff0c;本章以此为jenkins持续集成做项目部署&#xff0c;服务器为Centos&#xff0c;JDK 17&#xff0c;Spr…

【C++】文件操作(囊括特殊情况:读文件遇到的空格被跳过、“文件只读一次“)

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】文件操作&#xff08;囊括特殊情况&#xff1a;读文件遇到的空格被跳过、“文件只读一次”&#xff09; website&#xff1a;黑马程序员C date&#xff1a;2023年7月31日 目录 文本文件 写文件 源…