基于强化学习DQN的股票预测【股票交易】

强化学习笔记

第一章 强化学习基本概念
第二章 贝尔曼方程
第三章 贝尔曼最优方程
第四章 值迭代和策略迭代
第五章 强化学习实例分析:GridWorld
第六章 蒙特卡洛方法
第七章 Robbins-Monro算法
第八章 多臂老虎机
第九章 强化学习实例分析:CartPole
第十章 时序差分法
第十一章 值函数近似【DQN】
第十二章 基于强化学习DQN的股票预测


文章目录

  • 强化学习笔记
  • 一、DQN
  • 二、软更新
  • 三、实验
  • 四、参考资料


在金融决策问题中,如何制定有效的交易策略一直是一个重要且具有挑战性的问题。近年来,强化学习在这一领域的应用显示出了很大的潜力,比如,强化学习可以帮助我们在股票交易过程中进行决策。

在这里,我想先比较一下监督学习和强化学习在股票交易问题中的不同:

  1. 监督学习主要关注预测,即通过历史数据训练模型,然后对未来的数据进行预测。例如,我们可以通过监督学习预测股票的价格走势。如果要交易还得结合其他策略方法。
  2. 而强化学习不仅仅是预测,它可以进行交易决策。它不仅仅关注于预测未来的股票价格,更重要的是,它可以根据预测结果来制定买卖策略,以最大化我们的收益。

下图给出了强化学习在股票交易问题应用中的主要框架:

image-20240627140425224 其核心问题有以下几点:
  1. 如何定义奖励函数,即Reward如何设置?
  2. 采用强化学习中的哪种模型,DQN、PPO、A2C、DDPG……
  3. 状态空间如何定义?

一、DQN

本文我们介绍用深度强化学习中最经典的模型——DQN来进行建模,完整代码放在GitHub上——DQN-for-Stock-Trading。在DQN模型中,采用了多个全连接线性层,其模型结构如下:

class QNetwork(nn.Module):
    """QNetwork (Deep Q-Network), state is the input, 
        and the output is the Q value of each action.
    """
    def __init__(self, state_size, action_size, fc1_units=128, fc2_units=128, fc3_units=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size , fc1_units)
        self.fc2 = nn.Linear(fc1_units, fc2_units)
        self.fc3 = nn.Linear(fc2_units, fc3_units)
        self.fc4 = nn.Linear(fc3_units, action_size)
        self.dropout = nn.Dropout(0.1)  # Dropout with 20% probability

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

其中:

  1. 输入也就是状态 s s s,建模为股票过去几天的波动情况,也就是相邻两天的差值,输入的维数由给定的一个滑动窗口大小决定;
  2. 输出则是 q ( s , a ) q(s,a) q(s,a),其中 ∣ A ∣ = 3 |\mathcal{A}|=3 A=3,也就是说action有三种0、1、2,分别代表买入,卖出或者不变.

DQN的一个核心思想是经验缓冲池,将数据都放入缓冲池内,训练网络时从这里面采样得到小批量数据,其主要代码如下:

class ReplayBuffer:
    def __init__(self, action_size, buffer_size, batch_size):
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)  # initialize replay buffer
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

    def add(self, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)

DQN另一个重要思想是用两个神经网络来交替更新参数,其代码如下:

class Agent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

二、软更新

在更新target network时,我们采用软更新的策略。软更新是一种在深度强化学习中更新目标网络参数的方法。目标网络(target network)用于稳定训练过程,其参数并不像本地网络(local network)那样在每一步都更新,而是以较慢的速率进行更新。软更新通过将目标网络的参数逐步向本地网络的参数靠拢来实现这种较慢的更新。具体来说,软更新的公式如下:
θ target ← τ θ local + ( 1 − τ ) θ target \theta_{\text{target}} \leftarrow \tau \theta_{\text{local}} + (1 - \tau) \theta_{\text{target}} θtargetτθlocal+(1τ)θtarget其中:

  • θ target \theta_{\text{target}} θtarget 是目标网络的参数。
  • θ local \theta_{\text{local}} θlocal 是本地网络的参数。
  • τ \tau τ 是软更新的比例系数,通常是一个非常小的值(例如 0.001)。

这个公式表示目标网络的参数是本地网络参数的 τ \tau τ 倍加上目标网络自身参数的 ( 1 − τ ) (1 - \tau) (1τ) 倍。因此,目标网络参数的变化是渐进的,而不是像硬更新(hard update)那样直接将本地网络的参数复制到目标网络。

在代码中,软更新通过 soft_update 方法实现:

def soft_update(self, local_model, target_model, tau):
    for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
        target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

在DQN算法中,如果目标网络的参数频繁更新,会导致训练过程不稳定,因为目标网络用于计算目标值,而这些目标值需要在一段时间内保持相对稳定。因此,软更新通过缓慢调整目标网络的参数,能够有效地平滑训练过程,提高算法的收敛性和稳定性。

三、实验

在比较简单的环境设置下进行实验,不考虑交易成本,每次买入卖出都是1股股票,reward设置为卖出股票时赚的钱。下图是训练过程的累积收益,我们可以看到随着不断地学习,agent的决策确实使得我们在这只股票上挣钱了!

image-20240627142654564

下图是在训练数据上回测的结果,我们可以看到agent学到了一个简单的“低吸高抛”的策略。

image-20240627142617668

下图是在测试集上的实验,我们发现在没有训练的数据上用刚才的模型也能挣钱,并且策略仍然是低吸高抛.

image-20240627142728266

采用更复杂的交易环境,考虑交易成本,每次买入卖出的数量,奖励函数采用收益率,我们可以得到一个复杂的策略。下图图仍是在训练数据上的回测,我们可以看到相比前面的“低吸高抛”策略稍微复杂了一些,下面条形图表示持仓,可以看到学习的策略在股票价格最低时增大仓位,在股票价格高点时,抛售赚钱。

截屏2024-06-27 14.29.17

四、参考资料

  1. https://www.youtube.com/watch?v=05NqKJ0v7EE

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/769286.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

swiftui中常用组件picker的使用,以及它的可选样式

一个可选项列表就是一个picker组件搞出来的,它有多个样式可以选择,并且可以传递进去一些可选数据,有点像前端页面里面的seleted组件,但是picker组件的样式可以更多。可以看官方英文文档:PickerStyle | Apple Developer…

【Week-G2】人脸图像生成(DCGAN)--pytorch版本

文章目录 0、遇到的问题1、配置环境 & 导入数据2、定义模型3、训练模型4、什么是DCGAN? 🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 本文环境: 系统环境:…

从搜索框的提示词中再探防抖和节流

前言 最近逛掘金时,看到了一篇文章。发现是我之前写过的一篇文章主题是防抖和节流的,看防抖时没感觉哪里不一样,但是当我看到节流时发现他的节流怎么这么繁琐(・∀・(・∀・(・∀・*)? 抱着疑惑的想法,我仔细拜读了这…

PyCharm 如何设置作者信息

1、点击pycharm右上角的齿轮,选择settings 2、选择editor 3、选择 Editor File and Code Templates 4、选择作者信息的文件类型,中间选择框选择Python Script 5、然后在右边的输入框中输入相关的信息 # -*- coding: utf-8 -*- """ Time …

kotlin接口,前端怎么调用?

文章目录 🎉欢迎来到Java学习路线专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博客🎈该系列文章专栏:Java学习路线📜其他专栏:Java学习路线 Jav…

构建大数据生态:Sqoop、Hadoop、IDEA和Maven的完整安装与数据预处理指南【实训Day03】

一、Sqoop安装 1 上传安装包并解压缩(在hadoop101上) # cd /opt/software 点击xftp上传sqoop的安装文件sqoop-1.4.6.bin__hadoop-2.0.4-alpha.tar.gz # tar -zxvf sqoop-1.4.6.bin__hadoop-2.0.4-alpha.tar.gz -C /opt/module/ # cd /opt/module/ # mv s…

vue3+vue-router+vite 实现动态路由

文章中出现的代码是演示版本,仅供参考,实际的业务需求会更加复杂 什么是动态路由 什么场景会用到动态路由 举一个最常见的例子,比如说我们要开发一个后台管理系统,一般来说后台管理系统都会分角色登录,这个时候也就涉…

基于Vue.js和SpringBoot的地方美食分享网站系统设计与实现

你好,我是计算机专业的学姐,专注于前端开发和系统设计。如果你对地方美食分享网站感兴趣或有相关需求,欢迎随时联系我。 开发语言 Java 数据库 MySQL 技术 Vue.js SpringBoot Java 工具 Eclipse, MySQL Workbench, Maven 系统展示…

【Spring Boot】关系映射开发(一):一对一映射

关系映射开发(一):一对一映射 1.认识实体间关系映射1.1 映射方向1.2 ORM 映射类型 2.实现 “一对一” 映射2.1 编写实体2.1.1 新建 Student 实体2.1.2 新建 Card 实体 2.2 编写 Repository 层2.2.1 编写 Student 实体的 Repository2.2.2 编写…

AMEYA360代理:海凌科60G客流量统计雷达模块 4T4R出入口绊数计数

数字化时代,不管是大型商城还是各种连锁店,客流统计分析都可以帮助企业更加精准地了解顾客需求和消费行为。 海凌科推出一款专用于客流量统计的60G雷达模块,4T4R,可以实时进行固定范围内的人体运动轨迹检测,根据人体的…

0703_ARM7

练习: 封装exti,cic初始化函数 //EXTI初始化 void hal_key_exti_init(int id,int exticr,int mode){//获取偏移地址int address_offset (id%4)*8;//获取寄存器编号int re_ser (id/4)1;//printf("address_offset%d,re_ser%d\n",address_o…

学会python——用python编写一个电子时钟(python实例十七)

目录 1.认识Python 2.环境与工具 2.1 python环境 2.2 Visual Studio Code编译 3.电子时钟程序 3.1 代码构思 3.2代码实例 3.3运行结果 4.总结 1.认识Python Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强的可读性…

软件模型分类及特点

在软件开发的世界里,我们经常会遇到业务模型、系统模型和软件模型这三个层次。这些模型各有特点,相互之间也有着紧密的联系。通过理解这三个层次之间的映射关系,我们能更好地理解和掌握软件开发的全过程 1. 业务模型 业务模型描述了组织的业…

搜维尔科技:Xsens实时动作捕捉,驱动人形机器人操作!

搜维尔科技:http://www.souvr.comhttp://www.souvr.com实时动作捕捉,驱动人形机器人操作! 搜维尔科技:Xsens实时动作捕捉,驱动人形机器人操作!

方向导数和梯度

方向导数和梯度 1 导数的回忆2 偏导数及其向量形式偏导数的几何意义偏导数的向量形式 3 方向导数向量形式几何意义方向导数和偏导的关系 4 梯度5 梯度下降算法 1 导数的回忆 导数的几何意义如图所示: 当 P 0 P_{0} P0​点不断接近 P P P时,导数如下定义…

vue 中使用element-ui实现锚点定位表单

效果图&#xff1a; 代码&#xff1a; html代码&#xff1a; <div class"content-left"><el-tabs :tab-position"left" tab-click"goAnchor"><el-tab-pane v-for"(item,index) in anchorNameList"v-anchor-scroll:ke…

LeetCode 60.排序排列(dfs暴力)

给出集合 [1,2,3,...,n]&#xff0c;其所有元素共有 n! 种排列。 按大小顺序列出所有排列情况&#xff0c;并一一标记&#xff0c;当 n 3 时, 所有排列如下&#xff1a; "123""132""213""231""312""321" 给定…

基于RSA的数字签名设计与实现

闲着没事把大三的实验发一下 这里用的技术和老师要求的略有不同但大体相同 1. RSA算法简介 公钥密码体制中,解密和加密密钥不同,解密和加密可分离,通信双方无须事先交换密钥就可建立起保密通信,较好地解决了传统密码体制在网络通信中出现的问题.另外,随着电子商务的发展,网络…

transformer模型学习路线_transformer训练用的模型

Transformer学习路线 完全不懂transformer&#xff0c;最近小白来入门一下&#xff0c;下面就是本菜鸟学习路线。Transformer和CNN是两个分支&#xff01;&#xff01;因此要分开学习 Transformer是一个Seq2seq模型&#xff0c;而Seq2seq模型用到了self-attention机制&#xf…

Python | Leetcode Python题解之第205题同构字符串

题目&#xff1a; 题解&#xff1a; class Solution:def isIsomorphic(self, s: str, t: str) -> bool:dicts Counter(s)dictt Counter(t) if list(dicts.values()) ! list(dictt.values()):return Falsefor i in range(len(s)):inds list(dicts.keys()).index(s…